Add unsigned Decimal type

This commit is contained in:
Andrew Poelstra 2015-12-05 03:12:46 -06:00
parent 5c69d44397
commit 45ef239a34
1 changed files with 163 additions and 1 deletions

View File

@ -33,6 +33,13 @@ pub struct Decimal {
exponent: usize, exponent: usize,
} }
/// Unsigned fixed-point decimal type
#[derive(Copy, Clone, Debug, Eq, Ord)]
pub struct UDecimal {
mantissa: u64,
exponent: usize,
}
impl PartialEq<Decimal> for Decimal { impl PartialEq<Decimal> for Decimal {
fn eq(&self, other: &Decimal) -> bool { fn eq(&self, other: &Decimal) -> bool {
use std::cmp::max; use std::cmp::max;
@ -166,6 +173,123 @@ impl de::Deserialize for Decimal {
} }
} }
impl PartialEq<UDecimal> for UDecimal {
fn eq(&self, other: &UDecimal) -> bool {
use std::cmp::max;
let exp = max(self.exponent(), other.exponent());
self.integer_value(exp) == other.integer_value(exp)
}
}
impl PartialOrd<UDecimal> for UDecimal {
fn partial_cmp(&self, other: &UDecimal) -> Option<::std::cmp::Ordering> {
use std::cmp::max;
let exp = max(self.exponent(), other.exponent());
self.integer_value(exp).partial_cmp(&other.integer_value(exp))
}
}
impl fmt::Display for UDecimal {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let ten = 10u64.pow(self.exponent as u32);
let int_part = self.mantissa / ten;
let dec_part = self.mantissa % ten;
write!(f, "{}.{:02$}", int_part, dec_part, self.exponent)
}
}
impl ops::Add for UDecimal {
type Output = UDecimal;
#[inline]
fn add(self, other: UDecimal) -> UDecimal {
if self.exponent > other.exponent {
UDecimal {
mantissa: other.mantissa * 10u64.pow((self.exponent - other.exponent) as u32) + self.mantissa,
exponent: self.exponent
}
} else {
UDecimal {
mantissa: self.mantissa * 10u64.pow((other.exponent - self.exponent) as u32) + other.mantissa,
exponent: other.exponent
}
}
}
}
impl UDecimal {
/// Creates a new Decimal
pub fn new(mantissa: u64, exponent: usize) -> UDecimal {
UDecimal {
mantissa: mantissa,
exponent: exponent
}
}
/// Returns the mantissa
#[inline]
pub fn mantissa(&self) -> u64 { self.mantissa }
/// Returns the exponent
#[inline]
pub fn exponent(&self) -> usize { self.exponent }
/// Get the decimal's value in an integer type, by multiplying
/// by some power of ten to ensure the returned value is 10 **
/// `exponent` types the actual value.
pub fn integer_value(&self, exponent: usize) -> u64 {
if exponent < self.exponent {
self.mantissa / 10u64.pow((self.exponent - exponent) as u32)
} else {
self.mantissa * 10u64.pow((exponent - self.exponent) as u32)
}
}
}
impl ser::Serialize for UDecimal {
// Serialize through strason since it will not lose precision (when serializing
// to strason itself, the value will be passed through; otherwise it will be
// encoded as a string)
fn serialize<S: ser::Serializer>(&self, s: &mut S) -> Result<(), S::Error> {
let json = Json::from_str(&self.to_string()).unwrap();
ser::Serialize::serialize(&json, s)
}
}
impl de::Deserialize for UDecimal {
// Deserialize through strason for the same reason as in `Serialize`
fn deserialize<D: de::Deserializer>(d: &mut D) -> Result<UDecimal, D::Error> {
let json: Json = try!(de::Deserialize::deserialize(d));
match json.num() {
Some(s) => {
// We know this will be a well-formed Json number, so we can
// be pretty lax about parsing
let mut past_dec = false;
let mut exponent = 0;
let mut mantissa = 0u64;
for b in s.as_bytes() {
match *b {
b'0'...b'9' => {
mantissa = 10 * mantissa + (b - b'0') as u64;
if past_dec { exponent += 1; }
}
b'.' => { past_dec = true; }
_ => { /* whitespace or something, just ignore it */ }
}
}
Ok(UDecimal {
mantissa: mantissa,
exponent: exponent,
})
}
None => Err(de::Error::syntax("expected decimal, got non-numeric"))
}
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@ -186,6 +310,20 @@ mod tests {
assert_eq!(d.integer_value(6), 1234567800); assert_eq!(d.integer_value(6), 1234567800);
assert_eq!(d.integer_value(7), 12345678000); assert_eq!(d.integer_value(7), 12345678000);
assert_eq!(d.integer_value(8), 123456780000); assert_eq!(d.integer_value(8), 123456780000);
let u = UDecimal::new(12345678, 4);
assert_eq!(u.mantissa(), 12345678);
assert_eq!(u.exponent(), 4);
assert_eq!(u.integer_value(0), 1234);
assert_eq!(u.integer_value(1), 12345);
assert_eq!(u.integer_value(2), 123456);
assert_eq!(u.integer_value(3), 1234567);
assert_eq!(u.integer_value(4), 12345678);
assert_eq!(u.integer_value(5), 123456780);
assert_eq!(u.integer_value(6), 1234567800);
assert_eq!(u.integer_value(7), 12345678000);
assert_eq!(u.integer_value(8), 123456780000);
} }
macro_rules! deserialize_round_trip( macro_rules! deserialize_round_trip(
@ -195,7 +333,10 @@ mod tests {
assert_eq!(encoded, Json::from_reader(&$s[..]).unwrap()); assert_eq!(encoded, Json::from_reader(&$s[..]).unwrap());
assert_eq!(encoded.to_bytes(), &$s[..]); assert_eq!(encoded.to_bytes(), &$s[..]);
let decoded: Decimal = encoded.into_deserialize().unwrap(); // hack to force type inference
let mut decoded_res = encoded.into_deserialize();
if false { decoded_res = Ok($dec); }
let decoded = decoded_res.unwrap();
assert_eq!(decoded, d); assert_eq!(decoded, d);
}) })
); );
@ -203,15 +344,20 @@ mod tests {
#[test] #[test]
fn deserialize() { fn deserialize() {
deserialize_round_trip!(Decimal::new(0, 0), b"0.0"); deserialize_round_trip!(Decimal::new(0, 0), b"0.0");
deserialize_round_trip!(UDecimal::new(0, 0), b"0.0");
deserialize_round_trip!(Decimal::new(123456789001, 8), b"1234.56789001"); deserialize_round_trip!(Decimal::new(123456789001, 8), b"1234.56789001");
deserialize_round_trip!(UDecimal::new(123456789001, 8), b"1234.56789001");
deserialize_round_trip!(Decimal::new(-123456789001, 8), b"-1234.56789001"); deserialize_round_trip!(Decimal::new(-123456789001, 8), b"-1234.56789001");
deserialize_round_trip!(Decimal::new(123456789001, 1), b"12345678900.1"); deserialize_round_trip!(Decimal::new(123456789001, 1), b"12345678900.1");
deserialize_round_trip!(UDecimal::new(123456789001, 1), b"12345678900.1");
deserialize_round_trip!(Decimal::new(-123456789001, 1), b"-12345678900.1"); deserialize_round_trip!(Decimal::new(-123456789001, 1), b"-12345678900.1");
deserialize_round_trip!(Decimal::new(123456789001, 0), b"123456789001.0"); deserialize_round_trip!(Decimal::new(123456789001, 0), b"123456789001.0");
deserialize_round_trip!(UDecimal::new(123456789001, 0), b"123456789001.0");
deserialize_round_trip!(Decimal::new(-123456789001, 0), b"-123456789001.0"); deserialize_round_trip!(Decimal::new(-123456789001, 0), b"-123456789001.0");
deserialize_round_trip!(Decimal::new(123400000001, 8), b"1234.00000001"); deserialize_round_trip!(Decimal::new(123400000001, 8), b"1234.00000001");
deserialize_round_trip!(UDecimal::new(123400000001, 8), b"1234.00000001");
deserialize_round_trip!(Decimal::new(-123400000001, 8), b"-1234.00000001"); deserialize_round_trip!(Decimal::new(-123400000001, 8), b"-1234.00000001");
} }
@ -240,6 +386,9 @@ mod tests {
let d2 = Decimal::new(-2, 2); // -0.02 let d2 = Decimal::new(-2, 2); // -0.02
let d3 = Decimal::new(3, 0); // 3.0 let d3 = Decimal::new(3, 0); // 3.0
let d4 = Decimal::new(0, 5); // 0.00000 let d4 = Decimal::new(0, 5); // 0.00000
let u1 = UDecimal::new(5, 1); // 0.5
let u3 = UDecimal::new(3, 0); // 3.0
let u4 = UDecimal::new(0, 5); // 0.00000
assert!(d1.nonnegative()); assert!(d1.nonnegative());
assert!(!d2.nonnegative()); assert!(!d2.nonnegative());
@ -249,14 +398,17 @@ mod tests {
assert_eq!(d1 + d2, Decimal::new(48, 2)); assert_eq!(d1 + d2, Decimal::new(48, 2));
assert_eq!(d1 - d2, Decimal::new(52, 2)); assert_eq!(d1 - d2, Decimal::new(52, 2));
assert_eq!(d1 + d3, Decimal::new(35, 1)); assert_eq!(d1 + d3, Decimal::new(35, 1));
assert_eq!(u1 + u3, UDecimal::new(35, 1));
assert_eq!(d1 - d3, Decimal::new(-25, 1)); assert_eq!(d1 - d3, Decimal::new(-25, 1));
assert_eq!(d2 + d3, Decimal::new(298, 2)); assert_eq!(d2 + d3, Decimal::new(298, 2));
assert_eq!(d2 - d3, Decimal::new(-302, 2)); assert_eq!(d2 - d3, Decimal::new(-302, 2));
assert_eq!(d1 + d4, d1); assert_eq!(d1 + d4, d1);
assert_eq!(u1 + u4, u1);
assert_eq!(d1 - d4, d1); assert_eq!(d1 - d4, d1);
assert_eq!(d1 + d4, d1 - d4); assert_eq!(d1 + d4, d1 - d4);
assert_eq!(d4 + d4, d4); assert_eq!(d4 + d4, d4);
assert_eq!(u4 + u4, u4);
} }
#[test] #[test]
@ -266,10 +418,20 @@ mod tests {
let dec: Decimal = json.into_deserialize().unwrap(); let dec: Decimal = json.into_deserialize().unwrap();
assert_eq!(dec, Decimal::new(980000, 8)); assert_eq!(dec, Decimal::new(980000, 8));
let json = Json::from_str("0.00980000").unwrap();
assert_eq!(json.to_bytes(), b"0.00980000");
let dec: UDecimal = json.into_deserialize().unwrap();
assert_eq!(dec, UDecimal::new(980000, 8));
let json = Json::from_str("0.00980").unwrap(); let json = Json::from_str("0.00980").unwrap();
assert_eq!(json.to_bytes(), b"0.00980"); assert_eq!(json.to_bytes(), b"0.00980");
let dec: Decimal = json.into_deserialize().unwrap(); let dec: Decimal = json.into_deserialize().unwrap();
assert_eq!(dec, Decimal::new(98000, 7)); assert_eq!(dec, Decimal::new(98000, 7));
let json = Json::from_str("0.00980").unwrap();
assert_eq!(json.to_bytes(), b"0.00980");
let dec: UDecimal = json.into_deserialize().unwrap();
assert_eq!(dec, UDecimal::new(98000, 7));
} }
} }