Improve multiply performance
The main idea here is to do as much as possible with slices, instead of allocating new BigUints (= heap allocations). Current performance: multiply_0: 10,507 ns/iter (+/- 987) multiply_1: 2,788,734 ns/iter (+/- 100,079) multiply_2: 69,923,515 ns/iter (+/- 4,550,902) After this patch, we get: multiply_0: 364 ns/iter (+/- 62) multiply_1: 34,085 ns/iter (+/- 1,179) multiply_2: 3,753,883 ns/iter (+/- 46,876)
This commit is contained in:
parent
496ae0337c
commit
08b0022aab
285
src/bigint.rs
285
src/bigint.rs
|
@ -148,6 +148,16 @@ fn sbb(a: BigDigit, b: BigDigit, borrow: &mut BigDigit) -> BigDigit {
|
||||||
lo
|
lo
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn mac_with_carry(a: BigDigit, b: BigDigit, c: BigDigit, carry: &mut BigDigit) -> BigDigit {
|
||||||
|
let (hi, lo) = big_digit::from_doublebigdigit(
|
||||||
|
(a as DoubleBigDigit) +
|
||||||
|
(b as DoubleBigDigit) * (c as DoubleBigDigit) +
|
||||||
|
(*carry as DoubleBigDigit));
|
||||||
|
*carry = hi;
|
||||||
|
lo
|
||||||
|
}
|
||||||
|
|
||||||
/// A big unsigned integer type.
|
/// A big unsigned integer type.
|
||||||
///
|
///
|
||||||
/// A `BigUint`-typed value `BigUint { data: vec!(a, b, c) }` represents a number
|
/// A `BigUint`-typed value `BigUint { data: vec!(a, b, c) }` represents a number
|
||||||
|
@ -172,18 +182,25 @@ impl PartialOrd for BigUint {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn cmp_slice(a: &[BigDigit], b: &[BigDigit]) -> Ordering {
|
||||||
|
debug_assert!(a.last() != Some(&0));
|
||||||
|
debug_assert!(b.last() != Some(&0));
|
||||||
|
|
||||||
|
let (a_len, b_len) = (a.len(), b.len());
|
||||||
|
if a_len < b_len { return Less; }
|
||||||
|
if a_len > b_len { return Greater; }
|
||||||
|
|
||||||
|
for (&ai, &bi) in a.iter().rev().zip(b.iter().rev()) {
|
||||||
|
if ai < bi { return Less; }
|
||||||
|
if ai > bi { return Greater; }
|
||||||
|
}
|
||||||
|
return Equal;
|
||||||
|
}
|
||||||
|
|
||||||
impl Ord for BigUint {
|
impl Ord for BigUint {
|
||||||
#[inline]
|
#[inline]
|
||||||
fn cmp(&self, other: &BigUint) -> Ordering {
|
fn cmp(&self, other: &BigUint) -> Ordering {
|
||||||
let (s_len, o_len) = (self.data.len(), other.data.len());
|
cmp_slice(&self.data[..], &other.data[..])
|
||||||
if s_len < o_len { return Less; }
|
|
||||||
if s_len > o_len { return Greater; }
|
|
||||||
|
|
||||||
for (&self_i, &other_i) in self.data.iter().rev().zip(other.data.iter().rev()) {
|
|
||||||
if self_i < other_i { return Less; }
|
|
||||||
if self_i > other_i { return Greater; }
|
|
||||||
}
|
|
||||||
return Equal;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -608,80 +625,202 @@ impl<'a> Sub<&'a BigUint> for BigUint {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn sub_sign(a: &[BigDigit], b: &[BigDigit]) -> BigInt {
|
||||||
|
// Normalize:
|
||||||
|
let a = &a[..a.iter().rposition(|&x| x != 0).map_or(0, |i| i + 1)];
|
||||||
|
let b = &b[..b.iter().rposition(|&x| x != 0).map_or(0, |i| i + 1)];
|
||||||
|
|
||||||
forward_all_binop_to_val_ref_commutative!(impl Mul for BigUint, mul);
|
match cmp_slice(a, b) {
|
||||||
|
Greater => {
|
||||||
impl<'a> Mul<&'a BigUint> for BigUint {
|
let mut ret = BigUint::from_slice(a);
|
||||||
type Output = BigUint;
|
sub2(&mut ret.data[..], b);
|
||||||
|
BigInt::from_biguint(Plus, ret.normalize())
|
||||||
fn mul(self, other: &BigUint) -> BigUint {
|
},
|
||||||
if self.is_zero() || other.is_zero() { return Zero::zero(); }
|
Less => {
|
||||||
|
let mut ret = BigUint::from_slice(b);
|
||||||
let (s_len, o_len) = (self.data.len(), other.data.len());
|
sub2(&mut ret.data[..], a);
|
||||||
if s_len == 1 { return mul_digit(other.clone(), self.data[0]); }
|
BigInt::from_biguint(Minus, ret.normalize())
|
||||||
if o_len == 1 { return mul_digit(self, other.data[0]); }
|
},
|
||||||
|
_ => Zero::zero(),
|
||||||
// Using Karatsuba multiplication
|
|
||||||
// (a1 * base + a0) * (b1 * base + b0)
|
|
||||||
// = a1*b1 * base^2 +
|
|
||||||
// (a1*b1 + a0*b0 - (a1-b0)*(b1-a0)) * base +
|
|
||||||
// a0*b0
|
|
||||||
let half_len = cmp::max(s_len, o_len) / 2;
|
|
||||||
let (s_hi, s_lo) = cut_at(self, half_len);
|
|
||||||
let (o_hi, o_lo) = cut_at(other.clone(), half_len);
|
|
||||||
|
|
||||||
let ll = &s_lo * &o_lo;
|
|
||||||
let hh = &s_hi * &o_hi;
|
|
||||||
let mm = {
|
|
||||||
let (s1, n1) = sub_sign(s_hi, s_lo);
|
|
||||||
let (s2, n2) = sub_sign(o_hi, o_lo);
|
|
||||||
match (s1, s2) {
|
|
||||||
(Equal, _) | (_, Equal) => &hh + &ll,
|
|
||||||
(Less, Greater) | (Greater, Less) => &hh + &ll + (n1 * n2),
|
|
||||||
(Less, Less) | (Greater, Greater) => &hh + &ll - (n1 * n2)
|
|
||||||
}
|
}
|
||||||
};
|
}
|
||||||
|
|
||||||
return ll + mm.shl_unit(half_len) + hh.shl_unit(half_len * 2);
|
forward_all_binop_to_ref_ref!(impl Mul for BigUint, mul);
|
||||||
|
|
||||||
|
/// Three argument multiply accumulate:
|
||||||
|
/// acc += b * c
|
||||||
|
fn mac_digit(acc: &mut [BigDigit], b: &[BigDigit], c: BigDigit) {
|
||||||
|
if c == 0 { return; }
|
||||||
|
|
||||||
fn mul_digit(a: BigUint, n: BigDigit) -> BigUint {
|
let mut b_iter = b.iter();
|
||||||
if n == 0 { return Zero::zero(); }
|
|
||||||
if n == 1 { return a; }
|
|
||||||
|
|
||||||
let mut carry = 0;
|
let mut carry = 0;
|
||||||
let mut prod = a.data;
|
|
||||||
for a in &mut prod {
|
for ai in acc.iter_mut() {
|
||||||
let d = (*a as DoubleBigDigit)
|
if let Some(bi) = b_iter.next() {
|
||||||
* (n as DoubleBigDigit)
|
*ai = mac_with_carry(*ai, *bi, c, &mut carry);
|
||||||
+ (carry as DoubleBigDigit);
|
} else if carry != 0 {
|
||||||
let (hi, lo) = big_digit::from_doublebigdigit(d);
|
*ai = mac_with_carry(*ai, 0, c, &mut carry);
|
||||||
carry = hi;
|
} else {
|
||||||
*a = lo;
|
break;
|
||||||
}
|
}
|
||||||
if carry != 0 { prod.push(carry); }
|
|
||||||
BigUint::new(prod)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
assert!(carry == 0);
|
||||||
fn cut_at(mut a: BigUint, n: usize) -> (BigUint, BigUint) {
|
}
|
||||||
let mid = cmp::min(a.data.len(), n);
|
|
||||||
let hi = BigUint::from_slice(&a.data[mid ..]);
|
|
||||||
a.data.truncate(mid);
|
|
||||||
(hi, BigUint::new(a.data))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline]
|
/// Three argument multiply accumulate:
|
||||||
fn sub_sign(a: BigUint, b: BigUint) -> (Ordering, BigUint) {
|
/// acc += b * c
|
||||||
match a.cmp(&b) {
|
fn mac3(acc: &mut [BigDigit], b: &[BigDigit], c: &[BigDigit]) {
|
||||||
Less => (Less, b - a),
|
let (x, y) = if b.len() < c.len() { (b, c) } else { (c, b) };
|
||||||
Greater => (Greater, a - b),
|
|
||||||
_ => (Equal, Zero::zero())
|
/*
|
||||||
|
* Karatsuba multiplication is slower than long multiplication for small x and y:
|
||||||
|
*/
|
||||||
|
if x.len() <= 4 {
|
||||||
|
for (i, xi) in x.iter().enumerate() {
|
||||||
|
mac_digit(&mut acc[i..], y, *xi);
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
/*
|
||||||
|
* Karatsuba multiplication:
|
||||||
|
*
|
||||||
|
* The idea is that we break x and y up into two smaller numbers that each have about half
|
||||||
|
* as many digits, like so (note that multiplying by b is just a shift):
|
||||||
|
*
|
||||||
|
* x = x0 + x1 * b
|
||||||
|
* y = y0 + y1 * b
|
||||||
|
*
|
||||||
|
* With some algebra, we can compute x * y with three smaller products, where the inputs to
|
||||||
|
* each of the smaller products have only about half as many digits as x and y:
|
||||||
|
*
|
||||||
|
* x * y = (x0 + x1 * b) * (y0 + y1 * b)
|
||||||
|
*
|
||||||
|
* x * y = x0 * y0
|
||||||
|
* + x0 * y1 * b
|
||||||
|
* + x1 * y0 * b
|
||||||
|
* + x1 * y1 * b^2
|
||||||
|
*
|
||||||
|
* Let p0 = x0 * y0 and p2 = x1 * y1:
|
||||||
|
*
|
||||||
|
* x * y = p0
|
||||||
|
* + (x0 * y1 + x1 * p0) * b
|
||||||
|
* + p2 * b^2
|
||||||
|
*
|
||||||
|
* The real trick is that middle term:
|
||||||
|
*
|
||||||
|
* x0 * y1 + x1 * y0
|
||||||
|
*
|
||||||
|
* = x0 * y1 + x1 * y0 - p0 + p0 - p2 + p2
|
||||||
|
*
|
||||||
|
* = x0 * y1 + x1 * y0 - x0 * y0 - x1 * y1 + p0 + p2
|
||||||
|
*
|
||||||
|
* Now we complete the square:
|
||||||
|
*
|
||||||
|
* = -(x0 * y0 - x0 * y1 - x1 * y0 + x1 * y1) + p0 + p2
|
||||||
|
*
|
||||||
|
* = -((x1 - x0) * (y1 - y0)) + p0 + p2
|
||||||
|
*
|
||||||
|
* Let p1 = (x1 - x0) * (y1 - y0), and substitute back into our original formula:
|
||||||
|
*
|
||||||
|
* x * y = p0
|
||||||
|
* + (p0 + p2 - p1) * b
|
||||||
|
* + p2 * b^2
|
||||||
|
*
|
||||||
|
* Where the three intermediate products are:
|
||||||
|
*
|
||||||
|
* p0 = x0 * y0
|
||||||
|
* p1 = (x1 - x0) * (y1 - y0)
|
||||||
|
* p2 = x1 * y1
|
||||||
|
*
|
||||||
|
* In doing the computation, we take great care to avoid unnecessary temporary variables
|
||||||
|
* (since creating a BigUint requires a heap allocation): thus, we rearrange the formula a
|
||||||
|
* bit so we can use the same temporary variable for all the intermediate products:
|
||||||
|
*
|
||||||
|
* x * y = p2 * b^2 + p2 * b
|
||||||
|
* + p0 * b + p0
|
||||||
|
* - p1 * b
|
||||||
|
*
|
||||||
|
* The other trick we use is instead of doing explicit shifts, we slice acc at the
|
||||||
|
* appropriate offset when doing the add.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/*
|
||||||
|
* When x is smaller than y, it's significantly faster to pick b such that x is split in
|
||||||
|
* half, not y:
|
||||||
|
*/
|
||||||
|
let b = x.len() / 2;
|
||||||
|
let (x0, x1) = x.split_at(b);
|
||||||
|
let (y0, y1) = y.split_at(b);
|
||||||
|
|
||||||
|
/* We reuse the same BigUint for all the intermediate multiplies: */
|
||||||
|
|
||||||
|
let len = y.len() + 1;
|
||||||
|
let mut p: BigUint = BigUint { data: Vec::with_capacity(len) };
|
||||||
|
p.data.extend(repeat(0).take(len));
|
||||||
|
|
||||||
|
// p2 = x1 * y1
|
||||||
|
mac3(&mut p.data[..], x1, y1);
|
||||||
|
|
||||||
|
// Not required, but the adds go faster if we drop any unneeded 0s from the end:
|
||||||
|
p = p.normalize();
|
||||||
|
|
||||||
|
add2(&mut acc[b..], &p.data[..]);
|
||||||
|
add2(&mut acc[b * 2..], &p.data[..]);
|
||||||
|
|
||||||
|
// Zero out p before the next multiply:
|
||||||
|
p.data.truncate(0);
|
||||||
|
p.data.extend(repeat(0).take(len));
|
||||||
|
|
||||||
|
// p0 = x0 * y0
|
||||||
|
mac3(&mut p.data[..], x0, y0);
|
||||||
|
p = p.normalize();
|
||||||
|
|
||||||
|
add2(&mut acc[..], &p.data[..]);
|
||||||
|
add2(&mut acc[b..], &p.data[..]);
|
||||||
|
|
||||||
|
// p1 = (x1 - x0) * (y1 - y0)
|
||||||
|
// We do this one last, since it may be negative and acc can't ever be negative:
|
||||||
|
let j0 = sub_sign(x1, x0);
|
||||||
|
let j1 = sub_sign(y1, y0);
|
||||||
|
|
||||||
|
match j0.sign * j1.sign {
|
||||||
|
Plus => {
|
||||||
|
p.data.truncate(0);
|
||||||
|
p.data.extend(repeat(0).take(len));
|
||||||
|
|
||||||
|
mac3(&mut p.data[..], &j0.data.data[..], &j1.data.data[..]);
|
||||||
|
p = p.normalize();
|
||||||
|
|
||||||
|
sub2(&mut acc[b..], &p.data[..]);
|
||||||
|
},
|
||||||
|
Minus => {
|
||||||
|
mac3(&mut acc[b..], &j0.data.data[..], &j1.data.data[..]);
|
||||||
|
},
|
||||||
|
NoSign => (),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn mul3(x: &[BigDigit], y: &[BigDigit]) -> BigUint {
|
||||||
|
let len = x.len() + y.len() + 1;
|
||||||
|
let mut prod: BigUint = BigUint { data: Vec::with_capacity(len) };
|
||||||
|
|
||||||
|
// resize isn't stable yet:
|
||||||
|
//prod.data.resize(len, 0);
|
||||||
|
prod.data.extend(repeat(0).take(len));
|
||||||
|
|
||||||
|
mac3(&mut prod.data[..], x, y);
|
||||||
|
prod.normalize()
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, 'b> Mul<&'b BigUint> for &'a BigUint {
|
||||||
|
type Output = BigUint;
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn mul(self, other: &BigUint) -> BigUint {
|
||||||
|
mul3(&self.data[..], &other.data[..])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
forward_all_binop_to_ref_ref!(impl Div for BigUint, div);
|
forward_all_binop_to_ref_ref!(impl Div for BigUint, div);
|
||||||
|
|
||||||
|
@ -3131,6 +3270,16 @@ mod biguint_tests {
|
||||||
// Switching u and l should fail:
|
// Switching u and l should fail:
|
||||||
let _n: BigUint = rng.gen_biguint_range(&u, &l);
|
let _n: BigUint = rng.gen_biguint_range(&u, &l);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_sub_sign() {
|
||||||
|
use super::sub_sign;
|
||||||
|
let a = BigInt::from_str_radix("265252859812191058636308480000000", 10).unwrap();
|
||||||
|
let b = BigInt::from_str_radix("26525285981219105863630848000000", 10).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(sub_sign(&a.data.data[..], &b.data.data[..]), &a - &b);
|
||||||
|
assert_eq!(sub_sign(&b.data.data[..], &a.data.data[..]), &b - &a);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
Loading…
Reference in New Issue