From 8b1288ea01d88dfe778a62ff19d1ec6bdab27872 Mon Sep 17 00:00:00 2001 From: Sam Cappleman-Lynes Date: Thu, 29 Jun 2017 15:46:07 +0100 Subject: [PATCH] Add scalar multiplication to BigInt --- bigint/src/bigint.rs | 2 ++ bigint/src/biguint.rs | 10 +++++++--- bigint/src/tests/bigint.rs | 20 +++++++++++--------- 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/bigint/src/bigint.rs b/bigint/src/bigint.rs index 32af7ab..69937ad 100644 --- a/bigint/src/bigint.rs +++ b/bigint/src/bigint.rs @@ -485,6 +485,8 @@ impl<'a, 'b> Mul<&'b BigInt> for &'a BigInt { } } +forward_all_scalar_binop_to_val_val_commutative!(impl Mul for BigInt, mul); + impl Mul for BigInt { type Output = BigInt; diff --git a/bigint/src/biguint.rs b/bigint/src/biguint.rs index 631686d..28b6e73 100644 --- a/bigint/src/biguint.rs +++ b/bigint/src/biguint.rs @@ -483,9 +483,13 @@ impl Mul for BigUint { #[inline] fn mul(mut self, other: BigDigit) -> BigUint { - let carry = scalar_mul(&mut self.data[..], other); - if carry != 0 { - self.data.push(carry); + if other == 0 { + self.data.clear(); + } else { + let carry = scalar_mul(&mut self.data[..], other); + if carry != 0 { + self.data.push(carry); + } } self } diff --git a/bigint/src/tests/bigint.rs b/bigint/src/tests/bigint.rs index 610dfa9..a24cb3e 100644 --- a/bigint/src/tests/bigint.rs +++ b/bigint/src/tests/bigint.rs @@ -703,23 +703,25 @@ fn test_mul() { fn test_scalar_mul() { for elm in MUL_TRIPLES.iter() { let (a_vec, b_vec, c_vec) = *elm; + let a = BigInt::from_slice(Plus, a_vec); + let b = BigInt::from_slice(Plus, b_vec); let c = BigInt::from_slice(Plus, c_vec); - let nc = BigInt::from_slice(Minus, c_vec); + let (na, nb, nc) = (-&a, -&b, -&c); if a_vec.len() == 1 { - let b = BigInt::from_slice(Plus, b_vec); - let nb = BigInt::from_slice(Minus, b_vec); let a = a_vec[0]; - assert!(b * a == c); - assert!(nb * a == nc); + assert_op!(b * a == c); + assert_op!(a * b == c); + assert_op!(nb * a == nc); + assert_op!(a * nb == nc); } if b_vec.len() == 1 { - let a = BigInt::from_slice(Plus, a_vec); - let na = BigInt::from_slice(Minus, a_vec); let b = b_vec[0]; - assert!(a * b == c); - assert!(na * b == nc); + assert_op!(a * b == c); + assert_op!(b * a == c); + assert_op!(na * b == nc); + assert_op!(b * na == nc); } } }