diff --git a/bitcoin/examples/taproot-psbt.rs b/bitcoin/examples/taproot-psbt.rs index 1d712ed1b..c312cf3d7 100644 --- a/bitcoin/examples/taproot-psbt.rs +++ b/bitcoin/examples/taproot-psbt.rs @@ -412,7 +412,8 @@ impl BenefactorWallet { taproot_spend_info.internal_key(), taproot_spend_info.merkle_root(), ); - let value = input_utxo.amount - ABSOLUTE_FEES; + let value = (input_utxo.amount - ABSOLUTE_FEES) + .expect("ABSOLUTE_FEES must be set below input amount"); // Spend a normal BIP86-like output as an input in our inheritance funding transaction let tx = generate_bip86_key_spend_tx( @@ -476,7 +477,7 @@ impl BenefactorWallet { let mut psbt = self.next_psbt.clone().expect("should have next_psbt"); let input = &mut psbt.inputs[0]; let input_value = input.witness_utxo.as_ref().unwrap().value; - let output_value = input_value - ABSOLUTE_FEES; + let output_value = (input_value - ABSOLUTE_FEES).into_result()?; // We use some other derivation path in this example for our inheritance protocol. The important thing is to ensure // that we use an unhardened path so we can make use of xpubs. @@ -649,7 +650,8 @@ impl BeneficiaryWallet { psbt.unsigned_tx.lock_time = lock_time; psbt.unsigned_tx.output = vec![TxOut { script_pubkey: to_address.script_pubkey(), - value: input_value - ABSOLUTE_FEES, + value: (input_value - ABSOLUTE_FEES) + .expect("ABSOLUTE_FEES must be set below input amount"), }]; psbt.outputs = vec![Output::default()]; let unsigned_tx = psbt.unsigned_tx.clone(); diff --git a/bitcoin/src/blockdata/transaction.rs b/bitcoin/src/blockdata/transaction.rs index 7d7bf7699..56094aa7c 100644 --- a/bitcoin/src/blockdata/transaction.rs +++ b/bitcoin/src/blockdata/transaction.rs @@ -1673,7 +1673,7 @@ mod tests { // 10 sat/kwu * (204wu + BASE_WEIGHT) = 4 sats let expected_fee = "4 sats".parse::().unwrap(); - let expected_effective_value = value.to_signed() - expected_fee; + let expected_effective_value = (value.to_signed() - expected_fee).unwrap(); assert_eq!(effective_value, expected_effective_value); } diff --git a/bitcoin/src/psbt/mod.rs b/bitcoin/src/psbt/mod.rs index 96e36786a..2471fc41d 100644 --- a/bitcoin/src/psbt/mod.rs +++ b/bitcoin/src/psbt/mod.rs @@ -2249,7 +2249,7 @@ mod tests { }; assert_eq!( t.fee().expect("fee calculation"), - prev_output_val - (output_0_val + output_1_val) + (prev_output_val - (output_0_val + output_1_val)).unwrap() ); // no previous output let mut t2 = t.clone(); diff --git a/units/src/amount/mod.rs b/units/src/amount/mod.rs index b290c625f..e5e23417c 100644 --- a/units/src/amount/mod.rs +++ b/units/src/amount/mod.rs @@ -6,6 +6,7 @@ //! We refer to the documentation on the types for more information. mod error; +mod result; #[cfg(feature = "serde")] pub mod serde; @@ -34,6 +35,7 @@ pub use self::{ OutOfRangeError, ParseAmountError, ParseDenominationError, ParseError, PossiblyConfusingDenominationError, TooPreciseError, UnknownDenominationError, }, + result::{NumOpError, NumOpResult}, signed::SignedAmount, unsigned::Amount, }; diff --git a/units/src/amount/result.rs b/units/src/amount/result.rs new file mode 100644 index 000000000..c0e84f6e7 --- /dev/null +++ b/units/src/amount/result.rs @@ -0,0 +1,588 @@ +// SPDX-License-Identifier: CC0-1.0 + +//! Provides a monodic numeric result type that is used to return the result of +//! doing mathematical operations (`core::ops`) on amount types. + +use core::{fmt, ops}; + +use NumOpResult as R; + +use super::{Amount, SignedAmount}; + +/// Result of an operation on [`Amount`] or [`SignedAmount`]. +/// +/// The type parameter `T` should be normally `Amout` or `SignedAmount`. +#[derive(Debug, Clone, PartialEq, Eq)] +#[must_use] +pub enum NumOpResult { + /// Result of a successful mathematical operation. + Valid(T), + /// Result of an unsuccessful mathematical operation. + Error(NumOpError), +} + +impl NumOpResult { + /// Returns the contained valid amount, consuming `self`. + /// + /// # Panics + /// + /// Panics with `msg` if the numeric result is an `Error`. + #[track_caller] + pub fn expect(self, msg: &str) -> T { + match self { + R::Valid(amount) => amount, + R::Error(_) => panic!("{}", msg), + } + } + + /// Returns the contained valid amount, consuming `self`. + /// + /// # Panics + /// + /// Panics if the numeric result is an `Error`. + #[track_caller] + pub fn unwrap(self) -> T { + match self { + R::Valid(amount) => amount, + R::Error(e) => panic!("tried to unwrap an invalid numeric result: {:?}", e), + } + } + + /// Returns the contained error, consuming `self`. + /// + /// # Panics + /// + /// Panics if the numeric result is a valid amount. + #[track_caller] + pub fn unwrap_err(self) -> NumOpError { + match self { + R::Error(e) => e, + R::Valid(a) => panic!("tried to unwrap a valid numeric result: {:?}", a), + } + } + + /// Converts this `NumOpResult` to an `Option`. + pub fn ok(self) -> Option { + match self { + R::Valid(amount) => Some(amount), + R::Error(_) => None, + } + } + + /// Converts this `NumOpResult` to a `Result`. + #[allow(clippy::missing_errors_doc)] + pub fn into_result(self) -> Result { + match self { + R::Valid(amount) => Ok(amount), + R::Error(e) => Err(e), + } + } + + /// Calls `op` if the numeric result is `Valid`, otherwise returns the `Error` value of `self`. + pub fn and_then(self, op: F) -> NumOpResult + where + F: FnOnce(T) -> NumOpResult, + { + match self { + R::Valid(amount) => op(amount), + R::Error(e) => R::Error(e), + } + } + + /// Returns `true` if the numeric result is a valid amount. + pub fn is_valid(&self) -> bool { + match self { + R::Valid(_) => true, + R::Error(_) => false, + } + } + + /// Returns `true` if the numeric result is an invalid amount. + pub fn is_error(&self) -> bool { !self.is_valid() } +} + +impl From for NumOpResult { + fn from(a: Amount) -> Self { Self::Valid(a) } +} +impl From<&Amount> for NumOpResult { + fn from(a: &Amount) -> Self { Self::Valid(*a) } +} + +impl From for NumOpResult { + fn from(a: SignedAmount) -> Self { Self::Valid(a) } +} +impl From<&SignedAmount> for NumOpResult { + fn from(a: &SignedAmount) -> Self { Self::Valid(*a) } +} + +impl ops::Add for Amount { + type Output = NumOpResult; + + fn add(self, rhs: Amount) -> Self::Output { self.checked_add(rhs).valid_or_error() } +} +crate::internal_macros::impl_add_for_amount_references!(Amount); + +impl ops::Add> for Amount { + type Output = NumOpResult; + + fn add(self, rhs: NumOpResult) -> Self::Output { rhs.and_then(|a| a + self) } +} +impl ops::Add> for &Amount { + type Output = NumOpResult; + + fn add(self, rhs: NumOpResult) -> Self::Output { rhs.and_then(|a| a + self) } +} +impl ops::Add for NumOpResult { + type Output = NumOpResult; + + fn add(self, rhs: Amount) -> Self::Output { rhs + self } +} +impl ops::Add<&Amount> for NumOpResult { + type Output = NumOpResult; + + fn add(self, rhs: &Amount) -> Self::Output { rhs + self } +} + +impl ops::Sub for Amount { + type Output = NumOpResult; + + fn sub(self, rhs: Amount) -> Self::Output { self.checked_sub(rhs).valid_or_error() } +} +crate::internal_macros::impl_sub_for_amount_references!(Amount); + +impl ops::Sub> for Amount { + type Output = NumOpResult; + + fn sub(self, rhs: NumOpResult) -> Self::Output { + match rhs { + R::Valid(amount) => self - amount, + R::Error(_) => rhs, + } + } +} +impl ops::Sub> for &Amount { + type Output = NumOpResult; + + fn sub(self, rhs: NumOpResult) -> Self::Output { + match rhs { + R::Valid(amount) => self - amount, + R::Error(_) => rhs, + } + } +} +impl ops::Sub for NumOpResult { + type Output = NumOpResult; + + fn sub(self, rhs: Amount) -> Self::Output { + match self { + R::Valid(amount) => amount - rhs, + R::Error(_) => self, + } + } +} +impl ops::Sub<&Amount> for NumOpResult { + type Output = NumOpResult; + + fn sub(self, rhs: &Amount) -> Self::Output { + match self { + R::Valid(amount) => amount - (*rhs), + R::Error(_) => self, + } + } +} + +impl ops::Mul for Amount { + type Output = NumOpResult; + + fn mul(self, rhs: u64) -> Self::Output { self.checked_mul(rhs).valid_or_error() } +} +impl ops::Mul<&u64> for Amount { + type Output = NumOpResult; + + fn mul(self, rhs: &u64) -> Self::Output { self.mul(*rhs) } +} +impl ops::Mul for &Amount { + type Output = NumOpResult; + + fn mul(self, rhs: u64) -> Self::Output { (*self).mul(rhs) } +} +impl ops::Mul<&u64> for &Amount { + type Output = NumOpResult; + + fn mul(self, rhs: &u64) -> Self::Output { self.mul(*rhs) } +} + +impl ops::Div for Amount { + type Output = NumOpResult; + + fn div(self, rhs: u64) -> Self::Output { self.checked_div(rhs).valid_or_error() } +} +impl ops::Div<&u64> for Amount { + type Output = NumOpResult; + + fn div(self, rhs: &u64) -> Self::Output { self.div(*rhs) } +} +impl ops::Div for &Amount { + type Output = NumOpResult; + + fn div(self, rhs: u64) -> Self::Output { (*self).div(rhs) } +} +impl ops::Div<&u64> for &Amount { + type Output = NumOpResult; + + fn div(self, rhs: &u64) -> Self::Output { (*self).div(*rhs) } +} + +impl ops::Rem for Amount { + type Output = NumOpResult; + + fn rem(self, modulus: u64) -> Self::Output { self.checked_rem(modulus).valid_or_error() } +} +impl ops::Rem<&u64> for Amount { + type Output = NumOpResult; + + fn rem(self, modulus: &u64) -> Self::Output { self.rem(*modulus) } +} +impl ops::Rem for &Amount { + type Output = NumOpResult; + + fn rem(self, modulus: u64) -> Self::Output { (*self).rem(modulus) } +} +impl ops::Rem<&u64> for &Amount { + type Output = NumOpResult; + + fn rem(self, modulus: &u64) -> Self::Output { (*self).rem(*modulus) } +} + +impl ops::Add for SignedAmount { + type Output = NumOpResult; + + fn add(self, rhs: SignedAmount) -> Self::Output { self.checked_add(rhs).valid_or_error() } +} +crate::internal_macros::impl_add_for_amount_references!(SignedAmount); + +impl ops::Add> for SignedAmount { + type Output = NumOpResult; + + fn add(self, rhs: NumOpResult) -> Self::Output { rhs.and_then(|a| a + self) } +} +impl ops::Add> for &SignedAmount { + type Output = NumOpResult; + + fn add(self, rhs: NumOpResult) -> Self::Output { rhs.and_then(|a| a + self) } +} +impl ops::Add for NumOpResult { + type Output = NumOpResult; + + fn add(self, rhs: SignedAmount) -> Self::Output { rhs + self } +} +impl ops::Add<&SignedAmount> for NumOpResult { + type Output = NumOpResult; + + fn add(self, rhs: &SignedAmount) -> Self::Output { rhs + self } +} + +impl ops::Sub for SignedAmount { + type Output = NumOpResult; + + fn sub(self, rhs: SignedAmount) -> Self::Output { self.checked_sub(rhs).valid_or_error() } +} +crate::internal_macros::impl_sub_for_amount_references!(SignedAmount); + +impl ops::Sub> for SignedAmount { + type Output = NumOpResult; + + fn sub(self, rhs: NumOpResult) -> Self::Output { + match rhs { + R::Valid(amount) => amount - rhs, + R::Error(_) => rhs, + } + } +} +impl ops::Sub> for &SignedAmount { + type Output = NumOpResult; + + fn sub(self, rhs: NumOpResult) -> Self::Output { + match rhs { + R::Valid(amount) => amount - rhs, + R::Error(_) => rhs, + } + } +} +impl ops::Sub for NumOpResult { + type Output = NumOpResult; + + fn sub(self, rhs: SignedAmount) -> Self::Output { + match self { + R::Valid(amount) => amount - rhs, + R::Error(_) => self, + } + } +} +impl ops::Sub<&SignedAmount> for NumOpResult { + type Output = NumOpResult; + + fn sub(self, rhs: &SignedAmount) -> Self::Output { + match self { + R::Valid(amount) => amount - *rhs, + R::Error(_) => self, + } + } +} + +impl ops::Mul for SignedAmount { + type Output = NumOpResult; + + fn mul(self, rhs: i64) -> Self::Output { self.checked_mul(rhs).valid_or_error() } +} +impl ops::Mul<&i64> for SignedAmount { + type Output = NumOpResult; + + fn mul(self, rhs: &i64) -> Self::Output { self.mul(*rhs) } +} +impl ops::Mul for &SignedAmount { + type Output = NumOpResult; + + fn mul(self, rhs: i64) -> Self::Output { (*self).mul(rhs) } +} +impl ops::Mul<&i64> for &SignedAmount { + type Output = NumOpResult; + + fn mul(self, rhs: &i64) -> Self::Output { self.mul(*rhs) } +} + +impl ops::Div for SignedAmount { + type Output = NumOpResult; + + fn div(self, rhs: i64) -> Self::Output { self.checked_div(rhs).valid_or_error() } +} +impl ops::Div<&i64> for SignedAmount { + type Output = NumOpResult; + + fn div(self, rhs: &i64) -> Self::Output { self.div(*rhs) } +} +impl ops::Div for &SignedAmount { + type Output = NumOpResult; + + fn div(self, rhs: i64) -> Self::Output { (*self).div(rhs) } +} +impl ops::Div<&i64> for &SignedAmount { + type Output = NumOpResult; + + fn div(self, rhs: &i64) -> Self::Output { (*self).div(*rhs) } +} + +impl ops::Rem for SignedAmount { + type Output = NumOpResult; + + fn rem(self, modulus: i64) -> Self::Output { self.checked_rem(modulus).valid_or_error() } +} +impl ops::Rem<&i64> for SignedAmount { + type Output = NumOpResult; + + fn rem(self, modulus: &i64) -> Self::Output { self.rem(*modulus) } +} +impl ops::Rem for &SignedAmount { + type Output = NumOpResult; + + fn rem(self, modulus: i64) -> Self::Output { (*self).rem(modulus) } +} +impl ops::Rem<&i64> for &SignedAmount { + type Output = NumOpResult; + + fn rem(self, modulus: &i64) -> Self::Output { (*self).rem(*modulus) } +} + +impl ops::Neg for SignedAmount { + type Output = Self; + + fn neg(self) -> Self::Output { Self::from_sat(self.to_sat().neg()) } +} + +impl ops::Add for NumOpResult +where + T: ops::Add>, +{ + type Output = NumOpResult; + + fn add(self, rhs: Self) -> Self::Output { + match (self, rhs) { + (R::Valid(lhs), R::Valid(rhs)) => lhs + rhs, + (_, _) => R::Error(NumOpError {}), + } + } +} +impl ops::Add> for &NumOpResult +where + T: ops::Add> + Copy, +{ + type Output = NumOpResult; + + fn add(self, rhs: NumOpResult) -> Self::Output { + match (self, rhs) { + (R::Valid(lhs), R::Valid(rhs)) => *lhs + rhs, + (_, _) => R::Error(NumOpError {}), + } + } +} +impl ops::Add<&NumOpResult> for NumOpResult +where + T: ops::Add> + Copy, +{ + type Output = NumOpResult; + + fn add(self, rhs: &NumOpResult) -> Self::Output { + match (self, rhs) { + (R::Valid(lhs), R::Valid(rhs)) => lhs + *rhs, + (_, _) => R::Error(NumOpError {}), + } + } +} +impl ops::Add for &NumOpResult +where + T: ops::Add> + Copy, +{ + type Output = NumOpResult; + + fn add(self, rhs: &NumOpResult) -> Self::Output { + match (self, rhs) { + (R::Valid(lhs), R::Valid(rhs)) => *lhs + *rhs, + (_, _) => R::Error(NumOpError {}), + } + } +} + +impl ops::Sub for NumOpResult +where + T: ops::Sub>, +{ + type Output = NumOpResult; + + fn sub(self, rhs: Self) -> Self::Output { + match (self, rhs) { + (R::Valid(lhs), R::Valid(rhs)) => lhs - rhs, + (_, _) => R::Error(NumOpError {}), + } + } +} +impl ops::Sub> for &NumOpResult +where + T: ops::Sub> + Copy, +{ + type Output = NumOpResult; + + fn sub(self, rhs: NumOpResult) -> Self::Output { + match (self, rhs) { + (R::Valid(lhs), R::Valid(rhs)) => *lhs - rhs, + (_, _) => R::Error(NumOpError {}), + } + } +} +impl ops::Sub<&NumOpResult> for NumOpResult +where + T: ops::Sub> + Copy, +{ + type Output = NumOpResult; + + fn sub(self, rhs: &NumOpResult) -> Self::Output { + match (self, rhs) { + (R::Valid(lhs), R::Valid(rhs)) => lhs - *rhs, + (_, _) => R::Error(NumOpError {}), + } + } +} +impl ops::Sub for &NumOpResult +where + T: ops::Sub> + Copy, +{ + type Output = NumOpResult; + + fn sub(self, rhs: Self) -> Self::Output { + match (self, rhs) { + (R::Valid(lhs), R::Valid(rhs)) => *lhs - *rhs, + (_, _) => R::Error(NumOpError {}), + } + } +} + +impl core::iter::Sum> for NumOpResult { + fn sum(iter: I) -> Self + where + I: Iterator>, + { + iter.fold(R::Valid(Amount::ZERO), |acc, amount| match (acc, amount) { + (R::Valid(lhs), R::Valid(rhs)) => lhs + rhs, + (_, _) => R::Error(NumOpError {}), + }) + } +} +impl<'a> core::iter::Sum<&'a NumOpResult> for NumOpResult { + fn sum(iter: I) -> Self + where + I: Iterator>, + { + iter.fold(R::Valid(Amount::ZERO), |acc, amount| match (acc, amount) { + (R::Valid(lhs), R::Valid(rhs)) => lhs + rhs, + (_, _) => R::Error(NumOpError {}), + }) + } +} + +impl core::iter::Sum> for NumOpResult { + fn sum(iter: I) -> Self + where + I: Iterator>, + { + iter.fold(R::Valid(SignedAmount::ZERO), |acc, amount| match (acc, amount) { + (R::Valid(lhs), R::Valid(rhs)) => lhs + rhs, + (_, _) => R::Error(NumOpError {}), + }) + } +} +impl<'a> core::iter::Sum<&'a NumOpResult> for NumOpResult { + fn sum(iter: I) -> Self + where + I: Iterator>, + { + iter.fold(R::Valid(SignedAmount::ZERO), |acc, amount| match (acc, amount) { + (R::Valid(lhs), R::Valid(rhs)) => lhs + rhs, + (_, _) => R::Error(NumOpError {}), + }) + } +} + +pub(in crate::amount) trait OptionExt { + fn valid_or_error(self) -> NumOpResult; +} + +impl OptionExt for Option { + fn valid_or_error(self) -> NumOpResult { + match self { + Some(amount) => R::Valid(amount), + None => R::Error(NumOpError {}), + } + } +} + +impl OptionExt for Option { + fn valid_or_error(self) -> NumOpResult { + match self { + Some(amount) => R::Valid(amount), + None => R::Error(NumOpError {}), + } + } +} + +/// An error occurred while doing a mathematical operation. +#[derive(Debug, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub struct NumOpError; + +impl fmt::Display for NumOpError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "a math operation on amounts gave an invalid numeric result") + } +} + +#[cfg(feature = "std")] +impl std::error::Error for NumOpError {} diff --git a/units/src/amount/signed.rs b/units/src/amount/signed.rs index 99602eca9..f51be5570 100644 --- a/units/src/amount/signed.rs +++ b/units/src/amount/signed.rs @@ -5,7 +5,7 @@ #[cfg(feature = "alloc")] use alloc::string::{String, ToString}; use core::str::FromStr; -use core::{default, fmt, ops}; +use core::{default, fmt}; #[cfg(feature = "arbitrary")] use arbitrary::{Arbitrary, Unstructured}; @@ -479,68 +479,6 @@ impl fmt::Display for SignedAmount { } } -impl ops::Add for SignedAmount { - type Output = SignedAmount; - - fn add(self, rhs: SignedAmount) -> Self::Output { - self.checked_add(rhs).expect("SignedAmount addition error") - } -} -crate::internal_macros::impl_add_for_references!(SignedAmount); -crate::internal_macros::impl_add_assign!(SignedAmount); - -impl ops::Sub for SignedAmount { - type Output = SignedAmount; - - fn sub(self, rhs: SignedAmount) -> Self::Output { - self.checked_sub(rhs).expect("SignedAmount subtraction error") - } -} -crate::internal_macros::impl_sub_for_references!(SignedAmount); -crate::internal_macros::impl_sub_assign!(SignedAmount); - -impl ops::Rem for SignedAmount { - type Output = SignedAmount; - - fn rem(self, modulus: i64) -> Self::Output { - self.checked_rem(modulus).expect("SignedAmount remainder error") - } -} - -impl ops::RemAssign for SignedAmount { - fn rem_assign(&mut self, modulus: i64) { *self = *self % modulus } -} - -impl ops::Mul for SignedAmount { - type Output = SignedAmount; - - fn mul(self, rhs: i64) -> Self::Output { - self.checked_mul(rhs).expect("SignedAmount multiplication error") - } -} - -impl ops::MulAssign for SignedAmount { - fn mul_assign(&mut self, rhs: i64) { *self = *self * rhs } -} - -impl ops::Div for SignedAmount { - type Output = SignedAmount; - - fn div(self, rhs: i64) -> Self::Output { - self.checked_div(rhs).expect("SignedAmount division error") - } -} - -impl ops::DivAssign for SignedAmount { - fn div_assign(&mut self, rhs: i64) { *self = *self / rhs } -} - -impl ops::Neg for SignedAmount { - type Output = Self; - - fn neg(self) -> Self::Output { Self(self.0.neg()) } -} - impl FromStr for SignedAmount { type Err = ParseError; @@ -572,23 +510,6 @@ impl From for SignedAmount { } } -impl core::iter::Sum for SignedAmount { - fn sum>(iter: I) -> Self { - let sats: i64 = iter.map(|amt| amt.0).sum(); - Self::from_sat(sats) - } -} - -impl<'a> core::iter::Sum<&'a SignedAmount> for SignedAmount { - fn sum(iter: I) -> Self - where - I: Iterator, - { - let sats: i64 = iter.map(|amt| amt.0).sum(); - Self::from_sat(sats) - } -} - #[cfg(feature = "arbitrary")] impl<'a> Arbitrary<'a> for SignedAmount { fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result { diff --git a/units/src/amount/tests.rs b/units/src/amount/tests.rs index a733e3042..f5b450fdb 100644 --- a/units/src/amount/tests.rs +++ b/units/src/amount/tests.rs @@ -116,28 +116,15 @@ fn mul_div() { let sat = Amount::from_sat; let ssat = SignedAmount::from_sat; - assert_eq!(sat(14) * 3, sat(42)); - assert_eq!(sat(14) / 2, sat(7)); - assert_eq!(sat(14) % 3, sat(2)); - assert_eq!(ssat(-14) * 3, ssat(-42)); - assert_eq!(ssat(-14) / 2, ssat(-7)); - assert_eq!(ssat(-14) % 3, ssat(-2)); + let op_result_sat = |sat| NumOpResult::Valid(Amount::from_sat(sat)); + let op_result_ssat = |sat| NumOpResult::Valid(SignedAmount::from_sat(sat)); - let mut a = sat(30); - a /= 3; - assert_eq!(a, sat(10)); - a %= 3; - assert_eq!(a, sat(1)); - a *= 3; - assert_eq!(a, sat(3)); - - let mut b = ssat(30); - b /= 3; - assert_eq!(b, ssat(10)); - b %= 3; - assert_eq!(b, ssat(1)); - b *= 3; - assert_eq!(b, ssat(3)); + assert_eq!(sat(14) * 3, op_result_sat(42)); + assert_eq!(sat(14) / 2, op_result_sat(7)); + assert_eq!(sat(14) % 3, op_result_sat(2)); + assert_eq!(ssat(-14) * 3, op_result_ssat(-42)); + assert_eq!(ssat(-14) / 2, op_result_ssat(-7)); + assert_eq!(ssat(-14) % 3, op_result_ssat(-2)); } #[test] @@ -149,11 +136,41 @@ fn neg() { #[cfg(feature = "std")] #[test] fn overflows() { - // panic on overflow - let result = panic::catch_unwind(|| Amount::MAX + Amount::from_sat_unchecked(1)); - assert!(result.is_err()); - let result = panic::catch_unwind(|| Amount::from_sat_unchecked(8_446_744_073_709_551_615) * 3); - assert!(result.is_err()); + let result = Amount::MAX + Amount::from_sat_unchecked(1); + assert!(result.is_error()); + let result = Amount::from_sat_unchecked(8_446_744_073_709_551_615) * 3; + assert!(result.is_error()); +} + +#[test] +fn add() { + let sat = Amount::from_sat; + let ssat = SignedAmount::from_sat; + + assert!(sat(0) + sat(0) == sat(0).into()); + assert!(sat(127) + sat(179) == sat(306).into()); + + assert!(ssat(0) + ssat(0) == ssat(0).into()); + assert!(ssat(127) + ssat(179) == ssat(306).into()); + assert!(ssat(-127) + ssat(179) == ssat(52).into()); + assert!(ssat(127) + ssat(-179) == ssat(-52).into()); + assert!(ssat(-127) + ssat(-179) == ssat(-306).into()); +} + +#[test] +fn sub() { + let sat = Amount::from_sat; + let ssat = SignedAmount::from_sat; + + assert!(sat(0) - sat(0) == sat(0).into()); + assert!(sat(179) - sat(127) == sat(52).into()); + assert!((sat(127) - sat(179)).is_error()); + + assert!(ssat(0) - ssat(0) == ssat(0).into()); + assert!(ssat(127) - ssat(179) == ssat(-52).into()); + assert!(ssat(-127) - ssat(179) == ssat(-306).into()); + assert!(ssat(127) - ssat(-179) == ssat(306).into()); + assert!(ssat(-127) - ssat(-179) == ssat(52).into()); } #[test] @@ -990,18 +1007,28 @@ fn sum_amounts() { let sat = Amount::from_sat; let ssat = SignedAmount::from_sat; - assert_eq!([].iter().sum::(), Amount::ZERO); - assert_eq!([].iter().sum::(), SignedAmount::ZERO); + assert_eq!([].iter().sum::>(), Amount::ZERO.into()); + assert_eq!([].iter().sum::>(), SignedAmount::ZERO.into()); let amounts = [sat(42), sat(1337), sat(21)]; - assert_eq!(amounts.iter().sum::(), sat(1400)); - let sum = amounts.into_iter().sum::(); - assert_eq!(sum, sat(1400)); + assert_eq!( + amounts.iter().map(|a| NumOpResult::Valid(*a)).sum::>(), + sat(1400).into(), + ); + assert_eq!( + amounts.into_iter().map(NumOpResult::Valid).sum::>(), + sat(1400).into(), + ); let amounts = [ssat(-42), ssat(1337), ssat(21)]; - assert_eq!(amounts.iter().sum::(), ssat(1316)); - let sum = amounts.into_iter().sum::(); - assert_eq!(sum, ssat(1316)); + assert_eq!( + amounts.iter().map(NumOpResult::from).sum::>(), + ssat(1316).into(), + ); + assert_eq!( + amounts.into_iter().map(NumOpResult::from).sum::>(), + ssat(1316).into() + ); } #[test] @@ -1108,10 +1135,11 @@ fn unsigned_addition() { let two = sat(2); let three = sat(3); - assert!(one + two == three); - assert!(&one + two == three); - assert!(one + &two == three); - assert!(&one + &two == three); + assert!((one + two) == three.into()); + assert!((one + two) == three.into()); + assert!((&one + two) == three.into()); + assert!((one + &two) == three.into()); + assert!((&one + &two) == three.into()); } #[test] @@ -1123,36 +1151,10 @@ fn unsigned_subtract() { let two = sat(2); let three = sat(3); - assert!(three - two == one); - assert!(&three - two == one); - assert!(three - &two == one); - assert!(&three - &two == one); -} - -#[test] -fn unsigned_add_assign() { - let sat = Amount::from_sat; - - let mut f = sat(1); - f += sat(2); - assert_eq!(f, sat(3)); - - let mut f = sat(1); - f += &sat(2); - assert_eq!(f, sat(3)); -} - -#[test] -fn unsigned_sub_assign() { - let sat = Amount::from_sat; - - let mut f = sat(3); - f -= sat(2); - assert_eq!(f, sat(1)); - - let mut f = sat(3); - f -= &sat(2); - assert_eq!(f, sat(1)); + assert!(three - two == one.into()); + assert!(&three - two == one.into()); + assert!(three - &two == one.into()); + assert!(&three - &two == one.into()); } #[test] @@ -1164,10 +1166,10 @@ fn signed_addition() { let two = ssat(2); let three = ssat(3); - assert!(one + two == three); - assert!(&one + two == three); - assert!(one + &two == three); - assert!(&one + &two == three); + assert!(one + two == three.into()); + assert!(&one + two == three.into()); + assert!(one + &two == three.into()); + assert!(&one + &two == three.into()); } #[test] @@ -1179,36 +1181,10 @@ fn signed_subtract() { let two = ssat(2); let three = ssat(3); - assert!(three - two == one); - assert!(&three - two == one); - assert!(three - &two == one); - assert!(&three - &two == one); -} - -#[test] -fn signed_add_assign() { - let ssat = SignedAmount::from_sat; - - let mut f = ssat(1); - f += ssat(2); - assert_eq!(f, ssat(3)); - - let mut f = ssat(1); - f += &ssat(2); - assert_eq!(f, ssat(3)); -} - -#[test] -fn signed_sub_assign() { - let ssat = SignedAmount::from_sat; - - let mut f = ssat(3); - f -= ssat(2); - assert_eq!(f, ssat(1)); - - let mut f = ssat(3); - f -= &ssat(2); - assert_eq!(f, ssat(1)); + assert!(three - two == one.into()); + assert!(&three - two == one.into()); + assert!(three - &two == one.into()); + assert!(&three - &two == one.into()); } #[test] @@ -1219,3 +1195,185 @@ fn check_const() { assert_eq!(Amount::FIFTY_BTC.to_sat(), Amount::ONE_BTC.to_sat() * 50); assert_eq!(Amount::MAX_MONEY.to_sat() as i64, SignedAmount::MAX_MONEY.to_sat()); } + +// Verify we have implemented all combinations of ops for `Amount` and `SignedAmount`. +// It's easier to read this test that check the code. +#[test] +#[allow(clippy::op_ref)] // We are explicitly testing the references work with ops. +fn amount_tyes_all_ops() { + // Sanity check than stdlib supports the set of reference combinations for the ops we want. + { + let x = 127; + + let _ = x + x; + let _ = &x + x; + let _ = x + &x; + let _ = &x + &x; + + let _ = x - x; + let _ = &x - x; + let _ = x - &x; + let _ = &x - &x; + + let _ = -x; + } + + let sat = Amount::from_sat(1); + let ssat = SignedAmount::from_sat(1); + + // Add + let _ = sat + sat; + let _ = &sat + sat; + let _ = sat + &sat; + let _ = &sat + &sat; + + // let _ = ssat + sat; + // let _ = &ssat + sat; + // let _ = ssat + &sat; + // let _ = &ssat + &sat; + + // let _ = sat + ssat; + // let _ = &sat + ssat; + // let _ = sat + &ssat; + // let _ = &sat + &ssat; + + let _ = ssat + ssat; + let _ = &ssat + ssat; + let _ = ssat + &ssat; + let _ = &ssat + &ssat; + + // Sub + let _ = sat - sat; + let _ = &sat - sat; + let _ = sat - &sat; + let _ = &sat - &sat; + + // let _ = ssat - sat; + // let _ = &ssat - sat; + // let _ = ssat - &sat; + // let _ = &ssat - &sat; + + // let _ = sat - ssat; + // let _ = &sat - ssat; + // let _ = sat - &ssat; + // let _ = &sat - &ssat; + + let _ = ssat - ssat; + let _ = &ssat - ssat; + let _ = ssat - &ssat; + let _ = &ssat - &ssat; + + // let _ = sat * sat; // Intentionally not supported. + + // Mul + let _ = sat * 3; + let _ = sat * &3; + let _ = &sat * 3; + let _ = &sat * &3; + + let _ = ssat * 3_i64; // Explicit type for the benefit of the reader. + let _ = ssat * &3; + let _ = &ssat * 3; + let _ = &ssat * &3; + + // Div + let _ = sat / 3; + let _ = &sat / 3; + let _ = sat / &3; + let _ = &sat / &3; + + let _ = ssat / 3_i64; // Explicit type for the benefit of the reader. + let _ = &ssat / 3; + let _ = ssat / &3; + let _ = &ssat / &3; + + // Rem + let _ = sat % 3; + let _ = &sat % 3; + let _ = sat % &3; + let _ = &sat % &3; + + let _ = ssat % 3; + let _ = &ssat % 3; + let _ = ssat % &3; + let _ = &ssat % &3; + + // FIXME: Do we want to support this? + // let _ = sat / sat; + // + // "How many times does this amount go into that amount?" - seems + // like a reasonable question to ask. + + // FIXME: Do we want to support these? + // let _ = -sat; + // let _ = -ssat; +} + +// FIXME: Should we support this sort of thing? +// It will be a lot more code for possibly not that much benefit. +#[test] +fn can_ops_on_amount_and_signed_amount() { + // let res: NumOpResult = sat + ssat; +} + +// Verify we have implemented all combinations of ops for the `NumOpResult` type. +// It's easier to read this test that check the code. +#[test] +#[allow(clippy::op_ref)] // We are explicitly testing the references work with ops. +fn amount_op_result_all_ops() { + let sat = Amount::from_sat(1); + // let ssat = SignedAmount::from_sat(1); + + // Explicit type as sanity check. + let res: NumOpResult = sat + sat; + // let sres: NumOpResult = ssat + ssat; + + // Operations that where RHS is the result of another operation. + let _ = sat + res.clone(); + let _ = &sat + res.clone(); + // let _ = sat + &res.clone(); + // let _ = &sat + &res.clone(); + + let _ = sat - res.clone(); + let _ = &sat - res.clone(); + // let _ = sat - &res.clone(); + // let _ = &sat - &res.clone(); + + // Operations that where LHS is the result of another operation. + let _ = res.clone() + sat; + // let _ = &res.clone() + sat; + let _ = res.clone() + &sat; + // let _ = &res.clone() + &sat; + + let _ = res.clone() - sat; + // let _ = &res.clone() - sat; + let _ = res.clone() - &sat; + // let _ = &res.clone() - &sat; + + // Operations that where both sides are the result of another operation. + let _ = res.clone() + res.clone(); + // let _ = &res.clone() + res.clone(); + // let _ = res.clone() + &res.clone(); + // let _ = &res.clone() + &res.clone(); + + let _ = res.clone() - res.clone(); + // let _ = &res.clone() - res.clone(); + // let _ = res.clone() - &res.clone(); + // let _ = &res.clone() - &res.clone(); +} + +// Verify we have implemented all `Sum` for the `NumOpResult` type. +#[test] +fn amount_op_result_sum() { + let res = Amount::from_sat(1) + Amount::from_sat(1); + let amounts = [res.clone(), res.clone()]; + let amount_refs = [&res, &res]; + + // Sum iterators. + let _ = amounts.iter().sum::>(); + let _ = amount_refs.iter().copied().sum::>(); + let _ = amount_refs.into_iter().sum::>(); + + // FIXME: Should we support this? I don't think so (Tobin). + // let _ = amount_refs.iter().sum::>(); +} diff --git a/units/src/amount/unsigned.rs b/units/src/amount/unsigned.rs index 645dbe9d8..c52a8f2c5 100644 --- a/units/src/amount/unsigned.rs +++ b/units/src/amount/unsigned.rs @@ -5,7 +5,7 @@ #[cfg(feature = "alloc")] use alloc::string::{String, ToString}; use core::str::FromStr; -use core::{default, fmt, ops}; +use core::{default, fmt}; #[cfg(feature = "arbitrary")] use arbitrary::{Arbitrary, Unstructured}; @@ -422,60 +422,6 @@ impl fmt::Display for Amount { } } -impl ops::Add for Amount { - type Output = Amount; - - fn add(self, rhs: Amount) -> Self::Output { - self.checked_add(rhs).expect("Amount addition error") - } -} -crate::internal_macros::impl_add_for_references!(Amount); -crate::internal_macros::impl_add_assign!(Amount); - -impl ops::Sub for Amount { - type Output = Amount; - - fn sub(self, rhs: Amount) -> Self::Output { - self.checked_sub(rhs).expect("Amount subtraction error") - } -} -crate::internal_macros::impl_sub_for_references!(Amount); -crate::internal_macros::impl_sub_assign!(Amount); - -impl ops::Rem for Amount { - type Output = Amount; - - fn rem(self, modulus: u64) -> Self::Output { - self.checked_rem(modulus).expect("Amount remainder error") - } -} - -impl ops::RemAssign for Amount { - fn rem_assign(&mut self, modulus: u64) { *self = *self % modulus } -} - -impl ops::Mul for Amount { - type Output = Amount; - - fn mul(self, rhs: u64) -> Self::Output { - self.checked_mul(rhs).expect("Amount multiplication error") - } -} - -impl ops::MulAssign for Amount { - fn mul_assign(&mut self, rhs: u64) { *self = *self * rhs } -} - -impl ops::Div for Amount { - type Output = Amount; - - fn div(self, rhs: u64) -> Self::Output { self.checked_div(rhs).expect("Amount division error") } -} - -impl ops::DivAssign for Amount { - fn div_assign(&mut self, rhs: u64) { *self = *self / rhs } -} - impl FromStr for Amount { type Err = ParseError; @@ -506,23 +452,6 @@ impl TryFrom for Amount { fn try_from(value: SignedAmount) -> Result { value.to_unsigned() } } -impl core::iter::Sum for Amount { - fn sum>(iter: I) -> Self { - let sats: u64 = iter.map(|amt| amt.0).sum(); - Self::from_sat(sats) - } -} - -impl<'a> core::iter::Sum<&'a Amount> for Amount { - fn sum(iter: I) -> Self - where - I: Iterator, - { - let sats: u64 = iter.map(|amt| amt.0).sum(); - Self::from_sat(sats) - } -} - #[cfg(feature = "arbitrary")] impl<'a> Arbitrary<'a> for Amount { fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result { diff --git a/units/src/amount/verification.rs b/units/src/amount/verification.rs index f0d7b47eb..ecf35d217 100644 --- a/units/src/amount/verification.rs +++ b/units/src/amount/verification.rs @@ -29,19 +29,11 @@ fn u_amount_homomorphic() { let a2 = Amount::from_sat(n2); kani::assume(a1.checked_add(a2).is_some()); // Adding amounts doesn't overflow. - assert_eq!(Amount::from_sat(n1) + Amount::from_sat(n2), Amount::from_sat(n1 + n2)); - - let mut amt = Amount::from_sat(n1); - amt += Amount::from_sat(n2); - assert_eq!(amt, Amount::from_sat(n1 + n2)); + assert_eq!(Amount::from_sat(n1) + Amount::from_sat(n2), Amount::from_sat(n1 + n2).into()); let max = cmp::max(n1, n2); let min = cmp::min(n1, n2); - assert_eq!(Amount::from_sat(max) - Amount::from_sat(min), Amount::from_sat(max - min)); - - let mut amt = Amount::from_sat(max); - amt -= Amount::from_sat(min); - assert_eq!(amt, Amount::from_sat(max - min)); + assert_eq!(Amount::from_sat(max) - Amount::from_sat(min), Amount::from_sat(max - min).into()); } #[kani::unwind(4)] @@ -59,17 +51,10 @@ fn s_amount_homomorphic() { assert_eq!( SignedAmount::from_sat(n1) + SignedAmount::from_sat(n2), - SignedAmount::from_sat(n1 + n2) + SignedAmount::from_sat(n1 + n2).into() ); assert_eq!( SignedAmount::from_sat(n1) - SignedAmount::from_sat(n2), - SignedAmount::from_sat(n1 - n2) + SignedAmount::from_sat(n1 - n2).into() ); - - let mut amt = SignedAmount::from_sat(n1); - amt += SignedAmount::from_sat(n2); - assert_eq!(amt, SignedAmount::from_sat(n1 + n2)); - let mut amt = SignedAmount::from_sat(n1); - amt -= SignedAmount::from_sat(n2); - assert_eq!(amt, SignedAmount::from_sat(n1 - n2)); } diff --git a/units/src/internal_macros.rs b/units/src/internal_macros.rs index 475f9374e..1bae68986 100644 --- a/units/src/internal_macros.rs +++ b/units/src/internal_macros.rs @@ -34,6 +34,36 @@ macro_rules! impl_add_for_references { } pub(crate) use impl_add_for_references; +/// Implements `ops::Add` for various amount references. +/// +/// Requires `$ty` it implement `Add` e.g. 'impl Add for T'. Adds impls of: +/// +/// - Add for &T +/// - Add<&T> for T +/// - Add<&T> for &T +macro_rules! impl_add_for_amount_references { + ($ty:ident) => { + impl core::ops::Add<$ty> for &$ty { + type Output = NumOpResult<$ty>; + + fn add(self, rhs: $ty) -> Self::Output { *self + rhs } + } + + impl core::ops::Add<&$ty> for $ty { + type Output = NumOpResult<$ty>; + + fn add(self, rhs: &$ty) -> Self::Output { self + *rhs } + } + + impl<'a> core::ops::Add<&'a $ty> for &$ty { + type Output = NumOpResult<$ty>; + + fn add(self, rhs: &'a $ty) -> Self::Output { *self + *rhs } + } + }; +} +pub(crate) use impl_add_for_amount_references; + /// Implement `ops::AddAssign` for `$ty` and `&$ty`. macro_rules! impl_add_assign { ($ty:ident) => { @@ -78,6 +108,36 @@ macro_rules! impl_sub_for_references { } pub(crate) use impl_sub_for_references; +/// Implement `ops::Sub` for various amount references. +/// +/// Requires `$ty` it implement `Sub` e.g. 'impl Sub for T'. Adds impls of: +/// +/// - Sub for &T +/// - Sub<&T> for T +/// - Sub<&T> for &T +macro_rules! impl_sub_for_amount_references { + ($ty:ident) => { + impl core::ops::Sub<$ty> for &$ty { + type Output = NumOpResult<$ty>; + + fn sub(self, rhs: $ty) -> Self::Output { *self - rhs } + } + + impl core::ops::Sub<&$ty> for $ty { + type Output = NumOpResult<$ty>; + + fn sub(self, rhs: &$ty) -> Self::Output { self - *rhs } + } + + impl<'a> core::ops::Sub<&'a $ty> for &$ty { + type Output = NumOpResult<$ty>; + + fn sub(self, rhs: &'a $ty) -> Self::Output { *self - *rhs } + } + }; +} +pub(crate) use impl_sub_for_amount_references; + /// Implement `ops::SubAssign` for `$ty` and `&$ty`. macro_rules! impl_sub_assign { ($ty:ident) => {