From e520bdad0d047d9f0fb7aad46f44faae9d45de9d Mon Sep 17 00:00:00 2001 From: Sam Cappleman-Lynes Date: Mon, 24 Oct 2016 07:41:59 +0100 Subject: [PATCH] Add scalar multiplication to BigUint, BigInt BigUint and BigInt can now be multiplied by a BigDigit, re-using the same buffer for the output, thereby reducing allocations and copying. --- bigint/src/algorithms.rs | 17 +++++++++++++++++ bigint/src/bigint.rs | 9 +++++++++ bigint/src/biguint.rs | 15 ++++++++++++++- bigint/src/tests/bigint.rs | 25 +++++++++++++++++++++++++ bigint/src/tests/biguint.rs | 20 ++++++++++++++++++++ 5 files changed, 85 insertions(+), 1 deletion(-) diff --git a/bigint/src/algorithms.rs b/bigint/src/algorithms.rs index 0afd6b1..3c1ac9f 100644 --- a/bigint/src/algorithms.rs +++ b/bigint/src/algorithms.rs @@ -85,6 +85,15 @@ pub fn mac_with_carry(a: BigDigit, b: BigDigit, c: BigDigit, carry: &mut BigDigi lo } +#[inline] +pub fn mul_with_carry(a: BigDigit, b: BigDigit, carry: &mut BigDigit) -> BigDigit { + let (hi, lo) = big_digit::from_doublebigdigit((a as DoubleBigDigit) * (b as DoubleBigDigit) + + (*carry as DoubleBigDigit)); + + *carry = hi; + lo +} + /// Divide a two digit numerator by a one digit divisor, returns quotient and remainder: /// /// Note: the caller must ensure that both the quotient and remainder will fit into a single digit. @@ -377,6 +386,14 @@ pub fn mul3(x: &[BigDigit], y: &[BigDigit]) -> BigUint { prod.normalize() } +pub fn scalar_mul(a: &mut [BigDigit], b: BigDigit) -> BigDigit { + let mut carry = 0; + for a in a.iter_mut() { + *a = mul_with_carry(*a, b, &mut carry); + } + carry +} + pub fn div_rem(u: &BigUint, d: &BigUint) -> (BigUint, BigUint) { if d.is_zero() { panic!() diff --git a/bigint/src/bigint.rs b/bigint/src/bigint.rs index b86ca52..93f4ae7 100644 --- a/bigint/src/bigint.rs +++ b/bigint/src/bigint.rs @@ -436,6 +436,15 @@ impl<'a, 'b> Mul<&'b BigInt> for &'a BigInt { } } +impl Mul for BigInt { + type Output = BigInt; + + #[inline] + fn mul(self, other: BigDigit) -> BigInt { + BigInt::from_biguint(self.sign, self.data * other) + } +} + forward_all_binop_to_ref_ref!(impl Div for BigInt, div); impl<'a, 'b> Div<&'b BigInt> for &'a BigInt { diff --git a/bigint/src/biguint.rs b/bigint/src/biguint.rs index 340aaae..a5b74d0 100644 --- a/bigint/src/biguint.rs +++ b/bigint/src/biguint.rs @@ -22,7 +22,7 @@ mod algorithms; pub use self::algorithms::big_digit; pub use self::big_digit::{BigDigit, DoubleBigDigit, ZERO_BIG_DIGIT}; -use self::algorithms::{mac_with_carry, mul3, div_rem, div_rem_digit}; +use self::algorithms::{mac_with_carry, mul3, scalar_mul, div_rem, div_rem_digit}; use self::algorithms::{__add2, add2, sub2, sub2rev}; use self::algorithms::{biguint_shl, biguint_shr}; use self::algorithms::{cmp_slice, fls, ilog2}; @@ -431,6 +431,19 @@ impl<'a, 'b> Mul<&'b BigUint> for &'a BigUint { } } +impl Mul for BigUint { + type Output = BigUint; + + #[inline] + fn mul(mut self, other: BigDigit) -> BigUint { + let carry = scalar_mul(&mut self.data[..], other); + if carry != 0 { + self.data.push(carry); + } + self + } +} + forward_all_binop_to_ref_ref!(impl Div for BigUint, div); impl<'a, 'b> Div<&'b BigUint> for &'a BigUint { diff --git a/bigint/src/tests/bigint.rs b/bigint/src/tests/bigint.rs index aa4319d..9e83373 100644 --- a/bigint/src/tests/bigint.rs +++ b/bigint/src/tests/bigint.rs @@ -637,6 +637,31 @@ fn test_mul() { } } +#[test] +fn test_scalar_mul() { + for elm in MUL_TRIPLES.iter() { + let (a_vec, b_vec, c_vec) = *elm; + let c = BigInt::from_slice(Plus, c_vec); + let nc = BigInt::from_slice(Minus, c_vec); + + if a_vec.len() == 1 { + let b = BigInt::from_slice(Plus, b_vec); + let nb = BigInt::from_slice(Minus, b_vec); + let a = a_vec[0]; + assert!(b * a == c); + assert!(nb * a == nc); + } + + if b_vec.len() == 1 { + let a = BigInt::from_slice(Plus, a_vec); + let na = BigInt::from_slice(Minus, a_vec); + let b = b_vec[0]; + assert!(a * b == c); + assert!(na * b == nc); + } + } +} + #[test] fn test_div_mod_floor() { fn check_sub(a: &BigInt, b: &BigInt, ans_d: &BigInt, ans_m: &BigInt) { diff --git a/bigint/src/tests/biguint.rs b/bigint/src/tests/biguint.rs index 7c5a423..ac95235 100644 --- a/bigint/src/tests/biguint.rs +++ b/bigint/src/tests/biguint.rs @@ -770,6 +770,26 @@ fn test_mul() { } } +#[test] +fn test_scalar_mul() { + for elm in MUL_TRIPLES.iter() { + let (a_vec, b_vec, c_vec) = *elm; + let c = BigUint::from_slice(c_vec); + + if a_vec.len() == 1 { + let b = BigUint::from_slice(b_vec); + let a = a_vec[0]; + assert!(b * a == c); + } + + if b_vec.len() == 1 { + let a = BigUint::from_slice(a_vec); + let b = b_vec[0]; + assert!(a * b == c); + } + } +} + #[test] fn test_div_rem() { for elm in MUL_TRIPLES.iter() {