diff --git a/src/util/amount.rs b/src/util/amount.rs index 777501f0..427b67cc 100644 --- a/src/util/amount.rs +++ b/src/util/amount.rs @@ -523,6 +523,13 @@ impl FromStr for Amount { } } +impl ::std::iter::Sum for Amount { + fn sum>(iter: I) -> Self { + let sats: u64 = iter.map(|amt| amt.0).sum(); + Amount::from_sat(sats) + } +} + /// SignedAmount /// /// The [SignedAmount] type can be used to express Bitcoin amounts that supports @@ -848,6 +855,52 @@ impl FromStr for SignedAmount { } } +impl ::std::iter::Sum for SignedAmount { + fn sum>(iter: I) -> Self { + let sats: i64 = iter.map(|amt| amt.0).sum(); + SignedAmount::from_sat(sats) + } +} + +/// Calculate the sum over the iterator using checked arithmetic. +pub trait CheckedSum: private::SumSeal { + /// Calculate the sum over the iterator using checked arithmetic. If an over or underflow would + /// happen it returns `None`. + fn checked_sum(self) -> Option; +} + +impl CheckedSum for T where T: Iterator { + fn checked_sum(mut self) -> Option { + let first = Some(self.next().unwrap_or_default()); + + self.fold( + first, + |acc, item| acc.and_then(|acc| acc.checked_add(item)) + ) + } +} + +impl CheckedSum for T where T: Iterator { + fn checked_sum(mut self) -> Option { + let first = Some(self.next().unwrap_or_default()); + + self.fold( + first, + |acc, item| acc.and_then(|acc| acc.checked_add(item)) + ) + } +} + +mod private { + use ::{Amount, SignedAmount}; + + /// Used to seal the `CheckedSum` trait + pub trait SumSeal {} + + impl SumSeal for T where T: Iterator {} + impl SumSeal for T where T: Iterator {} +} + #[cfg(feature = "serde")] pub mod serde { // methods are implementation of a standardized serde-specific signature @@ -1516,4 +1569,72 @@ mod tests { let value_without: serde_json::Value = serde_json::from_str("{}").unwrap(); assert_eq!(without, serde_json::from_value(value_without).unwrap()); } + + #[test] + fn sum_amounts() { + assert_eq!(Amount::from_sat(0), vec![].into_iter().sum::()); + assert_eq!(SignedAmount::from_sat(0), vec![].into_iter().sum::()); + + let amounts = vec![ + Amount::from_sat(42), + Amount::from_sat(1337), + Amount::from_sat(21) + ]; + let sum = amounts.into_iter().sum::(); + assert_eq!(Amount::from_sat(1400), sum); + + let amounts = vec![ + SignedAmount::from_sat(-42), + SignedAmount::from_sat(1337), + SignedAmount::from_sat(21) + ]; + let sum = amounts.into_iter().sum::(); + assert_eq!(SignedAmount::from_sat(1316), sum); + } + + #[test] + fn checked_sum_amounts() { + assert_eq!(Some(Amount::from_sat(0)), vec![].into_iter().checked_sum()); + assert_eq!(Some(SignedAmount::from_sat(0)), vec![].into_iter().checked_sum()); + + let amounts = vec![ + Amount::from_sat(42), + Amount::from_sat(1337), + Amount::from_sat(21) + ]; + let sum = amounts.into_iter().checked_sum(); + assert_eq!(Some(Amount::from_sat(1400)), sum); + + let amounts = vec![ + Amount::from_sat(u64::max_value()), + Amount::from_sat(1337), + Amount::from_sat(21) + ]; + let sum = amounts.into_iter().checked_sum(); + assert_eq!(None, sum); + + let amounts = vec![ + SignedAmount::from_sat(i64::min_value()), + SignedAmount::from_sat(-1), + SignedAmount::from_sat(21) + ]; + let sum = amounts.into_iter().checked_sum(); + assert_eq!(None, sum); + + let amounts = vec![ + SignedAmount::from_sat(i64::max_value()), + SignedAmount::from_sat(1), + SignedAmount::from_sat(21) + ]; + let sum = amounts.into_iter().checked_sum(); + assert_eq!(None, sum); + + let amounts = vec![ + SignedAmount::from_sat(42), + SignedAmount::from_sat(3301), + SignedAmount::from_sat(21) + ]; + let sum = amounts.into_iter().checked_sum(); + assert_eq!(Some(SignedAmount::from_sat(3364)), sum); + } }