Verify and fix mul_u64

Add kani verification for `U256::mul_u64`, doing so uncovered a bug in
the current implementation due to overflow.

Re-write the `mul_u64` method.

Props to Elichai for the algorithm.

Co-authored-by: Elichai Turkel <elichai.turkel@gmail.com>
This commit is contained in:
Tobin C. Harding 2022-12-27 15:01:57 +11:00
parent 615759a8c2
commit cf9733d678
1 changed files with 41 additions and 23 deletions

View File

@ -421,32 +421,26 @@ impl U256 {
/// Wrapping multiplication by `u64`.
///
/// Returns a tuple of the addition along with a boolean indicating whether an arithmetic
/// overflow would occur. If an overflow would have occurred then the wrapped value is returned.
/// # Returns
///
/// The multiplication result along with a boolean indicating whether an arithmetic overflow
/// occurred. If an overflow occurred then the wrapped value is returned.
fn mul_u64(self, rhs: u64) -> (U256, bool) {
// Multiply 64 bit parts of `mul` by `rhs`.
fn mul_parts(mul: u128, rhs: u64) -> (u128, u128) {
let upper = (rhs as u128) * (mul >> 64);
let lower = (rhs as u128) * (mul & 0xFFFF_FFFF_FFFF_FFFF);
(upper, lower)
let mut carry: u128 = 0;
let mut split_le = [self.1 as u64, (self.1 >> 64) as u64, self.0 as u64, (self.0 >> 64) as u64];
for word in &mut split_le {
// TODO: Use `carrying_mul` when stabilized: https://github.com/rust-lang/rust/issues/85532
// This will not overflow, for proof see https://github.com/rust-bitcoin/rust-bitcoin/pull/1496#issuecomment-1365938572
let n = carry + u128::from(rhs) * u128::from(*word);
*word = n as u64; // Intentional truncation, save the low bits
carry = n >> 64; // and carry the high bits.
}
if self.is_zero() || rhs == 0 {
return (U256::ZERO, false);
}
let mut ret = U256::ZERO;
let mut ret_overflow = false;
let (upper, lower) = mul_parts(self.0, rhs);
ret.0 = lower + (upper << 64);
ret_overflow |= upper >> 64 > 0;
let (upper, lower) = mul_parts(self.1, rhs);
ret.1 = lower + (upper << 64);
ret.0 += upper >> 64;
(ret, ret_overflow)
let low = u128::from(split_le[0]) | u128::from(split_le[1]) << 64;
let high = u128::from(split_le[2]) | u128::from(split_le[3]) << 64;
(Self(high, low), carry != 0)
}
/// Calculates quotient and remainder.
@ -898,6 +892,15 @@ fn split_in_half(a: [u8; 32]) -> ([u8; 16], [u8; 16]) {
(high, low)
}
#[cfg(kani)]
impl kani::Arbitrary for U256 {
fn any() -> Self {
let high: u128 = kani::any();
let low: u128 = kani::any();
Self(high, low)
}
}
#[cfg(test)]
mod tests {
use super::*;
@ -1613,3 +1616,18 @@ mod tests {
work_multiplication_by_u64, 2_u64;
}
}
#[cfg(kani)]
mod verification {
use super::*;
// TODO: After we verify div_rem assert x * y / y == x
#[kani::unwind(5)] // mul_u64 loops over 4 64 bit ints so use one more than 4
#[kani::proof]
fn check_mul_u64() {
let x: U256 = kani::any();
let y: u64 = kani::any();
let _ = x.mul_u64(y);
}
}