diff --git a/src/util/decimal.rs b/src/util/decimal.rs index 48d89d3c..6436f783 100644 --- a/src/util/decimal.rs +++ b/src/util/decimal.rs @@ -33,6 +33,13 @@ pub struct Decimal { exponent: usize, } +/// Unsigned fixed-point decimal type +#[derive(Copy, Clone, Debug, Eq, Ord)] +pub struct UDecimal { + mantissa: u64, + exponent: usize, +} + impl PartialEq for Decimal { fn eq(&self, other: &Decimal) -> bool { use std::cmp::max; @@ -166,6 +173,123 @@ impl de::Deserialize for Decimal { } } + +impl PartialEq for UDecimal { + fn eq(&self, other: &UDecimal) -> bool { + use std::cmp::max; + let exp = max(self.exponent(), other.exponent()); + self.integer_value(exp) == other.integer_value(exp) + } +} + +impl PartialOrd for UDecimal { + fn partial_cmp(&self, other: &UDecimal) -> Option<::std::cmp::Ordering> { + use std::cmp::max; + let exp = max(self.exponent(), other.exponent()); + self.integer_value(exp).partial_cmp(&other.integer_value(exp)) + } +} + +impl fmt::Display for UDecimal { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let ten = 10u64.pow(self.exponent as u32); + let int_part = self.mantissa / ten; + let dec_part = self.mantissa % ten; + write!(f, "{}.{:02$}", int_part, dec_part, self.exponent) + } +} + +impl ops::Add for UDecimal { + type Output = UDecimal; + + #[inline] + fn add(self, other: UDecimal) -> UDecimal { + if self.exponent > other.exponent { + UDecimal { + mantissa: other.mantissa * 10u64.pow((self.exponent - other.exponent) as u32) + self.mantissa, + exponent: self.exponent + } + } else { + UDecimal { + mantissa: self.mantissa * 10u64.pow((other.exponent - self.exponent) as u32) + other.mantissa, + exponent: other.exponent + } + } + } +} + +impl UDecimal { + /// Creates a new Decimal + pub fn new(mantissa: u64, exponent: usize) -> UDecimal { + UDecimal { + mantissa: mantissa, + exponent: exponent + } + } + + /// Returns the mantissa + #[inline] + pub fn mantissa(&self) -> u64 { self.mantissa } + /// Returns the exponent + #[inline] + pub fn exponent(&self) -> usize { self.exponent } + + /// Get the decimal's value in an integer type, by multiplying + /// by some power of ten to ensure the returned value is 10 ** + /// `exponent` types the actual value. + pub fn integer_value(&self, exponent: usize) -> u64 { + if exponent < self.exponent { + self.mantissa / 10u64.pow((self.exponent - exponent) as u32) + } else { + self.mantissa * 10u64.pow((exponent - self.exponent) as u32) + } + } +} + +impl ser::Serialize for UDecimal { + // Serialize through strason since it will not lose precision (when serializing + // to strason itself, the value will be passed through; otherwise it will be + // encoded as a string) + fn serialize(&self, s: &mut S) -> Result<(), S::Error> { + let json = Json::from_str(&self.to_string()).unwrap(); + ser::Serialize::serialize(&json, s) + } +} + +impl de::Deserialize for UDecimal { + // Deserialize through strason for the same reason as in `Serialize` + fn deserialize(d: &mut D) -> Result { + let json: Json = try!(de::Deserialize::deserialize(d)); + match json.num() { + Some(s) => { + // We know this will be a well-formed Json number, so we can + // be pretty lax about parsing + let mut past_dec = false; + let mut exponent = 0; + let mut mantissa = 0u64; + + for b in s.as_bytes() { + match *b { + b'0'...b'9' => { + mantissa = 10 * mantissa + (b - b'0') as u64; + if past_dec { exponent += 1; } + } + b'.' => { past_dec = true; } + _ => { /* whitespace or something, just ignore it */ } + } + } + Ok(UDecimal { + mantissa: mantissa, + exponent: exponent, + }) + } + None => Err(de::Error::syntax("expected decimal, got non-numeric")) + } + } +} + + + #[cfg(test)] mod tests { use super::*; @@ -186,6 +310,20 @@ mod tests { assert_eq!(d.integer_value(6), 1234567800); assert_eq!(d.integer_value(7), 12345678000); assert_eq!(d.integer_value(8), 123456780000); + + let u = UDecimal::new(12345678, 4); + assert_eq!(u.mantissa(), 12345678); + assert_eq!(u.exponent(), 4); + + assert_eq!(u.integer_value(0), 1234); + assert_eq!(u.integer_value(1), 12345); + assert_eq!(u.integer_value(2), 123456); + assert_eq!(u.integer_value(3), 1234567); + assert_eq!(u.integer_value(4), 12345678); + assert_eq!(u.integer_value(5), 123456780); + assert_eq!(u.integer_value(6), 1234567800); + assert_eq!(u.integer_value(7), 12345678000); + assert_eq!(u.integer_value(8), 123456780000); } macro_rules! deserialize_round_trip( @@ -195,7 +333,10 @@ mod tests { assert_eq!(encoded, Json::from_reader(&$s[..]).unwrap()); assert_eq!(encoded.to_bytes(), &$s[..]); - let decoded: Decimal = encoded.into_deserialize().unwrap(); + // hack to force type inference + let mut decoded_res = encoded.into_deserialize(); + if false { decoded_res = Ok($dec); } + let decoded = decoded_res.unwrap(); assert_eq!(decoded, d); }) ); @@ -203,15 +344,20 @@ mod tests { #[test] fn deserialize() { deserialize_round_trip!(Decimal::new(0, 0), b"0.0"); + deserialize_round_trip!(UDecimal::new(0, 0), b"0.0"); deserialize_round_trip!(Decimal::new(123456789001, 8), b"1234.56789001"); + deserialize_round_trip!(UDecimal::new(123456789001, 8), b"1234.56789001"); deserialize_round_trip!(Decimal::new(-123456789001, 8), b"-1234.56789001"); deserialize_round_trip!(Decimal::new(123456789001, 1), b"12345678900.1"); + deserialize_round_trip!(UDecimal::new(123456789001, 1), b"12345678900.1"); deserialize_round_trip!(Decimal::new(-123456789001, 1), b"-12345678900.1"); deserialize_round_trip!(Decimal::new(123456789001, 0), b"123456789001.0"); + deserialize_round_trip!(UDecimal::new(123456789001, 0), b"123456789001.0"); deserialize_round_trip!(Decimal::new(-123456789001, 0), b"-123456789001.0"); deserialize_round_trip!(Decimal::new(123400000001, 8), b"1234.00000001"); + deserialize_round_trip!(UDecimal::new(123400000001, 8), b"1234.00000001"); deserialize_round_trip!(Decimal::new(-123400000001, 8), b"-1234.00000001"); } @@ -240,6 +386,9 @@ mod tests { let d2 = Decimal::new(-2, 2); // -0.02 let d3 = Decimal::new(3, 0); // 3.0 let d4 = Decimal::new(0, 5); // 0.00000 + let u1 = UDecimal::new(5, 1); // 0.5 + let u3 = UDecimal::new(3, 0); // 3.0 + let u4 = UDecimal::new(0, 5); // 0.00000 assert!(d1.nonnegative()); assert!(!d2.nonnegative()); @@ -249,14 +398,17 @@ mod tests { assert_eq!(d1 + d2, Decimal::new(48, 2)); assert_eq!(d1 - d2, Decimal::new(52, 2)); assert_eq!(d1 + d3, Decimal::new(35, 1)); + assert_eq!(u1 + u3, UDecimal::new(35, 1)); assert_eq!(d1 - d3, Decimal::new(-25, 1)); assert_eq!(d2 + d3, Decimal::new(298, 2)); assert_eq!(d2 - d3, Decimal::new(-302, 2)); assert_eq!(d1 + d4, d1); + assert_eq!(u1 + u4, u1); assert_eq!(d1 - d4, d1); assert_eq!(d1 + d4, d1 - d4); assert_eq!(d4 + d4, d4); + assert_eq!(u4 + u4, u4); } #[test] @@ -266,10 +418,20 @@ mod tests { let dec: Decimal = json.into_deserialize().unwrap(); assert_eq!(dec, Decimal::new(980000, 8)); + let json = Json::from_str("0.00980000").unwrap(); + assert_eq!(json.to_bytes(), b"0.00980000"); + let dec: UDecimal = json.into_deserialize().unwrap(); + assert_eq!(dec, UDecimal::new(980000, 8)); + let json = Json::from_str("0.00980").unwrap(); assert_eq!(json.to_bytes(), b"0.00980"); let dec: Decimal = json.into_deserialize().unwrap(); assert_eq!(dec, Decimal::new(98000, 7)); + + let json = Json::from_str("0.00980").unwrap(); + assert_eq!(json.to_bytes(), b"0.00980"); + let dec: UDecimal = json.into_deserialize().unwrap(); + assert_eq!(dec, UDecimal::new(98000, 7)); } }