diff --git a/bitcoin/src/pow.rs b/bitcoin/src/pow.rs index 709c8dc6..b9b3a9a3 100644 --- a/bitcoin/src/pow.rs +++ b/bitcoin/src/pow.rs @@ -421,32 +421,26 @@ impl U256 { /// Wrapping multiplication by `u64`. /// - /// Returns a tuple of the addition along with a boolean indicating whether an arithmetic - /// overflow would occur. If an overflow would have occurred then the wrapped value is returned. + /// # Returns + /// + /// The multiplication result along with a boolean indicating whether an arithmetic overflow + /// occurred. If an overflow occurred then the wrapped value is returned. fn mul_u64(self, rhs: u64) -> (U256, bool) { - // Multiply 64 bit parts of `mul` by `rhs`. - fn mul_parts(mul: u128, rhs: u64) -> (u128, u128) { - let upper = (rhs as u128) * (mul >> 64); - let lower = (rhs as u128) * (mul & 0xFFFF_FFFF_FFFF_FFFF); - (upper, lower) + let mut carry: u128 = 0; + let mut split_le = [self.1 as u64, (self.1 >> 64) as u64, self.0 as u64, (self.0 >> 64) as u64]; + + for word in &mut split_le { + // TODO: Use `carrying_mul` when stabilized: https://github.com/rust-lang/rust/issues/85532 + // This will not overflow, for proof see https://github.com/rust-bitcoin/rust-bitcoin/pull/1496#issuecomment-1365938572 + let n = carry + u128::from(rhs) * u128::from(*word); + + *word = n as u64; // Intentional truncation, save the low bits + carry = n >> 64; // and carry the high bits. } - if self.is_zero() || rhs == 0 { - return (U256::ZERO, false); - } - - let mut ret = U256::ZERO; - let mut ret_overflow = false; - - let (upper, lower) = mul_parts(self.0, rhs); - ret.0 = lower + (upper << 64); - ret_overflow |= upper >> 64 > 0; - - let (upper, lower) = mul_parts(self.1, rhs); - ret.1 = lower + (upper << 64); - ret.0 += upper >> 64; - - (ret, ret_overflow) + let low = u128::from(split_le[0]) | u128::from(split_le[1]) << 64; + let high = u128::from(split_le[2]) | u128::from(split_le[3]) << 64; + (Self(high, low), carry != 0) } /// Calculates quotient and remainder. @@ -898,6 +892,15 @@ fn split_in_half(a: [u8; 32]) -> ([u8; 16], [u8; 16]) { (high, low) } +#[cfg(kani)] +impl kani::Arbitrary for U256 { + fn any() -> Self { + let high: u128 = kani::any(); + let low: u128 = kani::any(); + Self(high, low) + } +} + #[cfg(test)] mod tests { use super::*; @@ -1613,3 +1616,18 @@ mod tests { work_multiplication_by_u64, 2_u64; } } + +#[cfg(kani)] +mod verification { + use super::*; + + // TODO: After we verify div_rem assert x * y / y == x + #[kani::unwind(5)] // mul_u64 loops over 4 64 bit ints so use one more than 4 + #[kani::proof] + fn check_mul_u64() { + let x: U256 = kani::any(); + let y: u64 = kani::any(); + + let _ = x.mul_u64(y); + } +}