// SPDX-License-Identifier: CC0-1.0 //! Provides a monodic type returned by mathematical operations (`core::ops`). use core::ops; use NumOpResult as R; use super::{Amount, SignedAmount}; use crate::{NumOpError, NumOpResult, OptionExt}; impl From for NumOpResult { fn from(a: Amount) -> Self { Self::Valid(a) } } impl From<&Amount> for NumOpResult { fn from(a: &Amount) -> Self { Self::Valid(*a) } } impl From for NumOpResult { fn from(a: SignedAmount) -> Self { Self::Valid(a) } } impl From<&SignedAmount> for NumOpResult { fn from(a: &SignedAmount) -> Self { Self::Valid(*a) } } crate::internal_macros::impl_op_for_references! { impl ops::Add for Amount { type Output = NumOpResult; fn add(self, rhs: Amount) -> Self::Output { self.checked_add(rhs).valid_or_error() } } impl ops::Add> for Amount { type Output = NumOpResult; fn add(self, rhs: NumOpResult) -> Self::Output { rhs.and_then(|a| a + self) } } impl ops::Sub for Amount { type Output = NumOpResult; fn sub(self, rhs: Amount) -> Self::Output { self.checked_sub(rhs).valid_or_error() } } impl ops::Sub> for Amount { type Output = NumOpResult; fn sub(self, rhs: NumOpResult) -> Self::Output { match rhs { R::Valid(amount) => self - amount, R::Error(_) => rhs, } } } impl ops::Mul for Amount { type Output = NumOpResult; fn mul(self, rhs: u64) -> Self::Output { self.checked_mul(rhs).valid_or_error() } } impl ops::Mul for NumOpResult { type Output = NumOpResult; fn mul(self, rhs: u64) -> Self::Output { self.and_then(|lhs| lhs * rhs) } } impl ops::Mul for u64 { type Output = NumOpResult; fn mul(self, rhs: Amount) -> Self::Output { rhs.checked_mul(self).valid_or_error() } } impl ops::Mul> for u64 { type Output = NumOpResult; fn mul(self, rhs: NumOpResult) -> Self::Output { rhs.and_then(|rhs| self * rhs) } } impl ops::Div for Amount { type Output = NumOpResult; fn div(self, rhs: u64) -> Self::Output { self.checked_div(rhs).valid_or_error() } } impl ops::Div for NumOpResult { type Output = NumOpResult; fn div(self, rhs: u64) -> Self::Output { self.and_then(|lhs| lhs / rhs) } } impl ops::Div for Amount { type Output = u64; fn div(self, rhs: Amount) -> Self::Output { self.to_sat() / rhs.to_sat() } } impl ops::Rem for Amount { type Output = NumOpResult; fn rem(self, modulus: u64) -> Self::Output { self.checked_rem(modulus).valid_or_error() } } impl ops::Rem for NumOpResult { type Output = NumOpResult; fn rem(self, modulus: u64) -> Self::Output { self.and_then(|lhs| lhs % modulus) } } impl ops::Add for SignedAmount { type Output = NumOpResult; fn add(self, rhs: SignedAmount) -> Self::Output { self.checked_add(rhs).valid_or_error() } } impl ops::Add> for SignedAmount { type Output = NumOpResult; fn add(self, rhs: NumOpResult) -> Self::Output { rhs.and_then(|a| a + self) } } impl ops::Sub for SignedAmount { type Output = NumOpResult; fn sub(self, rhs: SignedAmount) -> Self::Output { self.checked_sub(rhs).valid_or_error() } } impl ops::Sub> for SignedAmount { type Output = NumOpResult; fn sub(self, rhs: NumOpResult) -> Self::Output { match rhs { R::Valid(amount) => self - amount, R::Error(_) => rhs, } } } impl ops::Mul for SignedAmount { type Output = NumOpResult; fn mul(self, rhs: i64) -> Self::Output { self.checked_mul(rhs).valid_or_error() } } impl ops::Mul for NumOpResult { type Output = NumOpResult; fn mul(self, rhs: i64) -> Self::Output { self.and_then(|lhs| lhs * rhs) } } impl ops::Mul for i64 { type Output = NumOpResult; fn mul(self, rhs: SignedAmount) -> Self::Output { rhs.checked_mul(self).valid_or_error() } } impl ops::Mul> for i64 { type Output = NumOpResult; fn mul(self, rhs: NumOpResult) -> Self::Output { rhs.and_then(|rhs| self * rhs) } } impl ops::Div for SignedAmount { type Output = NumOpResult; fn div(self, rhs: i64) -> Self::Output { self.checked_div(rhs).valid_or_error() } } impl ops::Div for NumOpResult { type Output = NumOpResult; fn div(self, rhs: i64) -> Self::Output { self.and_then(|lhs| lhs / rhs) } } impl ops::Div for SignedAmount { type Output = i64; fn div(self, rhs: SignedAmount) -> Self::Output { self.to_sat() / rhs.to_sat() } } impl ops::Rem for SignedAmount { type Output = NumOpResult; fn rem(self, modulus: i64) -> Self::Output { self.checked_rem(modulus).valid_or_error() } } impl ops::Rem for NumOpResult { type Output = NumOpResult; fn rem(self, modulus: i64) -> Self::Output { self.and_then(|lhs| lhs % modulus) } } impl ops::Add> for NumOpResult where (T: Copy + ops::Add>) { type Output = NumOpResult; fn add(self, rhs: Self) -> Self::Output { match (self, rhs) { (R::Valid(lhs), R::Valid(rhs)) => lhs + rhs, (_, _) => R::Error(NumOpError {}), } } } impl ops::Add for NumOpResult where (T: Copy + ops::Add, Output = NumOpResult>) { type Output = NumOpResult; fn add(self, rhs: T) -> Self::Output { rhs + self } } impl ops::Sub> for NumOpResult where (T: Copy + ops::Sub>) { type Output = NumOpResult; fn sub(self, rhs: Self) -> Self::Output { match (self, rhs) { (R::Valid(lhs), R::Valid(rhs)) => lhs - rhs, (_, _) => R::Error(NumOpError {}), } } } impl ops::Sub for NumOpResult where (T: Copy + ops::Sub>) { type Output = NumOpResult; fn sub(self, rhs: T) -> Self::Output { match self { R::Valid(amount) => amount - rhs, R::Error(_) => self, } } } } impl ops::Neg for SignedAmount { type Output = Self; fn neg(self) -> Self::Output { Self::from_sat(self.to_sat().neg()).expect("all +ve and -ve values are valid") } } impl core::iter::Sum> for NumOpResult { fn sum(iter: I) -> Self where I: Iterator>, { iter.fold(R::Valid(Amount::ZERO), |acc, amount| match (acc, amount) { (R::Valid(lhs), R::Valid(rhs)) => lhs + rhs, (_, _) => R::Error(NumOpError {}), }) } } impl<'a> core::iter::Sum<&'a NumOpResult> for NumOpResult { fn sum(iter: I) -> Self where I: Iterator>, { iter.fold(R::Valid(Amount::ZERO), |acc, amount| match (acc, amount) { (R::Valid(lhs), R::Valid(rhs)) => lhs + rhs, (_, _) => R::Error(NumOpError {}), }) } } impl core::iter::Sum> for NumOpResult { fn sum(iter: I) -> Self where I: Iterator>, { iter.fold(R::Valid(SignedAmount::ZERO), |acc, amount| match (acc, amount) { (R::Valid(lhs), R::Valid(rhs)) => lhs + rhs, (_, _) => R::Error(NumOpError {}), }) } } impl<'a> core::iter::Sum<&'a NumOpResult> for NumOpResult { fn sum(iter: I) -> Self where I: Iterator>, { iter.fold(R::Valid(SignedAmount::ZERO), |acc, amount| match (acc, amount) { (R::Valid(lhs), R::Valid(rhs)) => lhs + rhs, (_, _) => R::Error(NumOpError {}), }) } }