diff --git a/bigint/src/bigint.rs b/bigint/src/bigint.rs index dba6a7a..32af7ab 100644 --- a/bigint/src/bigint.rs +++ b/bigint/src/bigint.rs @@ -445,6 +445,35 @@ impl Sub for BigInt { } } +forward_all_scalar_binop_to_val_val!(impl Sub for BigInt, sub); + +impl Sub for BigInt { + type Output = BigInt; + + #[inline] + fn sub(self, other: BigDigit) -> BigInt { + match self.sign { + NoSign => BigInt::from_biguint(Minus, From::from(other)), + Minus => BigInt::from_biguint(Minus, self.data + other), + Plus => + match self.data.cmp(&From::from(other)) { + Equal => Zero::zero(), + Greater => BigInt::from_biguint(Plus, self.data - other), + Less => BigInt::from_biguint(Minus, other - self.data), + } + } + } +} + +impl Sub for BigDigit { + type Output = BigInt; + + #[inline] + fn sub(self, other: BigInt) -> BigInt { + -(other - self) + } +} + forward_all_binop_to_ref_ref!(impl Mul for BigInt, mul); impl<'a, 'b> Mul<&'b BigInt> for &'a BigInt { diff --git a/bigint/src/tests/bigint.rs b/bigint/src/tests/bigint.rs index 4cdf616..610dfa9 100644 --- a/bigint/src/tests/bigint.rs +++ b/bigint/src/tests/bigint.rs @@ -556,9 +556,10 @@ fn test_add() { fn test_scalar_add() { for elm in SUM_TRIPLES.iter() { let (a_vec, b_vec, c_vec) = *elm; + let a = BigInt::from_slice(Plus, a_vec); let b = BigInt::from_slice(Plus, b_vec); let c = BigInt::from_slice(Plus, c_vec); - let (nb, nc) = (-&b, -&c); + let (na, nb, nc) = (-&a, -&b, -&c); if a_vec.len() == 1 { let a = a_vec[0]; @@ -567,6 +568,14 @@ fn test_scalar_add() { assert_op!(a + nc == nb); assert_op!(nc + a == nb); } + + if b_vec.len() == 1 { + let b = b_vec[0]; + assert_op!(a + b == c); + assert_op!(b + a == c); + assert_op!(b + nc == na); + assert_op!(nc + b == na); + } } } @@ -590,6 +599,41 @@ fn test_sub() { } } +#[test] +fn test_scalar_sub() { + for elm in SUM_TRIPLES.iter() { + let (a_vec, b_vec, c_vec) = *elm; + let a = BigInt::from_slice(Plus, a_vec); + let b = BigInt::from_slice(Plus, b_vec); + let c = BigInt::from_slice(Plus, c_vec); + let (na, nb, nc) = (-&a, -&b, -&c); + + if a_vec.len() == 1 { + let a = a_vec[0]; + assert_op!(c - a == b); + assert_op!(a - c == nb); + assert_op!(a - nb == c); + assert_op!(nb - a == nc); + } + + if b_vec.len() == 1 { + let b = b_vec[0]; + assert_op!(c - b == a); + assert_op!(b - c == na); + assert_op!(b - na == c); + assert_op!(na - b == nc); + } + + if c_vec.len() == 1 { + let c = c_vec[0]; + assert_op!(c - a == b); + assert_op!(a - c == nb); + assert_op!(c - b == a); + assert_op!(b - c == na); + } + } +} + const M: u32 = ::std::u32::MAX; static MUL_TRIPLES: &'static [(&'static [BigDigit], &'static [BigDigit],