587 lines
17 KiB
Rust
587 lines
17 KiB
Rust
use std::borrow::Cow;
|
|
use std::cmp;
|
|
use std::cmp::Ordering::{self, Less, Greater, Equal};
|
|
use std::iter::repeat;
|
|
use std::mem;
|
|
use traits;
|
|
use traits::{Zero, One};
|
|
|
|
use biguint::BigUint;
|
|
|
|
use bigint::Sign;
|
|
use bigint::Sign::{Minus, NoSign, Plus};
|
|
|
|
#[allow(non_snake_case)]
|
|
pub mod big_digit {
|
|
/// A `BigDigit` is a `BigUint`'s composing element.
|
|
pub type BigDigit = u32;
|
|
|
|
/// A `DoubleBigDigit` is the internal type used to do the computations. Its
|
|
/// size is the double of the size of `BigDigit`.
|
|
pub type DoubleBigDigit = u64;
|
|
|
|
pub const ZERO_BIG_DIGIT: BigDigit = 0;
|
|
|
|
// `DoubleBigDigit` size dependent
|
|
pub const BITS: usize = 32;
|
|
|
|
pub const BASE: DoubleBigDigit = 1 << BITS;
|
|
const LO_MASK: DoubleBigDigit = (-1i32 as DoubleBigDigit) >> BITS;
|
|
|
|
#[inline]
|
|
fn get_hi(n: DoubleBigDigit) -> BigDigit {
|
|
(n >> BITS) as BigDigit
|
|
}
|
|
#[inline]
|
|
fn get_lo(n: DoubleBigDigit) -> BigDigit {
|
|
(n & LO_MASK) as BigDigit
|
|
}
|
|
|
|
/// Split one `DoubleBigDigit` into two `BigDigit`s.
|
|
#[inline]
|
|
pub fn from_doublebigdigit(n: DoubleBigDigit) -> (BigDigit, BigDigit) {
|
|
(get_hi(n), get_lo(n))
|
|
}
|
|
|
|
/// Join two `BigDigit`s into one `DoubleBigDigit`
|
|
#[inline]
|
|
pub fn to_doublebigdigit(hi: BigDigit, lo: BigDigit) -> DoubleBigDigit {
|
|
(lo as DoubleBigDigit) | ((hi as DoubleBigDigit) << BITS)
|
|
}
|
|
}
|
|
|
|
use big_digit::{BigDigit, DoubleBigDigit};
|
|
|
|
// Generic functions for add/subtract/multiply with carry/borrow:
|
|
|
|
// Add with carry:
|
|
#[inline]
|
|
fn adc(a: BigDigit, b: BigDigit, carry: &mut BigDigit) -> BigDigit {
|
|
let (hi, lo) = big_digit::from_doublebigdigit((a as DoubleBigDigit) + (b as DoubleBigDigit) +
|
|
(*carry as DoubleBigDigit));
|
|
|
|
*carry = hi;
|
|
lo
|
|
}
|
|
|
|
// Subtract with borrow:
|
|
#[inline]
|
|
fn sbb(a: BigDigit, b: BigDigit, borrow: &mut BigDigit) -> BigDigit {
|
|
let (hi, lo) = big_digit::from_doublebigdigit(big_digit::BASE + (a as DoubleBigDigit) -
|
|
(b as DoubleBigDigit) -
|
|
(*borrow as DoubleBigDigit));
|
|
// hi * (base) + lo == 1*(base) + ai - bi - borrow
|
|
// => ai - bi - borrow < 0 <=> hi == 0
|
|
*borrow = (hi == 0) as BigDigit;
|
|
lo
|
|
}
|
|
|
|
#[inline]
|
|
pub fn mac_with_carry(a: BigDigit, b: BigDigit, c: BigDigit, carry: &mut BigDigit) -> BigDigit {
|
|
let (hi, lo) = big_digit::from_doublebigdigit((a as DoubleBigDigit) +
|
|
(b as DoubleBigDigit) * (c as DoubleBigDigit) +
|
|
(*carry as DoubleBigDigit));
|
|
*carry = hi;
|
|
lo
|
|
}
|
|
|
|
/// Divide a two digit numerator by a one digit divisor, returns quotient and remainder:
|
|
///
|
|
/// Note: the caller must ensure that both the quotient and remainder will fit into a single digit.
|
|
/// This is _not_ true for an arbitrary numerator/denominator.
|
|
///
|
|
/// (This function also matches what the x86 divide instruction does).
|
|
#[inline]
|
|
fn div_wide(hi: BigDigit, lo: BigDigit, divisor: BigDigit) -> (BigDigit, BigDigit) {
|
|
debug_assert!(hi < divisor);
|
|
|
|
let lhs = big_digit::to_doublebigdigit(hi, lo);
|
|
let rhs = divisor as DoubleBigDigit;
|
|
((lhs / rhs) as BigDigit, (lhs % rhs) as BigDigit)
|
|
}
|
|
|
|
pub fn div_rem_digit(mut a: BigUint, b: BigDigit) -> (BigUint, BigDigit) {
|
|
let mut rem = 0;
|
|
|
|
for d in a.data.iter_mut().rev() {
|
|
let (q, r) = div_wide(rem, *d, b);
|
|
*d = q;
|
|
rem = r;
|
|
}
|
|
|
|
(a.normalize(), rem)
|
|
}
|
|
|
|
// Only for the Add impl:
|
|
#[must_use]
|
|
#[inline]
|
|
pub fn __add2(a: &mut [BigDigit], b: &[BigDigit]) -> BigDigit {
|
|
debug_assert!(a.len() >= b.len());
|
|
|
|
let mut carry = 0;
|
|
let (a_lo, a_hi) = a.split_at_mut(b.len());
|
|
|
|
for (a, b) in a_lo.iter_mut().zip(b) {
|
|
*a = adc(*a, *b, &mut carry);
|
|
}
|
|
|
|
if carry != 0 {
|
|
for a in a_hi {
|
|
*a = adc(*a, 0, &mut carry);
|
|
if carry == 0 { break }
|
|
}
|
|
}
|
|
|
|
carry
|
|
}
|
|
|
|
/// /Two argument addition of raw slices:
|
|
/// a += b
|
|
///
|
|
/// The caller _must_ ensure that a is big enough to store the result - typically this means
|
|
/// resizing a to max(a.len(), b.len()) + 1, to fit a possible carry.
|
|
pub fn add2(a: &mut [BigDigit], b: &[BigDigit]) {
|
|
let carry = __add2(a, b);
|
|
|
|
debug_assert!(carry == 0);
|
|
}
|
|
|
|
pub fn sub2(a: &mut [BigDigit], b: &[BigDigit]) {
|
|
let mut borrow = 0;
|
|
|
|
let len = cmp::min(a.len(), b.len());
|
|
let (a_lo, a_hi) = a.split_at_mut(len);
|
|
let (b_lo, b_hi) = b.split_at(len);
|
|
|
|
for (a, b) in a_lo.iter_mut().zip(b_lo) {
|
|
*a = sbb(*a, *b, &mut borrow);
|
|
}
|
|
|
|
if borrow != 0 {
|
|
for a in a_hi {
|
|
*a = sbb(*a, 0, &mut borrow);
|
|
if borrow == 0 { break }
|
|
}
|
|
}
|
|
|
|
// note: we're _required_ to fail on underflow
|
|
assert!(borrow == 0 && b_hi.iter().all(|x| *x == 0),
|
|
"Cannot subtract b from a because b is larger than a.");
|
|
}
|
|
|
|
pub fn sub2rev(a: &[BigDigit], b: &mut [BigDigit]) {
|
|
debug_assert!(b.len() >= a.len());
|
|
|
|
let mut borrow = 0;
|
|
|
|
let len = cmp::min(a.len(), b.len());
|
|
let (a_lo, a_hi) = a.split_at(len);
|
|
let (b_lo, b_hi) = b.split_at_mut(len);
|
|
|
|
for (a, b) in a_lo.iter().zip(b_lo) {
|
|
*b = sbb(*a, *b, &mut borrow);
|
|
}
|
|
|
|
assert!(a_hi.is_empty());
|
|
|
|
// note: we're _required_ to fail on underflow
|
|
assert!(borrow == 0 && b_hi.iter().all(|x| *x == 0),
|
|
"Cannot subtract b from a because b is larger than a.");
|
|
}
|
|
|
|
pub fn sub_sign(a: &[BigDigit], b: &[BigDigit]) -> (Sign, BigUint) {
|
|
// Normalize:
|
|
let a = &a[..a.iter().rposition(|&x| x != 0).map_or(0, |i| i + 1)];
|
|
let b = &b[..b.iter().rposition(|&x| x != 0).map_or(0, |i| i + 1)];
|
|
|
|
match cmp_slice(a, b) {
|
|
Greater => {
|
|
let mut a = a.to_vec();
|
|
sub2(&mut a, b);
|
|
(Plus, BigUint::new(a))
|
|
}
|
|
Less => {
|
|
let mut b = b.to_vec();
|
|
sub2(&mut b, a);
|
|
(Minus, BigUint::new(b))
|
|
}
|
|
_ => (NoSign, Zero::zero()),
|
|
}
|
|
}
|
|
|
|
/// Three argument multiply accumulate:
|
|
/// acc += b * c
|
|
fn mac_digit(acc: &mut [BigDigit], b: &[BigDigit], c: BigDigit) {
|
|
if c == 0 {
|
|
return;
|
|
}
|
|
|
|
let mut b_iter = b.iter();
|
|
let mut carry = 0;
|
|
|
|
for ai in acc.iter_mut() {
|
|
if let Some(bi) = b_iter.next() {
|
|
*ai = mac_with_carry(*ai, *bi, c, &mut carry);
|
|
} else if carry != 0 {
|
|
*ai = mac_with_carry(*ai, 0, c, &mut carry);
|
|
} else {
|
|
break;
|
|
}
|
|
}
|
|
|
|
assert!(carry == 0);
|
|
}
|
|
|
|
/// Three argument multiply accumulate:
|
|
/// acc += b * c
|
|
fn mac3(acc: &mut [BigDigit], b: &[BigDigit], c: &[BigDigit]) {
|
|
let (x, y) = if b.len() < c.len() {
|
|
(b, c)
|
|
} else {
|
|
(c, b)
|
|
};
|
|
|
|
// Karatsuba multiplication is slower than long multiplication for small x and y:
|
|
//
|
|
if x.len() <= 4 {
|
|
for (i, xi) in x.iter().enumerate() {
|
|
mac_digit(&mut acc[i..], y, *xi);
|
|
}
|
|
} else {
|
|
/*
|
|
* Karatsuba multiplication:
|
|
*
|
|
* The idea is that we break x and y up into two smaller numbers that each have about half
|
|
* as many digits, like so (note that multiplying by b is just a shift):
|
|
*
|
|
* x = x0 + x1 * b
|
|
* y = y0 + y1 * b
|
|
*
|
|
* With some algebra, we can compute x * y with three smaller products, where the inputs to
|
|
* each of the smaller products have only about half as many digits as x and y:
|
|
*
|
|
* x * y = (x0 + x1 * b) * (y0 + y1 * b)
|
|
*
|
|
* x * y = x0 * y0
|
|
* + x0 * y1 * b
|
|
* + x1 * y0 * b
|
|
* + x1 * y1 * b^2
|
|
*
|
|
* Let p0 = x0 * y0 and p2 = x1 * y1:
|
|
*
|
|
* x * y = p0
|
|
* + (x0 * y1 + x1 * y0) * b
|
|
* + p2 * b^2
|
|
*
|
|
* The real trick is that middle term:
|
|
*
|
|
* x0 * y1 + x1 * y0
|
|
*
|
|
* = x0 * y1 + x1 * y0 - p0 + p0 - p2 + p2
|
|
*
|
|
* = x0 * y1 + x1 * y0 - x0 * y0 - x1 * y1 + p0 + p2
|
|
*
|
|
* Now we complete the square:
|
|
*
|
|
* = -(x0 * y0 - x0 * y1 - x1 * y0 + x1 * y1) + p0 + p2
|
|
*
|
|
* = -((x1 - x0) * (y1 - y0)) + p0 + p2
|
|
*
|
|
* Let p1 = (x1 - x0) * (y1 - y0), and substitute back into our original formula:
|
|
*
|
|
* x * y = p0
|
|
* + (p0 + p2 - p1) * b
|
|
* + p2 * b^2
|
|
*
|
|
* Where the three intermediate products are:
|
|
*
|
|
* p0 = x0 * y0
|
|
* p1 = (x1 - x0) * (y1 - y0)
|
|
* p2 = x1 * y1
|
|
*
|
|
* In doing the computation, we take great care to avoid unnecessary temporary variables
|
|
* (since creating a BigUint requires a heap allocation): thus, we rearrange the formula a
|
|
* bit so we can use the same temporary variable for all the intermediate products:
|
|
*
|
|
* x * y = p2 * b^2 + p2 * b
|
|
* + p0 * b + p0
|
|
* - p1 * b
|
|
*
|
|
* The other trick we use is instead of doing explicit shifts, we slice acc at the
|
|
* appropriate offset when doing the add.
|
|
*/
|
|
|
|
/*
|
|
* When x is smaller than y, it's significantly faster to pick b such that x is split in
|
|
* half, not y:
|
|
*/
|
|
let b = x.len() / 2;
|
|
let (x0, x1) = x.split_at(b);
|
|
let (y0, y1) = y.split_at(b);
|
|
|
|
/*
|
|
* We reuse the same BigUint for all the intermediate multiplies and have to size p
|
|
* appropriately here: x1.len() >= x0.len and y1.len() >= y0.len():
|
|
*/
|
|
let len = x1.len() + y1.len() + 1;
|
|
let mut p = BigUint { data: vec![0; len] };
|
|
|
|
// p2 = x1 * y1
|
|
mac3(&mut p.data[..], x1, y1);
|
|
|
|
// Not required, but the adds go faster if we drop any unneeded 0s from the end:
|
|
p = p.normalize();
|
|
|
|
add2(&mut acc[b..], &p.data[..]);
|
|
add2(&mut acc[b * 2..], &p.data[..]);
|
|
|
|
// Zero out p before the next multiply:
|
|
p.data.truncate(0);
|
|
p.data.extend(repeat(0).take(len));
|
|
|
|
// p0 = x0 * y0
|
|
mac3(&mut p.data[..], x0, y0);
|
|
p = p.normalize();
|
|
|
|
add2(&mut acc[..], &p.data[..]);
|
|
add2(&mut acc[b..], &p.data[..]);
|
|
|
|
// p1 = (x1 - x0) * (y1 - y0)
|
|
// We do this one last, since it may be negative and acc can't ever be negative:
|
|
let (j0_sign, j0) = sub_sign(x1, x0);
|
|
let (j1_sign, j1) = sub_sign(y1, y0);
|
|
|
|
match j0_sign * j1_sign {
|
|
Plus => {
|
|
p.data.truncate(0);
|
|
p.data.extend(repeat(0).take(len));
|
|
|
|
mac3(&mut p.data[..], &j0.data[..], &j1.data[..]);
|
|
p = p.normalize();
|
|
|
|
sub2(&mut acc[b..], &p.data[..]);
|
|
},
|
|
Minus => {
|
|
mac3(&mut acc[b..], &j0.data[..], &j1.data[..]);
|
|
},
|
|
NoSign => (),
|
|
}
|
|
}
|
|
}
|
|
|
|
pub fn mul3(x: &[BigDigit], y: &[BigDigit]) -> BigUint {
|
|
let len = x.len() + y.len() + 1;
|
|
let mut prod = BigUint { data: vec![0; len] };
|
|
|
|
mac3(&mut prod.data[..], x, y);
|
|
prod.normalize()
|
|
}
|
|
|
|
pub fn div_rem(u: &BigUint, d: &BigUint) -> (BigUint, BigUint) {
|
|
if d.is_zero() {
|
|
panic!()
|
|
}
|
|
if u.is_zero() {
|
|
return (Zero::zero(), Zero::zero());
|
|
}
|
|
if *d == One::one() {
|
|
return (u.clone(), Zero::zero());
|
|
}
|
|
|
|
// Required or the q_len calculation below can underflow:
|
|
match u.cmp(d) {
|
|
Less => return (Zero::zero(), u.clone()),
|
|
Equal => return (One::one(), Zero::zero()),
|
|
Greater => {} // Do nothing
|
|
}
|
|
|
|
// This algorithm is from Knuth, TAOCP vol 2 section 4.3, algorithm D:
|
|
//
|
|
// First, normalize the arguments so the highest bit in the highest digit of the divisor is
|
|
// set: the main loop uses the highest digit of the divisor for generating guesses, so we
|
|
// want it to be the largest number we can efficiently divide by.
|
|
//
|
|
let shift = d.data.last().unwrap().leading_zeros() as usize;
|
|
let mut a = u << shift;
|
|
let b = d << shift;
|
|
|
|
// The algorithm works by incrementally calculating "guesses", q0, for part of the
|
|
// remainder. Once we have any number q0 such that q0 * b <= a, we can set
|
|
//
|
|
// q += q0
|
|
// a -= q0 * b
|
|
//
|
|
// and then iterate until a < b. Then, (q, a) will be our desired quotient and remainder.
|
|
//
|
|
// q0, our guess, is calculated by dividing the last few digits of a by the last digit of b
|
|
// - this should give us a guess that is "close" to the actual quotient, but is possibly
|
|
// greater than the actual quotient. If q0 * b > a, we simply use iterated subtraction
|
|
// until we have a guess such that q0 & b <= a.
|
|
//
|
|
|
|
let bn = *b.data.last().unwrap();
|
|
let q_len = a.data.len() - b.data.len() + 1;
|
|
let mut q = BigUint { data: vec![0; q_len] };
|
|
|
|
// We reuse the same temporary to avoid hitting the allocator in our inner loop - this is
|
|
// sized to hold a0 (in the common case; if a particular digit of the quotient is zero a0
|
|
// can be bigger).
|
|
//
|
|
let mut tmp = BigUint { data: Vec::with_capacity(2) };
|
|
|
|
for j in (0..q_len).rev() {
|
|
/*
|
|
* When calculating our next guess q0, we don't need to consider the digits below j
|
|
* + b.data.len() - 1: we're guessing digit j of the quotient (i.e. q0 << j) from
|
|
* digit bn of the divisor (i.e. bn << (b.data.len() - 1) - so the product of those
|
|
* two numbers will be zero in all digits up to (j + b.data.len() - 1).
|
|
*/
|
|
let offset = j + b.data.len() - 1;
|
|
if offset >= a.data.len() {
|
|
continue;
|
|
}
|
|
|
|
/* just avoiding a heap allocation: */
|
|
let mut a0 = tmp;
|
|
a0.data.truncate(0);
|
|
a0.data.extend(a.data[offset..].iter().cloned());
|
|
|
|
/*
|
|
* q0 << j * big_digit::BITS is our actual quotient estimate - we do the shifts
|
|
* implicitly at the end, when adding and subtracting to a and q. Not only do we
|
|
* save the cost of the shifts, the rest of the arithmetic gets to work with
|
|
* smaller numbers.
|
|
*/
|
|
let (mut q0, _) = div_rem_digit(a0, bn);
|
|
let mut prod = &b * &q0;
|
|
|
|
while cmp_slice(&prod.data[..], &a.data[j..]) == Greater {
|
|
let one: BigUint = One::one();
|
|
q0 = q0 - one;
|
|
prod = prod - &b;
|
|
}
|
|
|
|
add2(&mut q.data[j..], &q0.data[..]);
|
|
sub2(&mut a.data[j..], &prod.data[..]);
|
|
a = a.normalize();
|
|
|
|
tmp = q0;
|
|
}
|
|
|
|
debug_assert!(a < b);
|
|
|
|
(q.normalize(), a >> shift)
|
|
}
|
|
|
|
/// Find last set bit
|
|
/// fls(0) == 0, fls(u32::MAX) == 32
|
|
pub fn fls<T: traits::PrimInt>(v: T) -> usize {
|
|
mem::size_of::<T>() * 8 - v.leading_zeros() as usize
|
|
}
|
|
|
|
pub fn ilog2<T: traits::PrimInt>(v: T) -> usize {
|
|
fls(v) - 1
|
|
}
|
|
|
|
#[inline]
|
|
pub fn biguint_shl(n: Cow<BigUint>, bits: usize) -> BigUint {
|
|
let n_unit = bits / big_digit::BITS;
|
|
let mut data = match n_unit {
|
|
0 => n.into_owned().data,
|
|
_ => {
|
|
let len = n_unit + n.data.len() + 1;
|
|
let mut data = Vec::with_capacity(len);
|
|
data.extend(repeat(0).take(n_unit));
|
|
data.extend(n.data.iter().cloned());
|
|
data
|
|
}
|
|
};
|
|
|
|
let n_bits = bits % big_digit::BITS;
|
|
if n_bits > 0 {
|
|
let mut carry = 0;
|
|
for elem in data[n_unit..].iter_mut() {
|
|
let new_carry = *elem >> (big_digit::BITS - n_bits);
|
|
*elem = (*elem << n_bits) | carry;
|
|
carry = new_carry;
|
|
}
|
|
if carry != 0 {
|
|
data.push(carry);
|
|
}
|
|
}
|
|
|
|
BigUint::new(data)
|
|
}
|
|
|
|
#[inline]
|
|
pub fn biguint_shr(n: Cow<BigUint>, bits: usize) -> BigUint {
|
|
let n_unit = bits / big_digit::BITS;
|
|
if n_unit >= n.data.len() {
|
|
return Zero::zero();
|
|
}
|
|
let mut data = match n_unit {
|
|
0 => n.into_owned().data,
|
|
_ => n.data[n_unit..].to_vec(),
|
|
};
|
|
|
|
let n_bits = bits % big_digit::BITS;
|
|
if n_bits > 0 {
|
|
let mut borrow = 0;
|
|
for elem in data.iter_mut().rev() {
|
|
let new_borrow = *elem << (big_digit::BITS - n_bits);
|
|
*elem = (*elem >> n_bits) | borrow;
|
|
borrow = new_borrow;
|
|
}
|
|
}
|
|
|
|
BigUint::new(data)
|
|
}
|
|
|
|
pub fn cmp_slice(a: &[BigDigit], b: &[BigDigit]) -> Ordering {
|
|
debug_assert!(a.last() != Some(&0));
|
|
debug_assert!(b.last() != Some(&0));
|
|
|
|
let (a_len, b_len) = (a.len(), b.len());
|
|
if a_len < b_len {
|
|
return Less;
|
|
}
|
|
if a_len > b_len {
|
|
return Greater;
|
|
}
|
|
|
|
for (&ai, &bi) in a.iter().rev().zip(b.iter().rev()) {
|
|
if ai < bi {
|
|
return Less;
|
|
}
|
|
if ai > bi {
|
|
return Greater;
|
|
}
|
|
}
|
|
return Equal;
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod algorithm_tests {
|
|
use {BigDigit, BigUint, BigInt};
|
|
use Sign::Plus;
|
|
use traits::Num;
|
|
|
|
#[test]
|
|
fn test_sub_sign() {
|
|
use super::sub_sign;
|
|
|
|
fn sub_sign_i(a: &[BigDigit], b: &[BigDigit]) -> BigInt {
|
|
let (sign, val) = sub_sign(a, b);
|
|
BigInt::from_biguint(sign, val)
|
|
}
|
|
|
|
let a = BigUint::from_str_radix("265252859812191058636308480000000", 10).unwrap();
|
|
let b = BigUint::from_str_radix("26525285981219105863630848000000", 10).unwrap();
|
|
let a_i = BigInt::from_biguint(Plus, a.clone());
|
|
let b_i = BigInt::from_biguint(Plus, b.clone());
|
|
|
|
assert_eq!(sub_sign_i(&a.data[..], &b.data[..]), &a_i - &b_i);
|
|
assert_eq!(sub_sign_i(&b.data[..], &a.data[..]), &b_i - &a_i);
|
|
}
|
|
}
|