diff --git a/bigint/src/algorithms.rs b/bigint/src/algorithms.rs index 0afd6b1..604fee2 100644 --- a/bigint/src/algorithms.rs +++ b/bigint/src/algorithms.rs @@ -85,6 +85,15 @@ pub fn mac_with_carry(a: BigDigit, b: BigDigit, c: BigDigit, carry: &mut BigDigi lo } +#[inline] +pub fn mul_with_carry(a: BigDigit, b: BigDigit, carry: &mut BigDigit) -> BigDigit { + let (hi, lo) = big_digit::from_doublebigdigit((a as DoubleBigDigit) * (b as DoubleBigDigit) + + (*carry as DoubleBigDigit)); + + *carry = hi; + lo +} + /// Divide a two digit numerator by a one digit divisor, returns quotient and remainder: /// /// Note: the caller must ensure that both the quotient and remainder will fit into a single digit. @@ -377,6 +386,14 @@ pub fn mul3(x: &[BigDigit], y: &[BigDigit]) -> BigUint { prod.normalize() } +pub fn scalar_mul(a: &mut [BigDigit], b: BigDigit) -> BigDigit { + let mut carry = 0; + for a in a.iter_mut() { + *a = mul_with_carry(*a, b, &mut carry); + } + carry +} + pub fn div_rem(u: &BigUint, d: &BigUint) -> (BigUint, BigUint) { if d.is_zero() { panic!() @@ -416,7 +433,7 @@ pub fn div_rem(u: &BigUint, d: &BigUint) -> (BigUint, BigUint) { // q0, our guess, is calculated by dividing the last few digits of a by the last digit of b // - this should give us a guess that is "close" to the actual quotient, but is possibly // greater than the actual quotient. If q0 * b > a, we simply use iterated subtraction - // until we have a guess such that q0 & b <= a. + // until we have a guess such that q0 * b <= a. // let bn = *b.data.last().unwrap(); diff --git a/bigint/src/bigint.rs b/bigint/src/bigint.rs index b86ca52..259f619 100644 --- a/bigint/src/bigint.rs +++ b/bigint/src/bigint.rs @@ -23,11 +23,14 @@ use self::Sign::{Minus, NoSign, Plus}; use super::ParseBigIntError; use super::big_digit; -use super::big_digit::BigDigit; +use super::big_digit::{BigDigit, DoubleBigDigit}; use biguint; use biguint::to_str_radix_reversed; use biguint::BigUint; +use UsizePromotion; +use IsizePromotion; + #[cfg(test)] #[path = "tests/bigint.rs"] mod bigint_tests; @@ -299,6 +302,26 @@ impl Signed for BigInt { } } +// A convenience method for getting the absolute value of an i32 in a u32. +#[inline] +fn i32_abs_as_u32(a: i32) -> u32 { + if a == i32::min_value() { + a as u32 + } else { + a.abs() as u32 + } +} + +// A convenience method for getting the absolute value of an i64 in a u64. +#[inline] +fn i64_abs_as_u64(a: i64) -> u64 { + if a == i64::min_value() { + a as u64 + } else { + a.abs() as u64 + } +} + // We want to forward to BigUint::add, but it's not clear how that will go until // we compare both sign and magnitude. So we duplicate this body for every // val/ref combination, deferring that decision to BigUint's own forwarding. @@ -362,6 +385,75 @@ impl Add for BigInt { } } +promote_all_scalars!(impl Add for BigInt, add); +forward_all_scalar_binop_to_val_val_commutative!(impl Add for BigInt, add); +forward_all_scalar_binop_to_val_val_commutative!(impl Add for BigInt, add); + +impl Add for BigInt { + type Output = BigInt; + + #[inline] + fn add(self, other: BigDigit) -> BigInt { + match self.sign { + NoSign => From::from(other), + Plus => BigInt::from_biguint(Plus, self.data + other), + Minus => + match self.data.cmp(&From::from(other)) { + Equal => Zero::zero(), + Less => BigInt::from_biguint(Plus, other - self.data), + Greater => BigInt::from_biguint(Minus, self.data - other), + } + } + } +} + +impl Add for BigInt { + type Output = BigInt; + + #[inline] + fn add(self, other: DoubleBigDigit) -> BigInt { + match self.sign { + NoSign => From::from(other), + Plus => BigInt::from_biguint(Plus, self.data + other), + Minus => + match self.data.cmp(&From::from(other)) { + Equal => Zero::zero(), + Less => BigInt::from_biguint(Plus, other - self.data), + Greater => BigInt::from_biguint(Minus, self.data - other), + } + } + } +} + +forward_all_scalar_binop_to_val_val_commutative!(impl Add for BigInt, add); +forward_all_scalar_binop_to_val_val_commutative!(impl Add for BigInt, add); + +impl Add for BigInt { + type Output = BigInt; + + #[inline] + fn add(self, other: i32) -> BigInt { + if other >= 0 { + self + other as u32 + } else { + self - i32_abs_as_u32(other) + } + } +} + +impl Add for BigInt { + type Output = BigInt; + + #[inline] + fn add(self, other: i64) -> BigInt { + if other >= 0 { + self + other as u64 + } else { + self - i64_abs_as_u64(other) + } + } +} + // We want to forward to BigUint::sub, but it's not clear how that will go until // we compare both sign and magnitude. So we duplicate this body for every // val/ref combination, deferring that decision to BigUint's own forwarding. @@ -425,6 +517,119 @@ impl Sub for BigInt { } } +promote_all_scalars!(impl Sub for BigInt, sub); +forward_all_scalar_binop_to_val_val!(impl Sub for BigInt, sub); +forward_all_scalar_binop_to_val_val!(impl Sub for BigInt, sub); + +impl Sub for BigInt { + type Output = BigInt; + + #[inline] + fn sub(self, other: BigDigit) -> BigInt { + match self.sign { + NoSign => BigInt::from_biguint(Minus, From::from(other)), + Minus => BigInt::from_biguint(Minus, self.data + other), + Plus => + match self.data.cmp(&From::from(other)) { + Equal => Zero::zero(), + Greater => BigInt::from_biguint(Plus, self.data - other), + Less => BigInt::from_biguint(Minus, other - self.data), + } + } + } +} + +impl Sub for BigDigit { + type Output = BigInt; + + #[inline] + fn sub(self, other: BigInt) -> BigInt { + -(other - self) + } +} + +impl Sub for BigInt { + type Output = BigInt; + + #[inline] + fn sub(self, other: DoubleBigDigit) -> BigInt { + match self.sign { + NoSign => BigInt::from_biguint(Minus, From::from(other)), + Minus => BigInt::from_biguint(Minus, self.data + other), + Plus => + match self.data.cmp(&From::from(other)) { + Equal => Zero::zero(), + Greater => BigInt::from_biguint(Plus, self.data - other), + Less => BigInt::from_biguint(Minus, other - self.data), + } + } + } +} + +impl Sub for DoubleBigDigit { + type Output = BigInt; + + #[inline] + fn sub(self, other: BigInt) -> BigInt { + -(other - self) + } +} + +forward_all_scalar_binop_to_val_val!(impl Sub for BigInt, sub); +forward_all_scalar_binop_to_val_val!(impl Sub for BigInt, sub); + +impl Sub for BigInt { + type Output = BigInt; + + #[inline] + fn sub(self, other: i32) -> BigInt { + if other >= 0 { + self - other as u32 + } else { + self + i32_abs_as_u32(other) + } + } +} + +impl Sub for i32 { + type Output = BigInt; + + #[inline] + fn sub(self, other: BigInt) -> BigInt { + if self >= 0 { + self as u32 - other + } else { + -other - i32_abs_as_u32(self) + } + } +} + +impl Sub for BigInt { + type Output = BigInt; + + #[inline] + fn sub(self, other: i64) -> BigInt { + if other >= 0 { + self - other as u64 + } else { + self + i64_abs_as_u64(other) + } + } +} + +impl Sub for i64 { + type Output = BigInt; + + #[inline] + fn sub(self, other: BigInt) -> BigInt { + if self >= 0 { + self as u64 - other + } else { + -other - i64_abs_as_u64(self) + } + } +} + forward_all_binop_to_ref_ref!(impl Mul for BigInt, mul); impl<'a, 'b> Mul<&'b BigInt> for &'a BigInt { @@ -436,6 +641,57 @@ impl<'a, 'b> Mul<&'b BigInt> for &'a BigInt { } } +promote_all_scalars!(impl Mul for BigInt, mul); +forward_all_scalar_binop_to_val_val_commutative!(impl Mul for BigInt, mul); +forward_all_scalar_binop_to_val_val_commutative!(impl Mul for BigInt, mul); + +impl Mul for BigInt { + type Output = BigInt; + + #[inline] + fn mul(self, other: BigDigit) -> BigInt { + BigInt::from_biguint(self.sign, self.data * other) + } +} + +impl Mul for BigInt { + type Output = BigInt; + + #[inline] + fn mul(self, other: DoubleBigDigit) -> BigInt { + BigInt::from_biguint(self.sign, self.data * other) + } +} + +forward_all_scalar_binop_to_val_val_commutative!(impl Mul for BigInt, mul); +forward_all_scalar_binop_to_val_val_commutative!(impl Mul for BigInt, mul); + +impl Mul for BigInt { + type Output = BigInt; + + #[inline] + fn mul(self, other: i32) -> BigInt { + if other >= 0 { + self * other as u32 + } else { + -(self * i32_abs_as_u32(other)) + } + } +} + +impl Mul for BigInt { + type Output = BigInt; + + #[inline] + fn mul(self, other: i64) -> BigInt { + if other >= 0 { + self * other as u64 + } else { + -(self * i64_abs_as_u64(other)) + } + } +} + forward_all_binop_to_ref_ref!(impl Div for BigInt, div); impl<'a, 'b> Div<&'b BigInt> for &'a BigInt { @@ -448,6 +704,101 @@ impl<'a, 'b> Div<&'b BigInt> for &'a BigInt { } } +promote_all_scalars!(impl Div for BigInt, div); +forward_all_scalar_binop_to_val_val!(impl Div for BigInt, div); +forward_all_scalar_binop_to_val_val!(impl Div for BigInt, div); + +impl Div for BigInt { + type Output = BigInt; + + #[inline] + fn div(self, other: BigDigit) -> BigInt { + BigInt::from_biguint(self.sign, self.data / other) + } +} + +impl Div for BigDigit { + type Output = BigInt; + + #[inline] + fn div(self, other: BigInt) -> BigInt { + BigInt::from_biguint(other.sign, self / other.data) + } +} + +impl Div for BigInt { + type Output = BigInt; + + #[inline] + fn div(self, other: DoubleBigDigit) -> BigInt { + BigInt::from_biguint(self.sign, self.data / other) + } +} + +impl Div for DoubleBigDigit { + type Output = BigInt; + + #[inline] + fn div(self, other: BigInt) -> BigInt { + BigInt::from_biguint(other.sign, self / other.data) + } +} + +forward_all_scalar_binop_to_val_val!(impl Div for BigInt, div); +forward_all_scalar_binop_to_val_val!(impl Div for BigInt, div); + +impl Div for BigInt { + type Output = BigInt; + + #[inline] + fn div(self, other: i32) -> BigInt { + if other >= 0 { + self / other as u32 + } else { + -(self / i32_abs_as_u32(other)) + } + } +} + +impl Div for i32 { + type Output = BigInt; + + #[inline] + fn div(self, other: BigInt) -> BigInt { + if self >= 0 { + self as u32 / other + } else { + -(i32_abs_as_u32(self) / other) + } + } +} + +impl Div for BigInt { + type Output = BigInt; + + #[inline] + fn div(self, other: i64) -> BigInt { + if other >= 0 { + self / other as u64 + } else { + -(self / i64_abs_as_u64(other)) + } + } +} + +impl Div for i64 { + type Output = BigInt; + + #[inline] + fn div(self, other: BigInt) -> BigInt { + if self >= 0 { + self as u64 / other + } else { + -(i64_abs_as_u64(self) / other) + } + } +} + forward_all_binop_to_ref_ref!(impl Rem for BigInt, rem); impl<'a, 'b> Rem<&'b BigInt> for &'a BigInt { @@ -460,6 +811,101 @@ impl<'a, 'b> Rem<&'b BigInt> for &'a BigInt { } } +promote_all_scalars!(impl Rem for BigInt, rem); +forward_all_scalar_binop_to_val_val!(impl Rem for BigInt, rem); +forward_all_scalar_binop_to_val_val!(impl Rem for BigInt, rem); + +impl Rem for BigInt { + type Output = BigInt; + + #[inline] + fn rem(self, other: BigDigit) -> BigInt { + BigInt::from_biguint(self.sign, self.data % other) + } +} + +impl Rem for BigDigit { + type Output = BigInt; + + #[inline] + fn rem(self, other: BigInt) -> BigInt { + BigInt::from_biguint(Plus, self % other.data) + } +} + +impl Rem for BigInt { + type Output = BigInt; + + #[inline] + fn rem(self, other: DoubleBigDigit) -> BigInt { + BigInt::from_biguint(self.sign, self.data % other) + } +} + +impl Rem for DoubleBigDigit { + type Output = BigInt; + + #[inline] + fn rem(self, other: BigInt) -> BigInt { + BigInt::from_biguint(Plus, self % other.data) + } +} + +forward_all_scalar_binop_to_val_val!(impl Rem for BigInt, rem); +forward_all_scalar_binop_to_val_val!(impl Rem for BigInt, rem); + +impl Rem for BigInt { + type Output = BigInt; + + #[inline] + fn rem(self, other: i32) -> BigInt { + if other >= 0 { + self % other as u32 + } else { + self % i32_abs_as_u32(other) + } + } +} + +impl Rem for i32 { + type Output = BigInt; + + #[inline] + fn rem(self, other: BigInt) -> BigInt { + if self >= 0 { + self as u32 % other + } else { + -(i32_abs_as_u32(self) % other) + } + } +} + +impl Rem for BigInt { + type Output = BigInt; + + #[inline] + fn rem(self, other: i64) -> BigInt { + if other >= 0 { + self % other as u64 + } else { + self % i64_abs_as_u64(other) + } + } +} + +impl Rem for i64 { + type Output = BigInt; + + #[inline] + fn rem(self, other: BigInt) -> BigInt { + if self >= 0 { + self as u64 % other + } else { + -(i64_abs_as_u64(self) % other) + } + } +} + impl Neg for BigInt { type Output = BigInt; diff --git a/bigint/src/biguint.rs b/bigint/src/biguint.rs index 340aaae..2c2d046 100644 --- a/bigint/src/biguint.rs +++ b/bigint/src/biguint.rs @@ -22,11 +22,13 @@ mod algorithms; pub use self::algorithms::big_digit; pub use self::big_digit::{BigDigit, DoubleBigDigit, ZERO_BIG_DIGIT}; -use self::algorithms::{mac_with_carry, mul3, div_rem, div_rem_digit}; +use self::algorithms::{mac_with_carry, mul3, scalar_mul, div_rem, div_rem_digit}; use self::algorithms::{__add2, add2, sub2, sub2rev}; use self::algorithms::{biguint_shl, biguint_shr}; use self::algorithms::{cmp_slice, fls, ilog2}; +use UsizePromotion; + use ParseBigIntError; #[cfg(test)] @@ -394,6 +396,51 @@ impl<'a> Add<&'a BigUint> for BigUint { } } +promote_unsigned_scalars!(impl Add for BigUint, add); +forward_all_scalar_binop_to_val_val_commutative!(impl Add for BigUint, add); +forward_all_scalar_binop_to_val_val_commutative!(impl Add for BigUint, add); + +impl Add for BigUint { + type Output = BigUint; + + #[inline] + fn add(mut self, other: BigDigit) -> BigUint { + if other != 0 { + if self.data.len() == 0 { + self.data.push(0); + } + + let carry = __add2(&mut self.data, &[other]); + if carry != 0 { + self.data.push(carry); + } + } + self + } +} + +impl Add for BigUint { + type Output = BigUint; + + #[inline] + fn add(mut self, other: DoubleBigDigit) -> BigUint { + let (hi, lo) = big_digit::from_doublebigdigit(other); + if hi == 0 { + self + lo + } else { + while self.data.len() < 2 { + self.data.push(0); + } + + let carry = __add2(&mut self.data, &[lo, hi]); + if carry != 0 { + self.data.push(carry); + } + self + } + } +} + forward_val_val_binop!(impl Sub for BigUint, sub); forward_ref_ref_binop!(impl Sub for BigUint, sub); @@ -420,6 +467,60 @@ impl<'a> Sub for &'a BigUint { } } +promote_unsigned_scalars!(impl Sub for BigUint, sub); +forward_all_scalar_binop_to_val_val!(impl Sub for BigUint, sub); +forward_all_scalar_binop_to_val_val!(impl Sub for BigUint, sub); + +impl Sub for BigUint { + type Output = BigUint; + + #[inline] + fn sub(mut self, other: BigDigit) -> BigUint { + sub2(&mut self.data[..], &[other]); + self.normalize() + } +} + +impl Sub for BigDigit { + type Output = BigUint; + + #[inline] + fn sub(self, mut other: BigUint) -> BigUint { + if other.data.len() == 0 { + other.data.push(0); + } + + sub2rev(&[self], &mut other.data[..]); + other.normalize() + } +} + +impl Sub for BigUint { + type Output = BigUint; + + #[inline] + fn sub(mut self, other: DoubleBigDigit) -> BigUint { + let (hi, lo) = big_digit::from_doublebigdigit(other); + sub2(&mut self.data[..], &[lo, hi]); + self.normalize() + } +} + +impl Sub for DoubleBigDigit { + type Output = BigUint; + + #[inline] + fn sub(self, mut other: BigUint) -> BigUint { + while other.data.len() < 2 { + other.data.push(0); + } + + let (hi, lo) = big_digit::from_doublebigdigit(self); + sub2rev(&[lo, hi], &mut other.data[..]); + other.normalize() + } +} + forward_all_binop_to_ref_ref!(impl Mul for BigUint, mul); impl<'a, 'b> Mul<&'b BigUint> for &'a BigUint { @@ -431,6 +532,44 @@ impl<'a, 'b> Mul<&'b BigUint> for &'a BigUint { } } +promote_unsigned_scalars!(impl Mul for BigUint, mul); +forward_all_scalar_binop_to_val_val_commutative!(impl Mul for BigUint, mul); +forward_all_scalar_binop_to_val_val_commutative!(impl Mul for BigUint, mul); + +impl Mul for BigUint { + type Output = BigUint; + + #[inline] + fn mul(mut self, other: BigDigit) -> BigUint { + if other == 0 { + self.data.clear(); + } else { + let carry = scalar_mul(&mut self.data[..], other); + if carry != 0 { + self.data.push(carry); + } + } + self + } +} + +impl Mul for BigUint { + type Output = BigUint; + + #[inline] + fn mul(mut self, other: DoubleBigDigit) -> BigUint { + if other == 0 { + self.data.clear(); + self + } else if other <= BigDigit::max_value() as DoubleBigDigit { + self * other as BigDigit + } else { + let (hi, lo) = big_digit::from_doublebigdigit(other); + mul3(&self.data[..], &[lo, hi]) + } + } +} + forward_all_binop_to_ref_ref!(impl Div for BigUint, div); impl<'a, 'b> Div<&'b BigUint> for &'a BigUint { @@ -439,7 +578,58 @@ impl<'a, 'b> Div<&'b BigUint> for &'a BigUint { #[inline] fn div(self, other: &BigUint) -> BigUint { let (q, _) = self.div_rem(other); - return q; + q + } +} + +promote_unsigned_scalars!(impl Div for BigUint, div); +forward_all_scalar_binop_to_val_val!(impl Div for BigUint, div); +forward_all_scalar_binop_to_val_val!(impl Div for BigUint, div); + +impl Div for BigUint { + type Output = BigUint; + + #[inline] + fn div(self, other: BigDigit) -> BigUint { + let (q, _) = div_rem_digit(self, other); + q + } +} + +impl Div for BigDigit { + type Output = BigUint; + + #[inline] + fn div(self, other: BigUint) -> BigUint { + match other.data.len() { + 0 => panic!(), + 1 => From::from(self / other.data[0]), + _ => Zero::zero(), + } + } +} + +impl Div for BigUint { + type Output = BigUint; + + #[inline] + fn div(self, other: DoubleBigDigit) -> BigUint { + let (q, _) = self.div_rem(&From::from(other)); + q + } +} + +impl Div for DoubleBigDigit { + type Output = BigUint; + + #[inline] + fn div(self, other: BigUint) -> BigUint { + match other.data.len() { + 0 => panic!(), + 1 => From::from(self / other.data[0] as u64), + 2 => From::from(self / big_digit::to_doublebigdigit(other.data[1], other.data[0])), + _ => Zero::zero(), + } } } @@ -451,7 +641,58 @@ impl<'a, 'b> Rem<&'b BigUint> for &'a BigUint { #[inline] fn rem(self, other: &BigUint) -> BigUint { let (_, r) = self.div_rem(other); - return r; + r + } +} + +promote_unsigned_scalars!(impl Rem for BigUint, rem); +forward_all_scalar_binop_to_val_val!(impl Rem for BigUint, rem); +forward_all_scalar_binop_to_val_val!(impl Rem for BigUint, rem); + +impl Rem for BigUint { + type Output = BigUint; + + #[inline] + fn rem(self, other: BigDigit) -> BigUint { + let (_, r) = div_rem_digit(self, other); + From::from(r) + } +} + +impl Rem for BigDigit { + type Output = BigUint; + + #[inline] + fn rem(self, other: BigUint) -> BigUint { + match other.data.len() { + 0 => panic!(), + 1 => From::from(self % other.data[0]), + _ => From::from(self) + } + } +} + +impl Rem for BigUint { + type Output = BigUint; + + #[inline] + fn rem(self, other: DoubleBigDigit) -> BigUint { + let (_, r) = self.div_rem(&From::from(other)); + r + } +} + +impl Rem for DoubleBigDigit { + type Output = BigUint; + + #[inline] + fn rem(self, other: BigUint) -> BigUint { + match other.data.len() { + 0 => panic!(), + 1 => From::from(self % other.data[0] as u64), + 2 => From::from(self % big_digit::to_doublebigdigit(other.data[0], other.data[1])), + _ => From::from(self), + } } } diff --git a/bigint/src/lib.rs b/bigint/src/lib.rs index 586eeec..02b9992 100644 --- a/bigint/src/lib.rs +++ b/bigint/src/lib.rs @@ -88,6 +88,16 @@ use std::error::Error; use std::num::ParseIntError; use std::fmt; +#[cfg(target_pointer_width = "32")] +type UsizePromotion = u32; +#[cfg(target_pointer_width = "64")] +type UsizePromotion = u64; + +#[cfg(target_pointer_width = "32")] +type IsizePromotion = i32; +#[cfg(target_pointer_width = "64")] +type IsizePromotion = i64; + #[derive(Debug, PartialEq)] pub enum ParseBigIntError { ParseInt(ParseIntError), diff --git a/bigint/src/macros.rs b/bigint/src/macros.rs index 39f45a4..7705f55 100644 --- a/bigint/src/macros.rs +++ b/bigint/src/macros.rs @@ -105,6 +105,125 @@ macro_rules! forward_ref_ref_binop_commutative { } } +macro_rules! forward_scalar_val_val_binop_commutative { + (impl $imp:ident<$scalar:ty> for $res:ty, $method: ident) => { + impl $imp<$res> for $scalar { + type Output = $res; + + #[inline] + fn $method(self, other: $res) -> $res { + $imp::$method(other, self) + } + } + } +} + +macro_rules! forward_scalar_val_ref_binop { + (impl $imp:ident<$scalar:ty> for $res:ty, $method:ident) => { + impl<'a> $imp<&'a $scalar> for $res { + type Output = $res; + + #[inline] + fn $method(self, other: &$scalar) -> $res { + $imp::$method(self, *other) + } + } + + impl<'a> $imp<$res> for &'a $scalar { + type Output = $res; + + #[inline] + fn $method(self, other: $res) -> $res { + $imp::$method(*self, other) + } + } + } +} + +macro_rules! forward_scalar_ref_val_binop { + (impl $imp:ident<$scalar:ty> for $res:ty, $method:ident) => { + impl<'a> $imp<$scalar> for &'a $res { + type Output = $res; + + #[inline] + fn $method(self, other: $scalar) -> $res { + $imp::$method(self.clone(), other) + } + } + + impl<'a> $imp<&'a $res> for $scalar { + type Output = $res; + + #[inline] + fn $method(self, other: &$res) -> $res { + $imp::$method(self, other.clone()) + } + } + } +} + +macro_rules! forward_scalar_ref_ref_binop { + (impl $imp:ident<$scalar:ty> for $res:ty, $method:ident) => { + impl<'a, 'b> $imp<&'b $scalar> for &'a $res { + type Output = $res; + + #[inline] + fn $method(self, other: &$scalar) -> $res { + $imp::$method(self.clone(), *other) + } + } + + impl<'a, 'b> $imp<&'a $res> for &'b $scalar { + type Output = $res; + + #[inline] + fn $method(self, other: &$res) -> $res { + $imp::$method(*self, other.clone()) + } + } + } +} + +macro_rules! promote_scalars { + (impl $imp:ident<$promo:ty> for $res:ty, $method:ident, $( $scalar:ty ),*) => { + $( + forward_all_scalar_binop_to_val_val!(impl $imp<$scalar> for $res, $method); + + impl $imp<$scalar> for $res { + type Output = $res; + + #[inline] + fn $method(self, other: $scalar) -> $res { + $imp::$method(self, other as $promo) + } + } + + impl $imp<$res> for $scalar { + type Output = $res; + + #[inline] + fn $method(self, other: $res) -> $res { + $imp::$method(self as $promo, other) + } + } + )* + } +} + +macro_rules! promote_unsigned_scalars { + (impl $imp:ident for $res:ty, $method:ident) => { + promote_scalars!(impl $imp for $res, $method, u8, u16); + promote_scalars!(impl $imp for $res, $method, usize); + } +} + +macro_rules! promote_signed_scalars { + (impl $imp:ident for $res:ty, $method:ident) => { + promote_scalars!(impl $imp for $res, $method, i8, i16); + promote_scalars!(impl $imp for $res, $method, isize); + } +} + // Forward everything to ref-ref, when reusing storage is not helpful macro_rules! forward_all_binop_to_ref_ref { (impl $imp:ident for $res:ty, $method:ident) => { @@ -131,3 +250,25 @@ macro_rules! forward_all_binop_to_val_ref_commutative { forward_ref_ref_binop_commutative!(impl $imp for $res, $method); }; } + +macro_rules! forward_all_scalar_binop_to_val_val { + (impl $imp:ident<$scalar:ty> for $res:ty, $method:ident) => { + forward_scalar_val_ref_binop!(impl $imp<$scalar> for $res, $method); + forward_scalar_ref_val_binop!(impl $imp<$scalar> for $res, $method); + forward_scalar_ref_ref_binop!(impl $imp<$scalar> for $res, $method); + } +} + +macro_rules! forward_all_scalar_binop_to_val_val_commutative { + (impl $imp:ident<$scalar:ty> for $res:ty, $method:ident) => { + forward_scalar_val_val_binop_commutative!(impl $imp<$scalar> for $res, $method); + forward_all_scalar_binop_to_val_val!(impl $imp<$scalar> for $res, $method); + } +} + +macro_rules! promote_all_scalars { + (impl $imp:ident for $res:ty, $method:ident) => { + promote_unsigned_scalars!(impl $imp for $res, $method); + promote_signed_scalars!(impl $imp for $res, $method); + } +} \ No newline at end of file diff --git a/bigint/src/tests/bigint.rs b/bigint/src/tests/bigint.rs index aa4319d..9fbebb0 100644 --- a/bigint/src/tests/bigint.rs +++ b/bigint/src/tests/bigint.rs @@ -24,6 +24,25 @@ macro_rules! assert_op { }; } +/// Assert that an op works for scalar left or right +macro_rules! assert_scalar_op { + (($($to:ident),*) $left:ident $op:tt $right:ident == $expected:expr) => { + $( + if let Some(left) = $left.$to() { + assert_op!(left $op $right == $expected); + } + if let Some(right) = $right.$to() { + assert_op!($left $op right == $expected); + } + )* + }; + ($left:ident $op:tt $right:ident == $expected:expr) => { + assert_scalar_op!((to_u8, to_u16, to_u32, to_u64, to_usize, + to_i8, to_i16, to_i32, to_i64, to_isize) + $left $op $right == $expected); + }; +} + #[test] fn test_from_biguint() { fn check(inp_s: Sign, inp_n: usize, ans_s: Sign, ans_n: usize) { @@ -552,6 +571,26 @@ fn test_add() { } } +#[test] +fn test_scalar_add() { + for elm in SUM_TRIPLES.iter() { + let (a_vec, b_vec, c_vec) = *elm; + let a = BigInt::from_slice(Plus, a_vec); + let b = BigInt::from_slice(Plus, b_vec); + let c = BigInt::from_slice(Plus, c_vec); + let (na, nb, nc) = (-&a, -&b, -&c); + + assert_scalar_op!(a + b == c); + assert_scalar_op!(b + a == c); + assert_scalar_op!(c + na == b); + assert_scalar_op!(c + nb == a); + assert_scalar_op!(a + nc == nb); + assert_scalar_op!(b + nc == na); + assert_scalar_op!(na + nb == nc); + assert_scalar_op!(a + na == Zero::zero()); + } +} + #[test] fn test_sub() { for elm in SUM_TRIPLES.iter() { @@ -572,6 +611,26 @@ fn test_sub() { } } +#[test] +fn test_scalar_sub() { + for elm in SUM_TRIPLES.iter() { + let (a_vec, b_vec, c_vec) = *elm; + let a = BigInt::from_slice(Plus, a_vec); + let b = BigInt::from_slice(Plus, b_vec); + let c = BigInt::from_slice(Plus, c_vec); + let (na, nb, nc) = (-&a, -&b, -&c); + + assert_scalar_op!(c - a == b); + assert_scalar_op!(c - b == a); + assert_scalar_op!(nb - a == nc); + assert_scalar_op!(na - b == nc); + assert_scalar_op!(b - na == c); + assert_scalar_op!(a - nb == c); + assert_scalar_op!(nc - na == nb); + assert_scalar_op!(a - a == Zero::zero()); + } +} + const M: u32 = ::std::u32::MAX; static MUL_TRIPLES: &'static [(&'static [BigDigit], &'static [BigDigit], @@ -603,6 +662,7 @@ static DIV_REM_QUADRUPLES: &'static [(&'static [BigDigit], &'static [BigDigit], &'static [BigDigit], &'static [BigDigit])] = &[(&[1], &[2], &[], &[1]), + (&[3], &[2], &[1], &[1]), (&[1, 1], &[2], &[M / 2 + 1], &[1]), (&[1, 1, 1], &[2], &[M / 2 + 1, M / 2 + 1], &[1]), (&[0, 1], &[N1], &[1], &[1]), @@ -637,6 +697,24 @@ fn test_mul() { } } +#[test] +fn test_scalar_mul() { + for elm in MUL_TRIPLES.iter() { + let (a_vec, b_vec, c_vec) = *elm; + let a = BigInt::from_slice(Plus, a_vec); + let b = BigInt::from_slice(Plus, b_vec); + let c = BigInt::from_slice(Plus, c_vec); + let (na, nb, nc) = (-&a, -&b, -&c); + + assert_scalar_op!(a * b == c); + assert_scalar_op!(b * a == c); + assert_scalar_op!(na * nb == c); + + assert_scalar_op!(na * b == nc); + assert_scalar_op!(nb * a == nc); + } +} + #[test] fn test_div_mod_floor() { fn check_sub(a: &BigInt, b: &BigInt, ans_d: &BigInt, ans_m: &BigInt) { @@ -743,6 +821,65 @@ fn test_div_rem() { } } +#[test] +fn test_scalar_div_rem() { + fn check_sub(a: &BigInt, b: BigDigit, ans_q: &BigInt, ans_r: &BigInt) { + let (q, r) = (a / b, a % b); + if !r.is_zero() { + assert_eq!(r.sign, a.sign); + } + assert!(r.abs() <= From::from(b)); + assert!(*a == b * &q + &r); + assert!(q == *ans_q); + assert!(r == *ans_r); + + let (a, b, ans_q, ans_r) = (a.clone(), b.clone(), ans_q.clone(), ans_r.clone()); + assert_op!(a / b == ans_q); + assert_op!(a % b == ans_r); + + if b <= i32::max_value() as u32 { + let nb = -(b as i32); + assert_op!(a / nb == -ans_q.clone()); + assert_op!(a % nb == ans_r); + } + } + + fn check(a: &BigInt, b: BigDigit, q: &BigInt, r: &BigInt) { + check_sub(a, b, q, r); + check_sub(&a.neg(), b, &q.neg(), &r.neg()); + } + + for elm in MUL_TRIPLES.iter() { + let (a_vec, b_vec, c_vec) = *elm; + let a = BigInt::from_slice(Plus, a_vec); + let b = BigInt::from_slice(Plus, b_vec); + let c = BigInt::from_slice(Plus, c_vec); + + if a_vec.len() == 1 && a_vec[0] != 0 { + let a = a_vec[0]; + check(&c, a, &b, &Zero::zero()); + } + + if b_vec.len() == 1 && b_vec[0] != 0 { + let b = b_vec[0]; + check(&c, b, &a, &Zero::zero()); + } + } + + for elm in DIV_REM_QUADRUPLES.iter() { + let (a_vec, b_vec, c_vec, d_vec) = *elm; + let a = BigInt::from_slice(Plus, a_vec); + let c = BigInt::from_slice(Plus, c_vec); + let d = BigInt::from_slice(Plus, d_vec); + + if b_vec.len() == 1 && b_vec[0] != 0 { + let b = b_vec[0]; + check(&a, b, &c, &d); + } + } + +} + #[test] fn test_checked_add() { for elm in SUM_TRIPLES.iter() { diff --git a/bigint/src/tests/biguint.rs b/bigint/src/tests/biguint.rs index 7c5a423..0f3b743 100644 --- a/bigint/src/tests/biguint.rs +++ b/bigint/src/tests/biguint.rs @@ -25,6 +25,24 @@ macro_rules! assert_op { }; } +/// Assert that an op works for scalar left or right +macro_rules! assert_scalar_op { + (($($to:ident),*) $left:ident $op:tt $right:ident == $expected:expr) => { + $( + if let Some(left) = $left.$to() { + assert_op!(left $op $right == $expected); + } + if let Some(right) = $right.$to() { + assert_op!($left $op right == $expected); + } + )* + }; + ($left:ident $op:tt $right:ident == $expected:expr) => { + assert_scalar_op!((to_u8, to_u16, to_u32, to_u64, to_usize) + $left $op $right == $expected); + }; +} + #[test] fn test_from_slice() { fn check(slice: &[BigDigit], data: &[BigDigit]) { @@ -690,6 +708,19 @@ fn test_add() { } } +#[test] +fn test_scalar_add() { + for elm in SUM_TRIPLES.iter() { + let (a_vec, b_vec, c_vec) = *elm; + let a = BigUint::from_slice(a_vec); + let b = BigUint::from_slice(b_vec); + let c = BigUint::from_slice(c_vec); + + assert_scalar_op!(a + b == c); + assert_scalar_op!(b + a == c); + } +} + #[test] fn test_sub() { for elm in SUM_TRIPLES.iter() { @@ -703,6 +734,19 @@ fn test_sub() { } } +#[test] +fn test_scalar_sub() { + for elm in SUM_TRIPLES.iter() { + let (a_vec, b_vec, c_vec) = *elm; + let a = BigUint::from_slice(a_vec); + let b = BigUint::from_slice(b_vec); + let c = BigUint::from_slice(c_vec); + + assert_scalar_op!(c - a == b); + assert_scalar_op!(c - b == a); + } +} + #[test] #[should_panic] fn test_sub_fail_on_underflow() { @@ -741,6 +785,7 @@ const DIV_REM_QUADRUPLES: &'static [(&'static [BigDigit], &'static [BigDigit], &'static [BigDigit], &'static [BigDigit])] = &[(&[1], &[2], &[], &[1]), + (&[3], &[2], &[1], &[1]), (&[1, 1], &[2], &[M / 2 + 1], &[1]), (&[1, 1, 1], &[2], &[M / 2 + 1, M / 2 + 1], &[1]), (&[0, 1], &[N1], &[1], &[1]), @@ -770,6 +815,19 @@ fn test_mul() { } } +#[test] +fn test_scalar_mul() { + for elm in MUL_TRIPLES.iter() { + let (a_vec, b_vec, c_vec) = *elm; + let a = BigUint::from_slice(a_vec); + let b = BigUint::from_slice(b_vec); + let c = BigUint::from_slice(c_vec); + + assert_scalar_op!(a * b == c); + assert_scalar_op!(b * a == c); + } +} + #[test] fn test_div_rem() { for elm in MUL_TRIPLES.iter() { @@ -805,6 +863,39 @@ fn test_div_rem() { } } +#[test] +fn test_scalar_div_rem() { + for elm in MUL_TRIPLES.iter() { + let (a_vec, b_vec, c_vec) = *elm; + let a = BigUint::from_slice(a_vec); + let b = BigUint::from_slice(b_vec); + let c = BigUint::from_slice(c_vec); + + if !a.is_zero() { + assert_scalar_op!(c / a == b); + assert_scalar_op!(c % a == Zero::zero()); + } + + if !b.is_zero() { + assert_scalar_op!(c / b == a); + assert_scalar_op!(c % b == Zero::zero()); + } + } + + for elm in DIV_REM_QUADRUPLES.iter() { + let (a_vec, b_vec, c_vec, d_vec) = *elm; + let a = BigUint::from_slice(a_vec); + let b = BigUint::from_slice(b_vec); + let c = BigUint::from_slice(c_vec); + let d = BigUint::from_slice(d_vec); + + if !b.is_zero() { + assert_scalar_op!(a / b == c); + assert_scalar_op!(a % b == d); + } + } +} + #[test] fn test_checked_add() { for elm in SUM_TRIPLES.iter() {