diff --git a/src/util/amount.rs b/src/util/amount.rs index c9b7f3eb..13756844 100644 --- a/src/util/amount.rs +++ b/src/util/amount.rs @@ -863,6 +863,35 @@ impl ::std::iter::Sum for SignedAmount { } } +/// Calculate the sum over the iterator using checked arithmetic. +pub trait CheckedSum { + /// Calculate the sum over the iterator using checked arithmetic. If an over or underflow would + /// happen it returns `None`. + fn checked_sum(self) -> Option; +} + +impl CheckedSum for T where T: Iterator { + fn checked_sum(mut self) -> Option { + let first = Some(self.next().unwrap_or_default()); + + self.fold( + first, + |acc, item| acc.and_then(|acc| acc.checked_add(item)) + ) + } +} + +impl CheckedSum for T where T: Iterator { + fn checked_sum(mut self) -> Option { + let first = Some(self.next().unwrap_or_default()); + + self.fold( + first, + |acc, item| acc.and_then(|acc| acc.checked_add(item)) + ) + } +} + #[cfg(feature = "serde")] pub mod serde { // methods are implementation of a standardized serde-specific signature @@ -1548,4 +1577,50 @@ mod tests { let sum = amounts.into_iter().sum::(); assert_eq!(SignedAmount::from_sat(1316), sum); } + + #[test] + fn checked_sum_amounts() { + assert_eq!(Some(Amount::from_sat(0)), vec![].into_iter().checked_sum()); + assert_eq!(Some(SignedAmount::from_sat(0)), vec![].into_iter().checked_sum()); + + let amounts = vec![ + Amount::from_sat(42), + Amount::from_sat(1337), + Amount::from_sat(21) + ]; + let sum = amounts.into_iter().checked_sum(); + assert_eq!(Some(Amount::from_sat(1400)), sum); + + let amounts = vec![ + Amount::from_sat(u64::max_value()), + Amount::from_sat(1337), + Amount::from_sat(21) + ]; + let sum = amounts.into_iter().checked_sum(); + assert_eq!(None, sum); + + let amounts = vec![ + SignedAmount::from_sat(i64::min_value()), + SignedAmount::from_sat(-1), + SignedAmount::from_sat(21) + ]; + let sum = amounts.into_iter().checked_sum(); + assert_eq!(None, sum); + + let amounts = vec![ + SignedAmount::from_sat(i64::max_value()), + SignedAmount::from_sat(1), + SignedAmount::from_sat(21) + ]; + let sum = amounts.into_iter().checked_sum(); + assert_eq!(None, sum); + + let amounts = vec![ + SignedAmount::from_sat(42), + SignedAmount::from_sat(3301), + SignedAmount::from_sat(21) + ]; + let sum = amounts.into_iter().checked_sum(); + assert_eq!(Some(SignedAmount::from_sat(3364)), sum); + } }