diff --git a/units/src/amount/result.rs b/units/src/amount/result.rs index e229a02c7..ae1c6e12d 100644 --- a/units/src/amount/result.rs +++ b/units/src/amount/result.rs @@ -161,6 +161,16 @@ crate::internal_macros::impl_op_for_references! { fn mul(self, rhs: u64) -> Self::Output { self.and_then(|lhs| lhs * rhs) } } + impl ops::Mul for u64 { + type Output = NumOpResult; + + fn mul(self, rhs: Amount) -> Self::Output { rhs.checked_mul(self).valid_or_error() } + } + impl ops::Mul> for u64 { + type Output = NumOpResult; + + fn mul(self, rhs: NumOpResult) -> Self::Output { rhs.and_then(|rhs| self * rhs) } + } impl ops::Div for Amount { type Output = NumOpResult; @@ -172,6 +182,11 @@ crate::internal_macros::impl_op_for_references! { fn div(self, rhs: u64) -> Self::Output { self.and_then(|lhs| lhs / rhs) } } + impl ops::Div for Amount { + type Output = u64; + + fn div(self, rhs: Amount) -> Self::Output { self.to_sat() / rhs.to_sat() } + } impl ops::Rem for Amount { type Output = NumOpResult; @@ -221,6 +236,16 @@ crate::internal_macros::impl_op_for_references! { fn mul(self, rhs: i64) -> Self::Output { self.and_then(|lhs| lhs * rhs) } } + impl ops::Mul for i64 { + type Output = NumOpResult; + + fn mul(self, rhs: SignedAmount) -> Self::Output { rhs.checked_mul(self).valid_or_error() } + } + impl ops::Mul> for i64 { + type Output = NumOpResult; + + fn mul(self, rhs: NumOpResult) -> Self::Output { rhs.and_then(|rhs| self * rhs) } + } impl ops::Div for SignedAmount { type Output = NumOpResult; @@ -232,6 +257,11 @@ crate::internal_macros::impl_op_for_references! { fn div(self, rhs: i64) -> Self::Output { self.and_then(|lhs| lhs / rhs) } } + impl ops::Div for SignedAmount { + type Output = i64; + + fn div(self, rhs: SignedAmount) -> Self::Output { self.to_sat() / rhs.to_sat() } + } impl ops::Rem for SignedAmount { type Output = NumOpResult; diff --git a/units/src/amount/tests.rs b/units/src/amount/tests.rs index e0d506abb..d263623eb 100644 --- a/units/src/amount/tests.rs +++ b/units/src/amount/tests.rs @@ -1138,64 +1138,191 @@ fn trailing_zeros_for_amount() { } #[test] -#[allow(clippy::op_ref)] +fn add_sub_combos() { + // Checks lhs op rhs for all reference combos. + macro_rules! check_ref { + ($($lhs:ident $op:tt $rhs:ident = $ans:ident);* $(;)?) => { + $( + assert_eq!($lhs $op $rhs, $ans); + assert_eq!(&$lhs $op $rhs, $ans); + assert_eq!($lhs $op &$rhs, $ans); + assert_eq!(&$lhs $op &$rhs, $ans); + )* + } + } + + // Checks lhs op rhs for all amount and `NumOpResult` combos. + macro_rules! check_res { + ($($amount:ident, $op:tt, $lhs:literal, $rhs:literal, $ans:literal);* $(;)?) => { + $( + let amt = |sat| $amount::from_sat(sat); + + let sat_lhs = amt($lhs); + let sat_rhs = amt($rhs); + + let res_lhs = NumOpResult::from(sat_lhs); + let res_rhs = NumOpResult::from(sat_rhs); + + let ans = NumOpResult::from(amt($ans)); + + check_ref! { + sat_lhs $op sat_rhs = ans; + sat_lhs $op res_rhs = ans; + res_lhs $op sat_rhs = ans; + res_lhs $op res_rhs = ans; + } + )* + } + } + + // Checks lhs op rhs for both amount types. + macro_rules! check_op { + ($($lhs:literal $op:tt $rhs:literal = $ans:literal);* $(;)?) => { + $( + check_res!(Amount, $op, $lhs, $rhs, $ans); + check_res!(SignedAmount, $op, $lhs, $rhs, $ans); + )* + } + } + + // We do not currently support division involving `NumOpResult` and an amount type. + check_op! { + 307 + 461 = 768; + 461 - 307 = 154; + } +} + +#[test] fn unsigned_addition() { let sat = Amount::from_sat; - let one = sat(1); - let two = sat(2); - let three = sat(3); - - assert!((one + two) == three.into()); - assert!((one + two) == three.into()); - assert!((&one + two) == three.into()); - assert!((one + &two) == three.into()); - assert!((&one + &two) == three.into()); + assert_eq!(sat(0) + sat(0), NumOpResult::from(sat(0))); + assert_eq!(sat(0) + sat(307), NumOpResult::from(sat(307))); + assert_eq!(sat(307) + sat(0), NumOpResult::from(sat(307))); + assert_eq!(sat(307) + sat(461), NumOpResult::from(sat(768))); + assert_eq!(sat(0) + Amount::MAX_MONEY, NumOpResult::from(Amount::MAX_MONEY)); } #[test] -#[allow(clippy::op_ref)] -fn unsigned_subtract() { - let sat = Amount::from_sat; - - let one = sat(1); - let two = sat(2); - let three = sat(3); - - assert!(three - two == one.into()); - assert!(&three - two == one.into()); - assert!(three - &two == one.into()); - assert!(&three - &two == one.into()); -} - -#[test] -#[allow(clippy::op_ref)] fn signed_addition() { let ssat = SignedAmount::from_sat; - let one = ssat(1); - let two = ssat(2); - let three = ssat(3); + assert_eq!(ssat(0) + ssat(0), NumOpResult::from(ssat(0))); + assert_eq!(ssat(0) + ssat(307), NumOpResult::from(ssat(307))); + assert_eq!(ssat(307) + ssat(0), NumOpResult::from(ssat(307))); + assert_eq!(ssat(307) + ssat(461), NumOpResult::from(ssat(768))); + assert_eq!(ssat(0) + SignedAmount::MAX_MONEY, NumOpResult::from(SignedAmount::MAX_MONEY)); - assert!(one + two == three.into()); - assert!(&one + two == three.into()); - assert!(one + &two == three.into()); - assert!(&one + &two == three.into()); + assert_eq!(ssat(0) + ssat(-307), NumOpResult::from(ssat(-307))); + assert_eq!(ssat(-307) + ssat(0), NumOpResult::from(ssat(-307))); + assert_eq!(ssat(-307) + ssat(461), NumOpResult::from(ssat(154))); + assert_eq!(ssat(307) + ssat(-461), NumOpResult::from(ssat(-154))); + assert_eq!(ssat(-307) + ssat(-461), NumOpResult::from(ssat(-768))); + assert_eq!( + SignedAmount::MAX_MONEY + -SignedAmount::MAX_MONEY, + NumOpResult::from(SignedAmount::ZERO) + ); } #[test] -#[allow(clippy::op_ref)] -fn signed_subtract() { +fn unsigned_subtraction() { + let sat = Amount::from_sat; + + assert_eq!(sat(0) - sat(0), NumOpResult::from(sat(0))); + assert_eq!(sat(307) - sat(0), NumOpResult::from(sat(307))); + assert_eq!(sat(461) - sat(307), NumOpResult::from(sat(154))); +} + +#[test] +fn signed_subtraction() { let ssat = SignedAmount::from_sat; - let one = ssat(1); - let two = ssat(2); - let three = ssat(3); + assert_eq!(ssat(0) - ssat(0), NumOpResult::from(ssat(0))); + assert_eq!(ssat(0) - ssat(307), NumOpResult::from(ssat(-307))); + assert_eq!(ssat(307) - ssat(0), NumOpResult::from(ssat(307))); + assert_eq!(ssat(307) - ssat(461), NumOpResult::from(ssat(-154))); + assert_eq!(ssat(0) - SignedAmount::MAX_MONEY, NumOpResult::from(-SignedAmount::MAX_MONEY)); - assert!(three - two == one.into()); - assert!(&three - two == one.into()); - assert!(three - &two == one.into()); - assert!(&three - &two == one.into()); + assert_eq!(ssat(0) - ssat(-307), NumOpResult::from(ssat(307))); + assert_eq!(ssat(-307) - ssat(0), NumOpResult::from(ssat(-307))); + assert_eq!(ssat(-307) - ssat(461), NumOpResult::from(ssat(-768))); + assert_eq!(ssat(307) - ssat(-461), NumOpResult::from(ssat(768))); + assert_eq!(ssat(-307) - ssat(-461), NumOpResult::from(ssat(154))); +} + +#[test] +fn op_int_combos() { + let sat = Amount::from_sat; + let ssat = SignedAmount::from_sat; + + let res = |sat| NumOpResult::from(Amount::from_sat(sat)); + let sres = |ssat| NumOpResult::from(SignedAmount::from_sat(ssat)); + + assert_eq!(sat(23) * 31, res(713)); + assert_eq!(ssat(23) * 31, sres(713)); + assert_eq!(res(23) * 31, res(713)); + assert_eq!(sres(23) * 31, sres(713)); + + assert_eq!(31 * sat(23), res(713)); + assert_eq!(31 * ssat(23), sres(713)); + assert_eq!(31 * res(23), res(713)); + assert_eq!(31 * sres(23), sres(713)); + + // No remainder. + assert_eq!(sat(1897) / 7, res(271)); + assert_eq!(ssat(1897) / 7, sres(271)); + assert_eq!(res(1897) / 7, res(271)); + assert_eq!(sres(1897) / 7, sres(271)); + + // Truncation works as expected. + assert_eq!(sat(1901) / 7, res(271)); + assert_eq!(ssat(1901) / 7, sres(271)); + assert_eq!(res(1901) / 7, res(271)); + assert_eq!(sres(1901) / 7, sres(271)); + + // No remainder. + assert_eq!(sat(1897) % 7, res(0)); + assert_eq!(ssat(1897) % 7, sres(0)); + assert_eq!(res(1897) % 7, res(0)); + assert_eq!(sres(1897) % 7, sres(0)); + + // Remainder works as expected. + assert_eq!(sat(1901) % 7, res(4)); + assert_eq!(ssat(1901) % 7, sres(4)); + assert_eq!(res(1901) % 7, res(4)); + assert_eq!(sres(1901) % 7, sres(4)); +} + +#[test] +fn unsigned_amount_div_by_amount() { + let sat = Amount::from_sat; + + assert_eq!(sat(0) / sat(7), 0); + assert_eq!(sat(1897) / sat(7), 271); +} + +#[test] +#[should_panic(expected = "attempt to divide by zero")] +fn unsigned_amount_div_by_amount_zero() { + let _ = Amount::from_sat(1897) / Amount::ZERO; +} + +#[test] +fn signed_amount_div_by_amount() { + let ssat = SignedAmount::from_sat; + + assert_eq!(ssat(0) / ssat(7), 0); + + assert_eq!(ssat(1897) / ssat(7), 271); + assert_eq!(ssat(1897) / ssat(-7), -271); + assert_eq!(ssat(-1897) / ssat(7), -271); + assert_eq!(ssat(-1897) / ssat(-7), 271); +} + +#[test] +#[should_panic(expected = "attempt to divide by zero")] +fn signed_amount_div_by_amount_zero() { + let _ = SignedAmount::from_sat(1897) / SignedAmount::ZERO; } #[test] @@ -1300,6 +1427,15 @@ fn num_op_result_ops_integer() { } } check_op! { + // Operations on an amount type and an integer. + let _ = sat * 3_u64; // Explicit type for the benefit of the reader. + let _ = sat / 3; + let _ = sat % 3; + + let _ = ssat * 3_i64; // Explicit type for the benefit of the reader. + let _ = ssat / 3; + let _ = ssat % 3; + // Operations on a `NumOpResult` and integer. let _ = res * 3_u64; // Explicit type for the benefit of the reader. let _ = res / 3;