1 use super::addition::{__add2, add2};
2 use super::subtraction::sub2;
3 #[cfg(not(u64_digit))]
4 use super::u32_from_u128;
5 use super::{biguint_from_vec, cmp_slice, BigUint, IntDigits};
6 
7 use crate::big_digit::{self, BigDigit, DoubleBigDigit};
8 use crate::Sign::{self, Minus, NoSign, Plus};
9 use crate::{BigInt, UsizePromotion};
10 
11 use core::cmp::Ordering;
12 use core::iter::Product;
13 use core::ops::{Mul, MulAssign};
14 use num_traits::{CheckedMul, FromPrimitive, One, Zero};
15 
16 #[inline]
mac_with_carry( a: BigDigit, b: BigDigit, c: BigDigit, acc: &mut DoubleBigDigit, ) -> BigDigit17 pub(super) fn mac_with_carry(
18     a: BigDigit,
19     b: BigDigit,
20     c: BigDigit,
21     acc: &mut DoubleBigDigit,
22 ) -> BigDigit {
23     *acc += DoubleBigDigit::from(a);
24     *acc += DoubleBigDigit::from(b) * DoubleBigDigit::from(c);
25     let lo = *acc as BigDigit;
26     *acc >>= big_digit::BITS;
27     lo
28 }
29 
30 #[inline]
mul_with_carry(a: BigDigit, b: BigDigit, acc: &mut DoubleBigDigit) -> BigDigit31 fn mul_with_carry(a: BigDigit, b: BigDigit, acc: &mut DoubleBigDigit) -> BigDigit {
32     *acc += DoubleBigDigit::from(a) * DoubleBigDigit::from(b);
33     let lo = *acc as BigDigit;
34     *acc >>= big_digit::BITS;
35     lo
36 }
37 
38 /// Three argument multiply accumulate:
39 /// acc += b * c
mac_digit(acc: &mut [BigDigit], b: &[BigDigit], c: BigDigit)40 fn mac_digit(acc: &mut [BigDigit], b: &[BigDigit], c: BigDigit) {
41     if c == 0 {
42         return;
43     }
44 
45     let mut carry = 0;
46     let (a_lo, a_hi) = acc.split_at_mut(b.len());
47 
48     for (a, &b) in a_lo.iter_mut().zip(b) {
49         *a = mac_with_carry(*a, b, c, &mut carry);
50     }
51 
52     let (carry_hi, carry_lo) = big_digit::from_doublebigdigit(carry);
53 
54     let final_carry = if carry_hi == 0 {
55         __add2(a_hi, &[carry_lo])
56     } else {
57         __add2(a_hi, &[carry_hi, carry_lo])
58     };
59     assert_eq!(final_carry, 0, "carry overflow during multiplication!");
60 }
61 
bigint_from_slice(slice: &[BigDigit]) -> BigInt62 fn bigint_from_slice(slice: &[BigDigit]) -> BigInt {
63     BigInt::from(biguint_from_vec(slice.to_vec()))
64 }
65 
66 /// Three argument multiply accumulate:
67 /// acc += b * c
68 #[allow(clippy::many_single_char_names)]
mac3(mut acc: &mut [BigDigit], mut b: &[BigDigit], mut c: &[BigDigit])69 fn mac3(mut acc: &mut [BigDigit], mut b: &[BigDigit], mut c: &[BigDigit]) {
70     // Least-significant zeros have no effect on the output.
71     if let Some(&0) = b.first() {
72         if let Some(nz) = b.iter().position(|&d| d != 0) {
73             b = &b[nz..];
74             acc = &mut acc[nz..];
75         } else {
76             return;
77         }
78     }
79     if let Some(&0) = c.first() {
80         if let Some(nz) = c.iter().position(|&d| d != 0) {
81             c = &c[nz..];
82             acc = &mut acc[nz..];
83         } else {
84             return;
85         }
86     }
87 
88     let acc = acc;
89     let (x, y) = if b.len() < c.len() { (b, c) } else { (c, b) };
90 
91     // We use three algorithms for different input sizes.
92     //
93     // - For small inputs, long multiplication is fastest.
94     // - Next we use Karatsuba multiplication (Toom-2), which we have optimized
95     //   to avoid unnecessary allocations for intermediate values.
96     // - For the largest inputs we use Toom-3, which better optimizes the
97     //   number of operations, but uses more temporary allocations.
98     //
99     // The thresholds are somewhat arbitrary, chosen by evaluating the results
100     // of `cargo bench --bench bigint multiply`.
101 
102     if x.len() <= 32 {
103         // Long multiplication:
104         for (i, xi) in x.iter().enumerate() {
105             mac_digit(&mut acc[i..], y, *xi);
106         }
107     } else if x.len() <= 256 {
108         // Karatsuba multiplication:
109         //
110         // The idea is that we break x and y up into two smaller numbers that each have about half
111         // as many digits, like so (note that multiplying by b is just a shift):
112         //
113         // x = x0 + x1 * b
114         // y = y0 + y1 * b
115         //
116         // With some algebra, we can compute x * y with three smaller products, where the inputs to
117         // each of the smaller products have only about half as many digits as x and y:
118         //
119         // x * y = (x0 + x1 * b) * (y0 + y1 * b)
120         //
121         // x * y = x0 * y0
122         //       + x0 * y1 * b
123         //       + x1 * y0 * b
124         //       + x1 * y1 * b^2
125         //
126         // Let p0 = x0 * y0 and p2 = x1 * y1:
127         //
128         // x * y = p0
129         //       + (x0 * y1 + x1 * y0) * b
130         //       + p2 * b^2
131         //
132         // The real trick is that middle term:
133         //
134         //         x0 * y1 + x1 * y0
135         //
136         //       = x0 * y1 + x1 * y0 - p0 + p0 - p2 + p2
137         //
138         //       = x0 * y1 + x1 * y0 - x0 * y0 - x1 * y1 + p0 + p2
139         //
140         // Now we complete the square:
141         //
142         //       = -(x0 * y0 - x0 * y1 - x1 * y0 + x1 * y1) + p0 + p2
143         //
144         //       = -((x1 - x0) * (y1 - y0)) + p0 + p2
145         //
146         // Let p1 = (x1 - x0) * (y1 - y0), and substitute back into our original formula:
147         //
148         // x * y = p0
149         //       + (p0 + p2 - p1) * b
150         //       + p2 * b^2
151         //
152         // Where the three intermediate products are:
153         //
154         // p0 = x0 * y0
155         // p1 = (x1 - x0) * (y1 - y0)
156         // p2 = x1 * y1
157         //
158         // In doing the computation, we take great care to avoid unnecessary temporary variables
159         // (since creating a BigUint requires a heap allocation): thus, we rearrange the formula a
160         // bit so we can use the same temporary variable for all the intermediate products:
161         //
162         // x * y = p2 * b^2 + p2 * b
163         //       + p0 * b + p0
164         //       - p1 * b
165         //
166         // The other trick we use is instead of doing explicit shifts, we slice acc at the
167         // appropriate offset when doing the add.
168 
169         // When x is smaller than y, it's significantly faster to pick b such that x is split in
170         // half, not y:
171         let b = x.len() / 2;
172         let (x0, x1) = x.split_at(b);
173         let (y0, y1) = y.split_at(b);
174 
175         // We reuse the same BigUint for all the intermediate multiplies and have to size p
176         // appropriately here: x1.len() >= x0.len and y1.len() >= y0.len():
177         let len = x1.len() + y1.len() + 1;
178         let mut p = BigUint { data: vec![0; len] };
179 
180         // p2 = x1 * y1
181         mac3(&mut p.data, x1, y1);
182 
183         // Not required, but the adds go faster if we drop any unneeded 0s from the end:
184         p.normalize();
185 
186         add2(&mut acc[b..], &p.data);
187         add2(&mut acc[b * 2..], &p.data);
188 
189         // Zero out p before the next multiply:
190         p.data.truncate(0);
191         p.data.resize(len, 0);
192 
193         // p0 = x0 * y0
194         mac3(&mut p.data, x0, y0);
195         p.normalize();
196 
197         add2(acc, &p.data);
198         add2(&mut acc[b..], &p.data);
199 
200         // p1 = (x1 - x0) * (y1 - y0)
201         // We do this one last, since it may be negative and acc can't ever be negative:
202         let (j0_sign, j0) = sub_sign(x1, x0);
203         let (j1_sign, j1) = sub_sign(y1, y0);
204 
205         match j0_sign * j1_sign {
206             Plus => {
207                 p.data.truncate(0);
208                 p.data.resize(len, 0);
209 
210                 mac3(&mut p.data, &j0.data, &j1.data);
211                 p.normalize();
212 
213                 sub2(&mut acc[b..], &p.data);
214             }
215             Minus => {
216                 mac3(&mut acc[b..], &j0.data, &j1.data);
217             }
218             NoSign => (),
219         }
220     } else {
221         // Toom-3 multiplication:
222         //
223         // Toom-3 is like Karatsuba above, but dividing the inputs into three parts.
224         // Both are instances of Toom-Cook, using `k=3` and `k=2` respectively.
225         //
226         // The general idea is to treat the large integers digits as
227         // polynomials of a certain degree and determine the coefficients/digits
228         // of the product of the two via interpolation of the polynomial product.
229         let i = y.len() / 3 + 1;
230 
231         let x0_len = Ord::min(x.len(), i);
232         let x1_len = Ord::min(x.len() - x0_len, i);
233 
234         let y0_len = i;
235         let y1_len = Ord::min(y.len() - y0_len, i);
236 
237         // Break x and y into three parts, representating an order two polynomial.
238         // t is chosen to be the size of a digit so we can use faster shifts
239         // in place of multiplications.
240         //
241         // x(t) = x2*t^2 + x1*t + x0
242         let x0 = bigint_from_slice(&x[..x0_len]);
243         let x1 = bigint_from_slice(&x[x0_len..x0_len + x1_len]);
244         let x2 = bigint_from_slice(&x[x0_len + x1_len..]);
245 
246         // y(t) = y2*t^2 + y1*t + y0
247         let y0 = bigint_from_slice(&y[..y0_len]);
248         let y1 = bigint_from_slice(&y[y0_len..y0_len + y1_len]);
249         let y2 = bigint_from_slice(&y[y0_len + y1_len..]);
250 
251         // Let w(t) = x(t) * y(t)
252         //
253         // This gives us the following order-4 polynomial.
254         //
255         // w(t) = w4*t^4 + w3*t^3 + w2*t^2 + w1*t + w0
256         //
257         // We need to find the coefficients w4, w3, w2, w1 and w0. Instead
258         // of simply multiplying the x and y in total, we can evaluate w
259         // at 5 points. An n-degree polynomial is uniquely identified by (n + 1)
260         // points.
261         //
262         // It is arbitrary as to what points we evaluate w at but we use the
263         // following.
264         //
265         // w(t) at t = 0, 1, -1, -2 and inf
266         //
267         // The values for w(t) in terms of x(t)*y(t) at these points are:
268         //
269         // let a = w(0)   = x0 * y0
270         // let b = w(1)   = (x2 + x1 + x0) * (y2 + y1 + y0)
271         // let c = w(-1)  = (x2 - x1 + x0) * (y2 - y1 + y0)
272         // let d = w(-2)  = (4*x2 - 2*x1 + x0) * (4*y2 - 2*y1 + y0)
273         // let e = w(inf) = x2 * y2 as t -> inf
274 
275         // x0 + x2, avoiding temporaries
276         let p = &x0 + &x2;
277 
278         // y0 + y2, avoiding temporaries
279         let q = &y0 + &y2;
280 
281         // x2 - x1 + x0, avoiding temporaries
282         let p2 = &p - &x1;
283 
284         // y2 - y1 + y0, avoiding temporaries
285         let q2 = &q - &y1;
286 
287         // w(0)
288         let r0 = &x0 * &y0;
289 
290         // w(inf)
291         let r4 = &x2 * &y2;
292 
293         // w(1)
294         let r1 = (p + x1) * (q + y1);
295 
296         // w(-1)
297         let r2 = &p2 * &q2;
298 
299         // w(-2)
300         let r3 = ((p2 + x2) * 2 - x0) * ((q2 + y2) * 2 - y0);
301 
302         // Evaluating these points gives us the following system of linear equations.
303         //
304         //  0  0  0  0  1 | a
305         //  1  1  1  1  1 | b
306         //  1 -1  1 -1  1 | c
307         // 16 -8  4 -2  1 | d
308         //  1  0  0  0  0 | e
309         //
310         // The solved equation (after gaussian elimination or similar)
311         // in terms of its coefficients:
312         //
313         // w0 = w(0)
314         // w1 = w(0)/2 + w(1)/3 - w(-1) + w(2)/6 - 2*w(inf)
315         // w2 = -w(0) + w(1)/2 + w(-1)/2 - w(inf)
316         // w3 = -w(0)/2 + w(1)/6 + w(-1)/2 - w(1)/6
317         // w4 = w(inf)
318         //
319         // This particular sequence is given by Bodrato and is an interpolation
320         // of the above equations.
321         let mut comp3: BigInt = (r3 - &r1) / 3u32;
322         let mut comp1: BigInt = (r1 - &r2) >> 1;
323         let mut comp2: BigInt = r2 - &r0;
324         comp3 = ((&comp2 - comp3) >> 1) + (&r4 << 1);
325         comp2 += &comp1 - &r4;
326         comp1 -= &comp3;
327 
328         // Recomposition. The coefficients of the polynomial are now known.
329         //
330         // Evaluate at w(t) where t is our given base to get the result.
331         //
332         //     let bits = u64::from(big_digit::BITS) * i as u64;
333         //     let result = r0
334         //         + (comp1 << bits)
335         //         + (comp2 << (2 * bits))
336         //         + (comp3 << (3 * bits))
337         //         + (r4 << (4 * bits));
338         //     let result_pos = result.to_biguint().unwrap();
339         //     add2(&mut acc[..], &result_pos.data);
340         //
341         // But with less intermediate copying:
342         for (j, result) in [&r0, &comp1, &comp2, &comp3, &r4].iter().enumerate().rev() {
343             match result.sign() {
344                 Plus => add2(&mut acc[i * j..], result.digits()),
345                 Minus => sub2(&mut acc[i * j..], result.digits()),
346                 NoSign => {}
347             }
348         }
349     }
350 }
351 
mul3(x: &[BigDigit], y: &[BigDigit]) -> BigUint352 fn mul3(x: &[BigDigit], y: &[BigDigit]) -> BigUint {
353     let len = x.len() + y.len() + 1;
354     let mut prod = BigUint { data: vec![0; len] };
355 
356     mac3(&mut prod.data, x, y);
357     prod.normalized()
358 }
359 
scalar_mul(a: &mut BigUint, b: BigDigit)360 fn scalar_mul(a: &mut BigUint, b: BigDigit) {
361     match b {
362         0 => a.set_zero(),
363         1 => {}
364         _ => {
365             if b.is_power_of_two() {
366                 *a <<= b.trailing_zeros();
367             } else {
368                 let mut carry = 0;
369                 for a in a.data.iter_mut() {
370                     *a = mul_with_carry(*a, b, &mut carry);
371                 }
372                 if carry != 0 {
373                     a.data.push(carry as BigDigit);
374                 }
375             }
376         }
377     }
378 }
379 
sub_sign(mut a: &[BigDigit], mut b: &[BigDigit]) -> (Sign, BigUint)380 fn sub_sign(mut a: &[BigDigit], mut b: &[BigDigit]) -> (Sign, BigUint) {
381     // Normalize:
382     if let Some(&0) = a.last() {
383         a = &a[..a.iter().rposition(|&x| x != 0).map_or(0, |i| i + 1)];
384     }
385     if let Some(&0) = b.last() {
386         b = &b[..b.iter().rposition(|&x| x != 0).map_or(0, |i| i + 1)];
387     }
388 
389     match cmp_slice(a, b) {
390         Ordering::Greater => {
391             let mut a = a.to_vec();
392             sub2(&mut a, b);
393             (Plus, biguint_from_vec(a))
394         }
395         Ordering::Less => {
396             let mut b = b.to_vec();
397             sub2(&mut b, a);
398             (Minus, biguint_from_vec(b))
399         }
400         Ordering::Equal => (NoSign, Zero::zero()),
401     }
402 }
403 
404 macro_rules! impl_mul {
405     ($(impl Mul<$Other:ty> for $Self:ty;)*) => {$(
406         impl Mul<$Other> for $Self {
407             type Output = BigUint;
408 
409             #[inline]
410             fn mul(self, other: $Other) -> BigUint {
411                 match (&*self.data, &*other.data) {
412                     // multiply by zero
413                     (&[], _) | (_, &[]) => BigUint::zero(),
414                     // multiply by a scalar
415                     (_, &[digit]) => self * digit,
416                     (&[digit], _) => other * digit,
417                     // full multiplication
418                     (x, y) => mul3(x, y),
419                 }
420             }
421         }
422     )*}
423 }
424 impl_mul! {
425     impl Mul<BigUint> for BigUint;
426     impl Mul<BigUint> for &BigUint;
427     impl Mul<&BigUint> for BigUint;
428     impl Mul<&BigUint> for &BigUint;
429 }
430 
431 macro_rules! impl_mul_assign {
432     ($(impl MulAssign<$Other:ty> for BigUint;)*) => {$(
433         impl MulAssign<$Other> for BigUint {
434             #[inline]
435             fn mul_assign(&mut self, other: $Other) {
436                 match (&*self.data, &*other.data) {
437                     // multiply by zero
438                     (&[], _) => {},
439                     (_, &[]) => self.set_zero(),
440                     // multiply by a scalar
441                     (_, &[digit]) => *self *= digit,
442                     (&[digit], _) => *self = other * digit,
443                     // full multiplication
444                     (x, y) => *self = mul3(x, y),
445                 }
446             }
447         }
448     )*}
449 }
450 impl_mul_assign! {
451     impl MulAssign<BigUint> for BigUint;
452     impl MulAssign<&BigUint> for BigUint;
453 }
454 
455 promote_unsigned_scalars!(impl Mul for BigUint, mul);
456 promote_unsigned_scalars_assign!(impl MulAssign for BigUint, mul_assign);
457 forward_all_scalar_binop_to_val_val_commutative!(impl Mul<u32> for BigUint, mul);
458 forward_all_scalar_binop_to_val_val_commutative!(impl Mul<u64> for BigUint, mul);
459 forward_all_scalar_binop_to_val_val_commutative!(impl Mul<u128> for BigUint, mul);
460 
461 impl Mul<u32> for BigUint {
462     type Output = BigUint;
463 
464     #[inline]
mul(mut self, other: u32) -> BigUint465     fn mul(mut self, other: u32) -> BigUint {
466         self *= other;
467         self
468     }
469 }
470 impl MulAssign<u32> for BigUint {
471     #[inline]
mul_assign(&mut self, other: u32)472     fn mul_assign(&mut self, other: u32) {
473         scalar_mul(self, other as BigDigit);
474     }
475 }
476 
477 impl Mul<u64> for BigUint {
478     type Output = BigUint;
479 
480     #[inline]
mul(mut self, other: u64) -> BigUint481     fn mul(mut self, other: u64) -> BigUint {
482         self *= other;
483         self
484     }
485 }
486 impl MulAssign<u64> for BigUint {
487     #[cfg(not(u64_digit))]
488     #[inline]
mul_assign(&mut self, other: u64)489     fn mul_assign(&mut self, other: u64) {
490         if let Some(other) = BigDigit::from_u64(other) {
491             scalar_mul(self, other);
492         } else {
493             let (hi, lo) = big_digit::from_doublebigdigit(other);
494             *self = mul3(&self.data, &[lo, hi]);
495         }
496     }
497 
498     #[cfg(u64_digit)]
499     #[inline]
mul_assign(&mut self, other: u64)500     fn mul_assign(&mut self, other: u64) {
501         scalar_mul(self, other);
502     }
503 }
504 
505 impl Mul<u128> for BigUint {
506     type Output = BigUint;
507 
508     #[inline]
mul(mut self, other: u128) -> BigUint509     fn mul(mut self, other: u128) -> BigUint {
510         self *= other;
511         self
512     }
513 }
514 
515 impl MulAssign<u128> for BigUint {
516     #[cfg(not(u64_digit))]
517     #[inline]
mul_assign(&mut self, other: u128)518     fn mul_assign(&mut self, other: u128) {
519         if let Some(other) = BigDigit::from_u128(other) {
520             scalar_mul(self, other);
521         } else {
522             *self = match u32_from_u128(other) {
523                 (0, 0, c, d) => mul3(&self.data, &[d, c]),
524                 (0, b, c, d) => mul3(&self.data, &[d, c, b]),
525                 (a, b, c, d) => mul3(&self.data, &[d, c, b, a]),
526             };
527         }
528     }
529 
530     #[cfg(u64_digit)]
531     #[inline]
mul_assign(&mut self, other: u128)532     fn mul_assign(&mut self, other: u128) {
533         if let Some(other) = BigDigit::from_u128(other) {
534             scalar_mul(self, other);
535         } else {
536             let (hi, lo) = big_digit::from_doublebigdigit(other);
537             *self = mul3(&self.data, &[lo, hi]);
538         }
539     }
540 }
541 
542 impl CheckedMul for BigUint {
543     #[inline]
checked_mul(&self, v: &BigUint) -> Option<BigUint>544     fn checked_mul(&self, v: &BigUint) -> Option<BigUint> {
545         Some(self.mul(v))
546     }
547 }
548 
549 impl_product_iter_type!(BigUint);
550 
551 #[test]
test_sub_sign()552 fn test_sub_sign() {
553     use crate::BigInt;
554     use num_traits::Num;
555 
556     fn sub_sign_i(a: &[BigDigit], b: &[BigDigit]) -> BigInt {
557         let (sign, val) = sub_sign(a, b);
558         BigInt::from_biguint(sign, val)
559     }
560 
561     let a = BigUint::from_str_radix("265252859812191058636308480000000", 10).unwrap();
562     let b = BigUint::from_str_radix("26525285981219105863630848000000", 10).unwrap();
563     let a_i = BigInt::from(a.clone());
564     let b_i = BigInt::from(b.clone());
565 
566     assert_eq!(sub_sign_i(&a.data, &b.data), &a_i - &b_i);
567     assert_eq!(sub_sign_i(&b.data, &a.data), &b_i - &a_i);
568 }
569