381 lines
12 KiB
Rust
381 lines
12 KiB
Rust
// SPDX-License-Identifier: CC0-1.0
|
|
|
|
//! Provides a monodic type returned by mathematical operations (`core::ops`).
|
|
|
|
use core::convert::Infallible;
|
|
use core::fmt;
|
|
|
|
#[cfg(feature = "arbitrary")]
|
|
use arbitrary::{Arbitrary, Unstructured};
|
|
use NumOpResult as R;
|
|
|
|
use crate::{Amount, FeeRate, SignedAmount, Weight};
|
|
|
|
/// Result of a mathematical operation on two numeric types.
|
|
///
|
|
/// In order to prevent overflow we provide a custom result type that is similar to the normal
|
|
/// [`core::result::Result`] but implements mathematical operations (e.g. [`core::ops::Add`]) so that
|
|
/// math operations can be chained ergonomically. This is very similar to how `NaN` works.
|
|
///
|
|
/// `NumOpResult` is a monadic type that contains `Valid` and `Error` (similar to `Ok` and `Err`).
|
|
/// It supports a subset of functions similar to `Result` (e.g. `unwrap`).
|
|
///
|
|
/// # Examples
|
|
///
|
|
/// The `NumOpResult` type provides protection against overflow and div-by-zero.
|
|
///
|
|
/// ### Overflow protection
|
|
///
|
|
/// ```
|
|
/// # use bitcoin_units::{amount, Amount};
|
|
/// // Example UTXO value.
|
|
/// let a1 = Amount::from_sat(1_000_000)?;
|
|
/// // And another value from some other UTXO.
|
|
/// let a2 = Amount::from_sat(765_432)?;
|
|
/// // Just an example (typically one would calculate fee using weight and fee rate).
|
|
/// let fee = Amount::from_sat(1_00)?;
|
|
/// // The amount we want to send.
|
|
/// let spend = Amount::from_sat(1_200_000)?;
|
|
///
|
|
/// // We can error if the change calculation overflows.
|
|
/// //
|
|
/// // For example if the `spend` value comes from the user and the `change` value is later
|
|
/// // used then overflow here could be an attack vector.
|
|
/// let _change = (a1 + a2 - spend - fee).into_result().expect("handle this error");
|
|
///
|
|
/// // Or if we control all the values and know they are sane we can just `unwrap`.
|
|
/// let _change = (a1 + a2 - spend - fee).unwrap();
|
|
/// // `NumOpResult` also implements `expect`.
|
|
/// let _change = (a1 + a2 - spend - fee).expect("we know values don't overflow");
|
|
/// # Ok::<_, amount::OutOfRangeError>(())
|
|
/// ```
|
|
///
|
|
/// ### Divide-by-zero (overflow in `Div` or `Rem`)
|
|
///
|
|
/// In some instances one may wish to differentiate div-by-zero from overflow.
|
|
///
|
|
/// ```
|
|
/// # use bitcoin_units::{Amount, FeeRate, NumOpResult, NumOpError};
|
|
/// // Two amounts that will be added to calculate the max fee.
|
|
/// let a = Amount::from_sat(123).expect("valid amount");
|
|
/// let b = Amount::from_sat(467).expect("valid amount");
|
|
/// // Fee rate for transaction.
|
|
/// let fee_rate = FeeRate::from_sat_per_vb(1);
|
|
///
|
|
/// // Somewhat contrived example to show addition operator chained with division.
|
|
/// let max_fee = a + b;
|
|
/// let _fee = match max_fee / fee_rate {
|
|
/// NumOpResult::Valid(fee) => fee,
|
|
/// NumOpResult::Error(e) if e.is_div_by_zero() => {
|
|
/// // Do something when div by zero.
|
|
/// return Err(e);
|
|
/// },
|
|
/// NumOpResult::Error(e) => {
|
|
/// // We separate div-by-zero from overflow in case it needs to be handled separately.
|
|
/// //
|
|
/// // This branch could be hit since `max_fee` came from some previous calculation. And if
|
|
/// // an input to that calculation was from the user then overflow could be an attack vector.
|
|
/// return Err(e);
|
|
/// }
|
|
/// };
|
|
/// # Ok::<_, NumOpError>(())
|
|
/// ```
|
|
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
|
#[must_use]
|
|
pub enum NumOpResult<T> {
|
|
/// Result of a successful mathematical operation.
|
|
Valid(T),
|
|
/// Result of an unsuccessful mathematical operation.
|
|
Error(NumOpError),
|
|
}
|
|
|
|
impl<T> NumOpResult<T> {
|
|
/// Maps a `NumOpResult<T>` to `NumOpResult<U>` by applying a function to a
|
|
/// contained [`NumOpResult::Valid`] value, leaving a [`NumOpResult::Error`] value untouched.
|
|
#[inline]
|
|
pub fn map<U, F: FnOnce(T) -> U>(self, op: F) -> NumOpResult<U> {
|
|
match self {
|
|
NumOpResult::Valid(t) => NumOpResult::Valid(op(t)),
|
|
NumOpResult::Error(e) => NumOpResult::Error(e),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<T: fmt::Debug> NumOpResult<T> {
|
|
/// Returns the contained valid numeric type, consuming `self`.
|
|
///
|
|
/// # Panics
|
|
///
|
|
/// Panics with `msg` if the numeric result is an `Error`.
|
|
#[inline]
|
|
#[track_caller]
|
|
pub fn expect(self, msg: &str) -> T {
|
|
match self {
|
|
R::Valid(x) => x,
|
|
R::Error(_) => panic!("{}", msg),
|
|
}
|
|
}
|
|
|
|
/// Returns the contained valid numeric type, consuming `self`.
|
|
///
|
|
/// # Panics
|
|
///
|
|
/// Panics if the numeric result is an `Error`.
|
|
#[inline]
|
|
#[track_caller]
|
|
pub fn unwrap(self) -> T {
|
|
match self {
|
|
R::Valid(x) => x,
|
|
R::Error(e) => panic!("tried to unwrap an invalid numeric result: {:?}", e),
|
|
}
|
|
}
|
|
|
|
/// Returns the contained error, consuming `self`.
|
|
///
|
|
/// # Panics
|
|
///
|
|
/// Panics if the numeric result is valid.
|
|
#[inline]
|
|
#[track_caller]
|
|
pub fn unwrap_err(self) -> NumOpError {
|
|
match self {
|
|
R::Error(e) => e,
|
|
R::Valid(a) => panic!("tried to unwrap a valid numeric result: {:?}", a),
|
|
}
|
|
}
|
|
|
|
/// Returns the contained Some value or a provided default.
|
|
///
|
|
/// Arguments passed to `unwrap_or` are eagerly evaluated; if you are passing the result of a
|
|
/// function call, it is recommended to use `unwrap_or_else`, which is lazily evaluated.
|
|
#[inline]
|
|
#[track_caller]
|
|
pub fn unwrap_or(self, default: T) -> T {
|
|
match self {
|
|
R::Valid(x) => x,
|
|
R::Error(_) => default,
|
|
}
|
|
}
|
|
|
|
/// Returns the contained `Some` value or computes it from a closure.
|
|
#[inline]
|
|
#[track_caller]
|
|
pub fn unwrap_or_else<F>(self, f: F) -> T
|
|
where
|
|
F: FnOnce() -> T,
|
|
{
|
|
match self {
|
|
R::Valid(x) => x,
|
|
R::Error(_) => f(),
|
|
}
|
|
}
|
|
|
|
/// Converts this `NumOpResult` to an `Option<T>`.
|
|
#[inline]
|
|
pub fn ok(self) -> Option<T> {
|
|
match self {
|
|
R::Valid(x) => Some(x),
|
|
R::Error(_) => None,
|
|
}
|
|
}
|
|
|
|
/// Converts this `NumOpResult` to a `Result<T, NumOpError>`.
|
|
#[inline]
|
|
#[allow(clippy::missing_errors_doc)]
|
|
pub fn into_result(self) -> Result<T, NumOpError> {
|
|
match self {
|
|
R::Valid(x) => Ok(x),
|
|
R::Error(e) => Err(e),
|
|
}
|
|
}
|
|
|
|
/// Calls `op` if the numeric result is `Valid`, otherwise returns the `Error` value of `self`.
|
|
#[inline]
|
|
pub fn and_then<F>(self, op: F) -> NumOpResult<T>
|
|
where
|
|
F: FnOnce(T) -> NumOpResult<T>,
|
|
{
|
|
match self {
|
|
R::Valid(x) => op(x),
|
|
R::Error(e) => R::Error(e),
|
|
}
|
|
}
|
|
|
|
/// Returns `true` if the numeric result is valid.
|
|
#[inline]
|
|
pub fn is_valid(&self) -> bool {
|
|
match self {
|
|
R::Valid(_) => true,
|
|
R::Error(_) => false,
|
|
}
|
|
}
|
|
|
|
/// Returns `true` if the numeric result is invalid.
|
|
#[inline]
|
|
pub fn is_error(&self) -> bool { !self.is_valid() }
|
|
}
|
|
|
|
pub(crate) trait OptionExt<T> {
|
|
fn valid_or_error(self, op: MathOp) -> NumOpResult<T>;
|
|
}
|
|
|
|
macro_rules! impl_opt_ext {
|
|
($($ty:ident),* $(,)?) => {
|
|
$(
|
|
impl OptionExt<$ty> for Option<$ty> {
|
|
#[inline]
|
|
fn valid_or_error(self, op: MathOp) -> NumOpResult<$ty> {
|
|
match self {
|
|
Some(amount) => R::Valid(amount),
|
|
None => R::Error(NumOpError(op)),
|
|
}
|
|
}
|
|
}
|
|
)*
|
|
}
|
|
}
|
|
impl_opt_ext!(Amount, SignedAmount, u64, i64, FeeRate, Weight);
|
|
|
|
/// Error returned when a mathematical operation fails.
|
|
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
|
#[non_exhaustive]
|
|
pub struct NumOpError(MathOp);
|
|
|
|
impl NumOpError {
|
|
/// Creates a [`NumOpError`] caused by `op`.
|
|
pub(crate) const fn while_doing(op: MathOp) -> Self { NumOpError(op) }
|
|
|
|
/// Returns `true` if this operation error'ed due to overflow.
|
|
pub fn is_overflow(self) -> bool { self.0.is_overflow() }
|
|
|
|
/// Returns `true` if this operation error'ed due to division by zero.
|
|
pub fn is_div_by_zero(self) -> bool { self.0.is_div_by_zero() }
|
|
|
|
/// Returns the [`MathOp`] that caused this error.
|
|
pub fn operation(self) -> MathOp { self.0 }
|
|
}
|
|
|
|
impl fmt::Display for NumOpError {
|
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
|
write!(f, "math operation '{}' gave an invalid numeric result", self.operation())
|
|
}
|
|
}
|
|
|
|
#[cfg(feature = "std")]
|
|
impl std::error::Error for NumOpError {}
|
|
|
|
/// The math operation that caused the error.
|
|
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
|
#[non_exhaustive]
|
|
pub enum MathOp {
|
|
/// Addition failed ([`core::ops::Add`] resulted in an invalid value).
|
|
Add,
|
|
/// Subtraction failed ([`core::ops::Sub`] resulted in an invalid value).
|
|
Sub,
|
|
/// Multiplication failed ([`core::ops::Mul`] resulted in an invalid value).
|
|
Mul,
|
|
/// Division failed ([`core::ops::Div`] attempted div-by-zero).
|
|
Div,
|
|
/// Calculating the remainder failed ([`core::ops::Rem`] attempted div-by-zero).
|
|
Rem,
|
|
/// Negation failed ([`core::ops::Neg`] resulted in an invalid value).
|
|
Neg,
|
|
/// Stops users from casting this enum to an integer.
|
|
// May get removed if one day Rust supports disabling casts natively.
|
|
#[doc(hidden)]
|
|
_DoNotUse(Infallible),
|
|
}
|
|
|
|
impl MathOp {
|
|
/// Returns `true` if this operation error'ed due to overflow.
|
|
pub fn is_overflow(self) -> bool {
|
|
matches!(self, MathOp::Add | MathOp::Sub | MathOp::Mul | MathOp::Neg)
|
|
}
|
|
|
|
/// Returns `true` if this operation error'ed due to division by zero.
|
|
pub fn is_div_by_zero(self) -> bool { !self.is_overflow() }
|
|
|
|
/// Returns `true` if this operation error'ed due to addition.
|
|
pub fn is_addition(self) -> bool { self == MathOp::Add }
|
|
|
|
/// Returns `true` if this operation error'ed due to subtraction.
|
|
pub fn is_subtraction(self) -> bool { self == MathOp::Sub }
|
|
|
|
/// Returns `true` if this operation error'ed due to multiplication.
|
|
pub fn is_multiplication(self) -> bool { self == MathOp::Mul }
|
|
|
|
/// Returns `true` if this operation error'ed due to negation.
|
|
pub fn is_negation(self) -> bool { self == MathOp::Neg }
|
|
}
|
|
|
|
impl fmt::Display for MathOp {
|
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
|
match *self {
|
|
MathOp::Add => write!(f, "add"),
|
|
MathOp::Sub => write!(f, "sub"),
|
|
MathOp::Mul => write!(f, "mul"),
|
|
MathOp::Div => write!(f, "div"),
|
|
MathOp::Rem => write!(f, "rem"),
|
|
MathOp::Neg => write!(f, "neg"),
|
|
MathOp::_DoNotUse(infallible) => match infallible {},
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(feature = "arbitrary")]
|
|
impl<'a, T: Arbitrary<'a>> Arbitrary<'a> for NumOpResult<T> {
|
|
fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result<Self> {
|
|
let choice = u.int_in_range(0..=1)?;
|
|
match choice {
|
|
0 => Ok(NumOpResult::Valid(T::arbitrary(u)?)),
|
|
_ => Ok(NumOpResult::Error(NumOpError(MathOp::arbitrary(u)?))),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(feature = "arbitrary")]
|
|
impl<'a> Arbitrary<'a> for MathOp {
|
|
fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result<Self> {
|
|
let choice = u.int_in_range(0..=5)?;
|
|
match choice {
|
|
0 => Ok(MathOp::Add),
|
|
1 => Ok(MathOp::Sub),
|
|
2 => Ok(MathOp::Mul),
|
|
3 => Ok(MathOp::Div),
|
|
4 => Ok(MathOp::Rem),
|
|
_ => Ok(MathOp::Neg),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use crate::MathOp;
|
|
|
|
#[test]
|
|
fn mathop_predicates() {
|
|
assert!(MathOp::Add.is_overflow());
|
|
assert!(MathOp::Sub.is_overflow());
|
|
assert!(MathOp::Mul.is_overflow());
|
|
assert!(MathOp::Neg.is_overflow());
|
|
assert!(!MathOp::Div.is_overflow());
|
|
assert!(!MathOp::Rem.is_overflow());
|
|
|
|
assert!(MathOp::Div.is_div_by_zero());
|
|
assert!(MathOp::Rem.is_div_by_zero());
|
|
assert!(!MathOp::Add.is_div_by_zero());
|
|
|
|
assert!(MathOp::Add.is_addition());
|
|
assert!(!MathOp::Sub.is_addition());
|
|
|
|
assert!(MathOp::Sub.is_subtraction());
|
|
assert!(!MathOp::Add.is_subtraction());
|
|
|
|
assert!(MathOp::Mul.is_multiplication());
|
|
assert!(!MathOp::Div.is_multiplication());
|
|
|
|
assert!(MathOp::Neg.is_negation());
|
|
assert!(!MathOp::Add.is_negation());
|
|
}
|
|
}
|