From 6afac825d9dc0504667d5db2e54b9d377292330e Mon Sep 17 00:00:00 2001 From: Josh Stone Date: Tue, 11 Jul 2017 17:27:19 -0700 Subject: [PATCH] test and fix more scalar add cases --- bigint/src/biguint.rs | 37 ++++++++++++++------------ bigint/src/tests/bigint.rs | 52 +++++++++++++++++++++++++++++++++---- bigint/src/tests/biguint.rs | 32 +++++++++++++++++++---- 3 files changed, 94 insertions(+), 27 deletions(-) diff --git a/bigint/src/biguint.rs b/bigint/src/biguint.rs index db00a72..2c2d046 100644 --- a/bigint/src/biguint.rs +++ b/bigint/src/biguint.rs @@ -405,13 +405,15 @@ impl Add for BigUint { #[inline] fn add(mut self, other: BigDigit) -> BigUint { - if self.data.len() == 0 && other != 0 { - self.data.push(0); - } + if other != 0 { + if self.data.len() == 0 { + self.data.push(0); + } - let carry = __add2(&mut self.data, &[other]); - if carry != 0 { - self.data.push(carry); + let carry = __add2(&mut self.data, &[other]); + if carry != 0 { + self.data.push(carry); + } } self } @@ -422,19 +424,20 @@ impl Add for BigUint { #[inline] fn add(mut self, other: DoubleBigDigit) -> BigUint { - if self.data.len() == 0 && other != 0 { - self.data.push(0); - } - if self.data.len() == 1 && other > BigDigit::max_value() as DoubleBigDigit { - self.data.push(0); - } - let (hi, lo) = big_digit::from_doublebigdigit(other); - let carry = __add2(&mut self.data, &[lo, hi]); - if carry != 0 { - self.data.push(carry); + if hi == 0 { + self + lo + } else { + while self.data.len() < 2 { + self.data.push(0); + } + + let carry = __add2(&mut self.data, &[lo, hi]); + if carry != 0 { + self.data.push(carry); + } + self } - self } } diff --git a/bigint/src/tests/bigint.rs b/bigint/src/tests/bigint.rs index ac3d543..14f86a1 100644 --- a/bigint/src/tests/bigint.rs +++ b/bigint/src/tests/bigint.rs @@ -1,4 +1,4 @@ -use {BigDigit, BigUint, big_digit}; +use {BigDigit, DoubleBigDigit, BigUint, big_digit}; use {Sign, BigInt, RandBigInt, ToBigInt}; use Sign::{Minus, NoSign, Plus}; @@ -532,6 +532,16 @@ const SUM_TRIPLES: &'static [(&'static [BigDigit], (&[1, 1, 1], &[N1, N1], &[0, 1, 2]), (&[2, 2, 1], &[N1, N2], &[1, 1, 2])]; +fn get_scalar(vec: &[BigDigit]) -> BigDigit { + vec.get(0).map_or(0, BigDigit::clone) +} + +fn get_scalar_double(vec: &[BigDigit]) -> DoubleBigDigit { + let lo = vec.get(0).map_or(0, BigDigit::clone); + let hi = vec.get(1).map_or(0, BigDigit::clone); + big_digit::to_doublebigdigit(hi, lo) +} + #[test] fn test_add() { for elm in SUM_TRIPLES.iter() { @@ -561,8 +571,8 @@ fn test_scalar_add() { let c = BigInt::from_slice(Plus, c_vec); let (na, nb, nc) = (-&a, -&b, -&c); - if a_vec.len() == 1 { - let a = a_vec[0]; + if a_vec.len() <= 1 { + let a = get_scalar(a_vec); assert_op!(a + b == c); assert_op!(b + a == c); assert_op!(a + nc == nb); @@ -577,8 +587,24 @@ fn test_scalar_add() { } } - if b_vec.len() == 1 { - let b = b_vec[0]; + if a_vec.len() <= 2 { + let a = get_scalar_double(a_vec); + assert_op!(a + b == c); + assert_op!(b + a == c); + assert_op!(a + nc == nb); + assert_op!(nc + a == nb); + + if a <= i64::max_value() as u64 { + let na = -(a as i64); + assert_op!(na + nb == nc); + assert_op!(nb + na == nc); + assert_op!(na + c == b); + assert_op!(c + na == b); + } + } + + if b_vec.len() <= 1 { + let b = get_scalar(b_vec); assert_op!(a + b == c); assert_op!(b + a == c); assert_op!(b + nc == na); @@ -592,6 +618,22 @@ fn test_scalar_add() { assert_op!(c + nb == a); } } + + if b_vec.len() <= 2 { + let b = get_scalar_double(b_vec); + assert_op!(a + b == c); + assert_op!(b + a == c); + assert_op!(b + nc == na); + assert_op!(nc + b == na); + + if b <= i64::max_value() as u64 { + let nb = -(b as i64); + assert_op!(na + nb == nc); + assert_op!(nb + na == nc); + assert_op!(nb + c == a); + assert_op!(c + nb == a); + } + } } } diff --git a/bigint/src/tests/biguint.rs b/bigint/src/tests/biguint.rs index c39aa40..b8b3492 100644 --- a/bigint/src/tests/biguint.rs +++ b/bigint/src/tests/biguint.rs @@ -1,5 +1,5 @@ use integer::Integer; -use {BigDigit, BigUint, ToBigUint, big_digit}; +use {BigDigit, DoubleBigDigit, BigUint, ToBigUint, big_digit}; use {BigInt, RandBigInt, ToBigInt}; use Sign::Plus; @@ -677,6 +677,16 @@ const SUM_TRIPLES: &'static [(&'static [BigDigit], (&[1, 1, 1], &[N1, N1], &[0, 1, 2]), (&[2, 2, 1], &[N1, N2], &[1, 1, 2])]; +fn get_scalar(vec: &[BigDigit]) -> BigDigit { + vec.get(0).map_or(0, BigDigit::clone) +} + +fn get_scalar_double(vec: &[BigDigit]) -> DoubleBigDigit { + let lo = vec.get(0).map_or(0, BigDigit::clone); + let hi = vec.get(1).map_or(0, BigDigit::clone); + big_digit::to_doublebigdigit(hi, lo) +} + #[test] fn test_add() { for elm in SUM_TRIPLES.iter() { @@ -698,14 +708,26 @@ fn test_scalar_add() { let b = BigUint::from_slice(b_vec); let c = BigUint::from_slice(c_vec); - if a_vec.len() == 1 { - let a = a_vec[0]; + if a_vec.len() <= 1 { + let a = get_scalar(a_vec); assert_op!(a + b == c); assert_op!(b + a == c); } - if b_vec.len() == 1 { - let b = b_vec[0]; + if a_vec.len() <= 2 { + let a = get_scalar_double(a_vec); + assert_op!(a + b == c); + assert_op!(b + a == c); + } + + if b_vec.len() <= 1 { + let b = get_scalar(b_vec); + assert_op!(a + b == c); + assert_op!(b + a == c); + } + + if b_vec.len() <= 2 { + let b = get_scalar_double(b_vec); assert_op!(a + b == c); assert_op!(b + a == c); }