diff --git a/src/bigint.rs b/src/bigint.rs index 2fe98a2..2c7b81b 100644 --- a/src/bigint.rs +++ b/src/bigint.rs @@ -86,7 +86,6 @@ pub type BigDigit = u32; pub type DoubleBigDigit = u64; pub const ZERO_BIG_DIGIT: BigDigit = 0; -static ZERO_VEC: [BigDigit; 1] = [ZERO_BIG_DIGIT]; #[allow(non_snake_case)] pub mod big_digit { @@ -237,7 +236,26 @@ macro_rules! forward_val_val_binop { #[inline] fn $method(self, other: $res) -> $res { - (&self).$method(&other) + // forward to val-ref + $imp::$method(self, &other) + } + } + } +} + +macro_rules! forward_val_val_binop_commutative { + (impl $imp:ident for $res:ty, $method:ident) => { + impl $imp<$res> for $res { + type Output = $res; + + #[inline] + fn $method(self, other: $res) -> $res { + // forward to val-ref, with the larger capacity as val + if self.data.capacity() >= other.data.capacity() { + $imp::$method(self, &other) + } else { + $imp::$method(other, &self) + } } } } @@ -250,7 +268,22 @@ macro_rules! forward_ref_val_binop { #[inline] fn $method(self, other: $res) -> $res { - self.$method(&other) + // forward to ref-ref + $imp::$method(self, &other) + } + } + } +} + +macro_rules! forward_ref_val_binop_commutative { + (impl $imp:ident for $res:ty, $method:ident) => { + impl<'a> $imp<$res> for &'a $res { + type Output = $res; + + #[inline] + fn $method(self, other: $res) -> $res { + // reverse, forward to val-ref + $imp::$method(other, self) } } } @@ -263,58 +296,121 @@ macro_rules! forward_val_ref_binop { #[inline] fn $method(self, other: &$res) -> $res { - (&self).$method(other) + // forward to ref-ref + $imp::$method(&self, other) } } } } -macro_rules! forward_all_binop { +macro_rules! forward_ref_ref_binop { + (impl $imp:ident for $res:ty, $method:ident) => { + impl<'a, 'b> $imp<&'b $res> for &'a $res { + type Output = $res; + + #[inline] + fn $method(self, other: &$res) -> $res { + // forward to val-ref + $imp::$method(self.clone(), other) + } + } + } +} + +macro_rules! forward_ref_ref_binop_commutative { + (impl $imp:ident for $res:ty, $method:ident) => { + impl<'a, 'b> $imp<&'b $res> for &'a $res { + type Output = $res; + + #[inline] + fn $method(self, other: &$res) -> $res { + // forward to val-ref, choosing the larger to clone + if self.data.len() >= other.data.len() { + $imp::$method(self.clone(), other) + } else { + $imp::$method(other.clone(), self) + } + } + } + } +} + +// Forward everything to ref-ref, when reusing storage is not helpful +macro_rules! forward_all_binop_to_ref_ref { (impl $imp:ident for $res:ty, $method:ident) => { forward_val_val_binop!(impl $imp for $res, $method); - forward_ref_val_binop!(impl $imp for $res, $method); forward_val_ref_binop!(impl $imp for $res, $method); + forward_ref_val_binop!(impl $imp for $res, $method); }; } -forward_all_binop!(impl BitAnd for BigUint, bitand); +// Forward everything to val-ref, so LHS storage can be reused +macro_rules! forward_all_binop_to_val_ref { + (impl $imp:ident for $res:ty, $method:ident) => { + forward_val_val_binop!(impl $imp for $res, $method); + forward_ref_val_binop!(impl $imp for $res, $method); + forward_ref_ref_binop!(impl $imp for $res, $method); + }; +} -impl<'a, 'b> BitAnd<&'b BigUint> for &'a BigUint { +// Forward everything to val-ref, commutatively, so either LHS or RHS storage can be reused +macro_rules! forward_all_binop_to_val_ref_commutative { + (impl $imp:ident for $res:ty, $method:ident) => { + forward_val_val_binop_commutative!(impl $imp for $res, $method); + forward_ref_val_binop_commutative!(impl $imp for $res, $method); + forward_ref_ref_binop_commutative!(impl $imp for $res, $method); + }; +} + +forward_all_binop_to_val_ref_commutative!(impl BitAnd for BigUint, bitand); + +impl<'a> BitAnd<&'a BigUint> for BigUint { type Output = BigUint; #[inline] fn bitand(self, other: &BigUint) -> BigUint { - BigUint::new(self.data.iter().zip(other.data.iter()).map(|(ai, bi)| *ai & *bi).collect()) + let mut data = self.data; + for (ai, &bi) in data.iter_mut().zip(other.data.iter()) { + *ai &= bi; + } + data.truncate(other.data.len()); + BigUint::new(data) } } -forward_all_binop!(impl BitOr for BigUint, bitor); +forward_all_binop_to_val_ref_commutative!(impl BitOr for BigUint, bitor); -impl<'a, 'b> BitOr<&'b BigUint> for &'a BigUint { +impl<'a> BitOr<&'a BigUint> for BigUint { type Output = BigUint; fn bitor(self, other: &BigUint) -> BigUint { - let zeros = ZERO_VEC.iter().cycle(); - let (a, b) = if self.data.len() > other.data.len() { (self, other) } else { (other, self) }; - let ored = a.data.iter().zip(b.data.iter().chain(zeros)).map( - |(ai, bi)| *ai | *bi - ).collect(); - return BigUint::new(ored); + let mut data = self.data; + for (ai, &bi) in data.iter_mut().zip(other.data.iter()) { + *ai |= bi; + } + if other.data.len() > data.len() { + let extra = &other.data[data.len()..]; + data.extend(extra.iter().cloned()); + } + BigUint::new(data) } } -forward_all_binop!(impl BitXor for BigUint, bitxor); +forward_all_binop_to_val_ref_commutative!(impl BitXor for BigUint, bitxor); -impl<'a, 'b> BitXor<&'b BigUint> for &'a BigUint { +impl<'a> BitXor<&'a BigUint> for BigUint { type Output = BigUint; fn bitxor(self, other: &BigUint) -> BigUint { - let zeros = ZERO_VEC.iter().cycle(); - let (a, b) = if self.data.len() > other.data.len() { (self, other) } else { (other, self) }; - let xored = a.data.iter().zip(b.data.iter().chain(zeros)).map( - |(ai, bi)| *ai ^ *bi - ).collect(); - return BigUint::new(xored); + let mut data = self.data; + for (ai, &bi) in data.iter_mut().zip(other.data.iter()) { + *ai ^= bi; + } + if other.data.len() > data.len() { + let extra = &other.data[data.len()..]; + data.extend(extra.iter().cloned()); + } + BigUint::new(data) } } @@ -332,7 +428,7 @@ impl<'a> Shl for &'a BigUint { fn shl(self, rhs: usize) -> BigUint { let n_unit = rhs / big_digit::BITS; let n_bits = rhs % big_digit::BITS; - return self.shl_unit(n_unit).shl_bits(n_bits); + self.shl_unit(n_unit).shl_bits(n_bits) } } @@ -350,7 +446,7 @@ impl<'a> Shr for &'a BigUint { fn shr(self, rhs: usize) -> BigUint { let n_unit = rhs / big_digit::BITS; let n_bits = rhs % big_digit::BITS; - return self.shr_unit(n_unit).shr_bits(n_bits); + self.shr_unit(n_unit).shr_bits(n_bits) } } @@ -369,70 +465,85 @@ impl One for BigUint { impl Unsigned for BigUint {} -forward_all_binop!(impl Add for BigUint, add); +forward_all_binop_to_val_ref_commutative!(impl Add for BigUint, add); -impl<'a, 'b> Add<&'b BigUint> for &'a BigUint { +impl<'a> Add<&'a BigUint> for BigUint { type Output = BigUint; fn add(self, other: &BigUint) -> BigUint { - let zeros = ZERO_VEC.iter().cycle(); - let (a, b) = if self.data.len() > other.data.len() { (self, other) } else { (other, self) }; + let mut sum = self.data; + + if other.data.len() > sum.len() { + let additional = other.data.len() - sum.len(); + sum.reserve(additional); + sum.extend(repeat(ZERO_BIG_DIGIT).take(additional)); + } + let other_iter = other.data.iter().cloned().chain(repeat(ZERO_BIG_DIGIT)); let mut carry = 0; - let mut sum: Vec = a.data.iter().zip(b.data.iter().chain(zeros)).map(|(ai, bi)| { - let (hi, lo) = big_digit::from_doublebigdigit( - (*ai as DoubleBigDigit) + (*bi as DoubleBigDigit) + (carry as DoubleBigDigit)); + for (a, b) in sum.iter_mut().zip(other_iter) { + let d = (*a as DoubleBigDigit) + + (b as DoubleBigDigit) + + (carry as DoubleBigDigit); + let (hi, lo) = big_digit::from_doublebigdigit(d); carry = hi; - lo - }).collect(); + *a = lo; + } + if carry != 0 { sum.push(carry); } - return BigUint::new(sum); + BigUint::new(sum) } } -forward_all_binop!(impl Sub for BigUint, sub); +forward_all_binop_to_val_ref!(impl Sub for BigUint, sub); -impl<'a, 'b> Sub<&'b BigUint> for &'a BigUint { +impl<'a> Sub<&'a BigUint> for BigUint { type Output = BigUint; fn sub(self, other: &BigUint) -> BigUint { - let new_len = cmp::max(self.data.len(), other.data.len()); - let zeros = ZERO_VEC.iter().cycle(); - let (a, b) = (self.data.iter().chain(zeros.clone()), other.data.iter().chain(zeros)); + let mut diff = self.data; + let other = &other.data; + assert!(diff.len() >= other.len(), "arithmetic operation overflowed"); - let mut borrow = 0isize; - let diff: Vec = a.take(new_len).zip(b).map(|(ai, bi)| { - let (hi, lo) = big_digit::from_doublebigdigit( - big_digit::BASE - + (*ai as DoubleBigDigit) - - (*bi as DoubleBigDigit) - - (borrow as DoubleBigDigit) - ); + let mut borrow: DoubleBigDigit = 0; + for (a, &b) in diff.iter_mut().zip(other.iter()) { + let d = big_digit::BASE - borrow + + (*a as DoubleBigDigit) + - (b as DoubleBigDigit); + let (hi, lo) = big_digit::from_doublebigdigit(d); /* hi * (base) + lo == 1*(base) + ai - bi - borrow => ai - bi - borrow < 0 <=> hi == 0 */ borrow = if hi == 0 { 1 } else { 0 }; - lo - }).collect(); + *a = lo; + } - assert!(borrow == 0, - "Cannot subtract other from self because other is larger than self."); - return BigUint::new(diff); + for a in &mut diff[other.len()..] { + if borrow == 0 { break } + let d = big_digit::BASE - borrow + + (*a as DoubleBigDigit); + let (hi, lo) = big_digit::from_doublebigdigit(d); + borrow = if hi == 0 { 1 } else { 0 }; + *a = lo; + } + + assert!(borrow == 0, "arithmetic operation overflowed"); + BigUint::new(diff) } } -forward_all_binop!(impl Mul for BigUint, mul); +forward_all_binop_to_val_ref_commutative!(impl Mul for BigUint, mul); -impl<'a, 'b> Mul<&'b BigUint> for &'a BigUint { +impl<'a> Mul<&'a BigUint> for BigUint { type Output = BigUint; fn mul(self, other: &BigUint) -> BigUint { if self.is_zero() || other.is_zero() { return Zero::zero(); } let (s_len, o_len) = (self.data.len(), other.data.len()); - if s_len == 1 { return mul_digit(other, self.data[0]); } + if s_len == 1 { return mul_digit(other.clone(), self.data[0]); } if o_len == 1 { return mul_digit(self, other.data[0]); } // Using Karatsuba multiplication @@ -442,7 +553,7 @@ impl<'a, 'b> Mul<&'b BigUint> for &'a BigUint { // a0*b0 let half_len = cmp::max(s_len, o_len) / 2; let (s_hi, s_lo) = cut_at(self, half_len); - let (o_hi, o_lo) = cut_at(other, half_len); + let (o_hi, o_lo) = cut_at(other.clone(), half_len); let ll = &s_lo * &o_lo; let hh = &s_hi * &o_hi; @@ -459,27 +570,30 @@ impl<'a, 'b> Mul<&'b BigUint> for &'a BigUint { return ll + mm.shl_unit(half_len) + hh.shl_unit(half_len * 2); - fn mul_digit(a: &BigUint, n: BigDigit) -> BigUint { + fn mul_digit(a: BigUint, n: BigDigit) -> BigUint { if n == 0 { return Zero::zero(); } - if n == 1 { return a.clone(); } + if n == 1 { return a; } let mut carry = 0; - let mut prod: Vec = a.data.iter().map(|ai| { - let (hi, lo) = big_digit::from_doublebigdigit( - (*ai as DoubleBigDigit) * (n as DoubleBigDigit) + (carry as DoubleBigDigit) - ); + let mut prod = a.data; + for a in &mut prod { + let d = (*a as DoubleBigDigit) + * (n as DoubleBigDigit) + + (carry as DoubleBigDigit); + let (hi, lo) = big_digit::from_doublebigdigit(d); carry = hi; - lo - }).collect(); + *a = lo; + } if carry != 0 { prod.push(carry); } - return BigUint::new(prod); + BigUint::new(prod) } #[inline] - fn cut_at(a: &BigUint, n: usize) -> (BigUint, BigUint) { + fn cut_at(mut a: BigUint, n: usize) -> (BigUint, BigUint) { let mid = cmp::min(a.data.len(), n); - (BigUint::from_slice(&a.data[mid ..]), - BigUint::from_slice(&a.data[.. mid])) + let hi = BigUint::from_slice(&a.data[mid ..]); + a.data.truncate(mid); + (hi, BigUint::new(a.data)) } #[inline] @@ -494,7 +608,7 @@ impl<'a, 'b> Mul<&'b BigUint> for &'a BigUint { } -forward_all_binop!(impl Div for BigUint, div); +forward_all_binop_to_ref_ref!(impl Div for BigUint, div); impl<'a, 'b> Div<&'b BigUint> for &'a BigUint { type Output = BigUint; @@ -506,7 +620,7 @@ impl<'a, 'b> Div<&'b BigUint> for &'a BigUint { } } -forward_all_binop!(impl Rem for BigUint, rem); +forward_all_binop_to_ref_ref!(impl Rem for BigUint, rem); impl<'a, 'b> Rem<&'b BigUint> for &'a BigUint { type Output = BigUint; @@ -587,10 +701,10 @@ impl Integer for BigUint { fn div_mod_floor(&self, other: &BigUint) -> (BigUint, BigUint) { if other.is_zero() { panic!() } if self.is_zero() { return (Zero::zero(), Zero::zero()); } - if *other == One::one() { return ((*self).clone(), Zero::zero()); } + if *other == One::one() { return (self.clone(), Zero::zero()); } match self.cmp(other) { - Less => return (Zero::zero(), (*self).clone()), + Less => return (Zero::zero(), self.clone()), Equal => return (One::one(), Zero::zero()), Greater => {} // Do nothing } @@ -1007,48 +1121,53 @@ impl BigUint { #[inline] fn shl_unit(&self, n_unit: usize) -> BigUint { - if n_unit == 0 || self.is_zero() { return (*self).clone(); } + if n_unit == 0 || self.is_zero() { return self.clone(); } - let mut v = repeat(ZERO_BIG_DIGIT).take(n_unit).collect::>(); + let mut v = vec![0; n_unit]; v.extend(self.data.iter().cloned()); BigUint::new(v) } #[inline] - fn shl_bits(&self, n_bits: usize) -> BigUint { - if n_bits == 0 || self.is_zero() { return (*self).clone(); } + fn shl_bits(self, n_bits: usize) -> BigUint { + if n_bits == 0 || self.is_zero() { return self; } + + assert!(n_bits < big_digit::BITS); let mut carry = 0; - let mut shifted: Vec = self.data.iter().map(|elem| { - let (hi, lo) = big_digit::from_doublebigdigit( - (*elem as DoubleBigDigit) << n_bits | (carry as DoubleBigDigit) - ); - carry = hi; - lo - }).collect(); - if carry != 0 { shifted.push(carry); } - return BigUint::new(shifted); + let mut shifted = self.data; + for elem in shifted.iter_mut() { + let new_carry = *elem >> (big_digit::BITS - n_bits); + *elem = (*elem << n_bits) | carry; + carry = new_carry; + } + if carry != 0 { + shifted.push(carry); + } + BigUint::new(shifted) } #[inline] fn shr_unit(&self, n_unit: usize) -> BigUint { - if n_unit == 0 { return (*self).clone(); } + if n_unit == 0 { return self.clone(); } if self.data.len() < n_unit { return Zero::zero(); } BigUint::from_slice(&self.data[n_unit ..]) } #[inline] - fn shr_bits(&self, n_bits: usize) -> BigUint { - if n_bits == 0 || self.data.is_empty() { return (*self).clone(); } + fn shr_bits(self, n_bits: usize) -> BigUint { + if n_bits == 0 || self.data.is_empty() { return self; } + + assert!(n_bits < big_digit::BITS); let mut borrow = 0; - let mut shifted_rev = Vec::with_capacity(self.data.len()); - for elem in self.data.iter().rev() { - shifted_rev.push((*elem >> n_bits) | borrow); - borrow = *elem << (big_digit::BITS - n_bits); + let mut shifted = self.data; + for elem in shifted.iter_mut().rev() { + let new_borrow = *elem << (big_digit::BITS - n_bits); + *elem = (*elem >> n_bits) | borrow; + borrow = new_borrow; } - let shifted = { shifted_rev.reverse(); shifted_rev }; - return BigUint::new(shifted); + BigUint::new(shifted) } /// Determines the fewest bits necessary to express the `BigUint`.