diff --git a/src/util/uint.rs b/src/util/uint.rs index e05e9b99..29f7d1c4 100644 --- a/src/util/uint.rs +++ b/src/util/uint.rs @@ -84,6 +84,42 @@ macro_rules! construct_uint { assert!(init >= 0); $name::from_u64(init as u64) } + + // divmod like operation, returns (quotient, remainder) + #[inline] + fn div_rem(self, other: Self) -> (Self, Self) { + let mut sub_copy = self; + let mut shift_copy = other; + let mut ret = [0u64; $n_words]; + + let my_bits = self.bits(); + let your_bits = other.bits(); + + // Check for division by 0 + assert!(your_bits != 0); + + // Early return in case we are dividing by a larger number than us + if my_bits < your_bits { + return ($name(ret), sub_copy); + } + + // Bitwise long division + let mut shift = my_bits - your_bits; + shift_copy = shift_copy << shift; + loop { + if sub_copy >= shift_copy { + ret[shift / 64] |= 1 << (shift % 64); + sub_copy = sub_copy - shift_copy; + } + shift_copy = shift_copy >> 1; + if shift == 0 { + break; + } + shift -= 1; + } + + ($name(ret), sub_copy) + } } impl ::std::ops::Add<$name> for $name { @@ -134,35 +170,7 @@ macro_rules! construct_uint { type Output = $name; fn div(self, other: $name) -> $name { - let mut sub_copy = self; - let mut shift_copy = other; - let mut ret = [0u64; $n_words]; - - let my_bits = self.bits(); - let your_bits = other.bits(); - - // Check for division by 0 - assert!(your_bits != 0); - - // Early return in case we are dividing by a larger number than us - if my_bits < your_bits { - return $name(ret); - } - - // Bitwise long division - let mut shift = my_bits - your_bits; - shift_copy = shift_copy << shift; - loop { - if sub_copy >= shift_copy { - ret[shift / 64] |= 1 << (shift % 64); - sub_copy = sub_copy - shift_copy; - } - shift_copy = shift_copy >> 1; - if shift == 0 { break; } - shift -= 1; - } - - $name(ret) + self.div_rem(other).0 } } @@ -170,8 +178,7 @@ macro_rules! construct_uint { type Output = $name; fn rem(self, other: $name) -> $name { - let times = self / other; - self - (times * other) + self.div_rem(other).1 } } @@ -568,4 +575,3 @@ mod tests { assert_eq!(end2.ok(), Some(start2)); } } -