diff --git a/bigint/src/bigint.rs b/bigint/src/bigint.rs index 858b977..848eb09 100644 --- a/bigint/src/bigint.rs +++ b/bigint/src/bigint.rs @@ -299,6 +299,14 @@ impl Signed for BigInt { } } +// A convenience method for getting the absolute value of an i32 in a u32. +fn i32_abs_as_u32(a: i32) -> u32 { + match a.checked_abs() { + Some(x) => x as u32, + None => a as u32 + } +} + // We want to forward to BigUint::add, but it's not clear how that will go until // we compare both sign and magnitude. So we duplicate this body for every // val/ref combination, deferring that decision to BigUint's own forwarding. @@ -382,6 +390,21 @@ impl Add for BigInt { } } +forward_all_scalar_binop_to_val_val_commutative!(impl Add for BigInt, add); + +impl Add for BigInt { + type Output = BigInt; + + #[inline] + fn add(self, other: i32) -> BigInt { + if other >= 0 { + self + other as u32 + } else { + self - i32_abs_as_u32(other) + } + } +} + // We want to forward to BigUint::sub, but it's not clear how that will go until // we compare both sign and magnitude. So we duplicate this body for every // val/ref combination, deferring that decision to BigUint's own forwarding. @@ -474,6 +497,34 @@ impl Sub for BigDigit { } } +forward_all_scalar_binop_to_val_val!(impl Sub for BigInt, sub); + +impl Sub for BigInt { + type Output = BigInt; + + #[inline] + fn sub(self, other: i32) -> BigInt { + if other >= 0 { + self - other as u32 + } else { + self + i32_abs_as_u32(other) + } + } +} + +impl Sub for i32 { + type Output = BigInt; + + #[inline] + fn sub(self, other: BigInt) -> BigInt { + if self >= 0 { + self as u32 - other + } else { + -other - i32_abs_as_u32(self) + } + } +} + forward_all_binop_to_ref_ref!(impl Mul for BigInt, mul); impl<'a, 'b> Mul<&'b BigInt> for &'a BigInt { @@ -496,6 +547,21 @@ impl Mul for BigInt { } } +forward_all_scalar_binop_to_val_val_commutative!(impl Mul for BigInt, mul); + +impl Mul for BigInt { + type Output = BigInt; + + #[inline] + fn mul(self, other: i32) -> BigInt { + if other >= 0 { + self * other as u32 + } else { + -(self * i32_abs_as_u32(other)) + } + } +} + forward_all_binop_to_ref_ref!(impl Div for BigInt, div); impl<'a, 'b> Div<&'b BigInt> for &'a BigInt { @@ -528,6 +594,34 @@ impl Div for BigDigit { } } +forward_all_scalar_binop_to_val_val!(impl Div for BigInt, div); + +impl Div for BigInt { + type Output = BigInt; + + #[inline] + fn div(self, other: i32) -> BigInt { + if other >= 0 { + self / other as u32 + } else { + -(self / i32_abs_as_u32(other)) + } + } +} + +impl Div for i32 { + type Output = BigInt; + + #[inline] + fn div(self, other: BigInt) -> BigInt { + if self >= 0 { + self as u32 / other + } else { + -(i32_abs_as_u32(self) / other) + } + } +} + forward_all_binop_to_ref_ref!(impl Rem for BigInt, rem); impl<'a, 'b> Rem<&'b BigInt> for &'a BigInt { @@ -560,6 +654,34 @@ impl Rem for BigDigit { } } +forward_all_scalar_binop_to_val_val!(impl Rem for BigInt, rem); + +impl Rem for BigInt { + type Output = BigInt; + + #[inline] + fn rem(self, other: i32) -> BigInt { + if other >= 0 { + self % other as u32 + } else { + self % i32_abs_as_u32(other) + } + } +} + +impl Rem for i32 { + type Output = BigInt; + + #[inline] + fn rem(self, other: BigInt) -> BigInt { + if self >= 0 { + self as u32 % other + } else { + -(i32_abs_as_u32(self) % other) + } + } +} + impl Neg for BigInt { type Output = BigInt; diff --git a/bigint/src/tests/bigint.rs b/bigint/src/tests/bigint.rs index b8a48dc..ac3d543 100644 --- a/bigint/src/tests/bigint.rs +++ b/bigint/src/tests/bigint.rs @@ -567,6 +567,14 @@ fn test_scalar_add() { assert_op!(b + a == c); assert_op!(a + nc == nb); assert_op!(nc + a == nb); + + if a <= i32::max_value() as u32 { + let na = -(a as i32); + assert_op!(na + nb == nc); + assert_op!(nb + na == nc); + assert_op!(na + c == b); + assert_op!(c + na == b); + } } if b_vec.len() == 1 { @@ -575,6 +583,14 @@ fn test_scalar_add() { assert_op!(b + a == c); assert_op!(b + nc == na); assert_op!(nc + b == na); + + if b <= i32::max_value() as u32 { + let nb = -(b as i32); + assert_op!(na + nb == nc); + assert_op!(nb + na == nc); + assert_op!(nb + c == a); + assert_op!(c + nb == a); + } } } } @@ -614,6 +630,14 @@ fn test_scalar_sub() { assert_op!(a - c == nb); assert_op!(a - nb == c); assert_op!(nb - a == nc); + + if a <= i32::max_value() as u32 { + let na = -(a as i32); + assert_op!(nc - na == nb); + assert_op!(na - nc == b); + assert_op!(na - b == nc); + assert_op!(b - na == c); + } } if b_vec.len() == 1 { @@ -622,6 +646,14 @@ fn test_scalar_sub() { assert_op!(b - c == na); assert_op!(b - na == c); assert_op!(na - b == nc); + + if b <= i32::max_value() as u32 { + let nb = -(b as i32); + assert_op!(nc - nb == na); + assert_op!(nb - nc == a); + assert_op!(nb - a == nc); + assert_op!(a - nb == c); + } } if c_vec.len() == 1 { @@ -630,6 +662,14 @@ fn test_scalar_sub() { assert_op!(a - c == nb); assert_op!(c - b == a); assert_op!(b - c == na); + + if c <= i32::max_value() as u32 { + let nc = -(c as i32); + assert_op!(nc - na == nb); + assert_op!(na - nc == b); + assert_op!(nc - nb == na); + assert_op!(nb - nc == a); + } } } } @@ -665,6 +705,7 @@ static DIV_REM_QUADRUPLES: &'static [(&'static [BigDigit], &'static [BigDigit], &'static [BigDigit], &'static [BigDigit])] = &[(&[1], &[2], &[], &[1]), + (&[3], &[2], &[1], &[1]), (&[1, 1], &[2], &[M / 2 + 1], &[1]), (&[1, 1, 1], &[2], &[M / 2 + 1, M / 2 + 1], &[1]), (&[0, 1], &[N1], &[1], &[1]), @@ -714,6 +755,14 @@ fn test_scalar_mul() { assert_op!(a * b == c); assert_op!(nb * a == nc); assert_op!(a * nb == nc); + + if a <= i32::max_value() as u32 { + let na = -(a as i32); + assert_op!(nb * na == c); + assert_op!(na * nb == c); + assert_op!(b * na == nc); + assert_op!(na * b == nc); + } } if b_vec.len() == 1 { @@ -722,6 +771,14 @@ fn test_scalar_mul() { assert_op!(b * a == c); assert_op!(na * b == nc); assert_op!(b * na == nc); + + if b <= i32::max_value() as u32 { + let nb = -(b as i32); + assert_op!(na * nb == c); + assert_op!(nb * na == c); + assert_op!(a * nb == nc); + assert_op!(nb * a == nc); + } } } } @@ -847,6 +904,12 @@ fn test_scalar_div_rem() { let (a, b, ans_q, ans_r) = (a.clone(), b.clone(), ans_q.clone(), ans_r.clone()); assert_op!(a / b == ans_q); assert_op!(a % b == ans_r); + + if b <= i32::max_value() as u32 { + let nb = -(b as i32); + assert_op!(a / nb == -ans_q.clone()); + assert_op!(a % nb == ans_r); + } } fn check(a: &BigInt, b: BigDigit, q: &BigInt, r: &BigInt) {