1 use super::addition::{__add2, add2};
2 use super::subtraction::sub2;
3 #[cfg(not(u64_digit))]
4 use super::u32_from_u128;
5 use super::{biguint_from_vec, cmp_slice, BigUint, IntDigits};
6
7 use crate::big_digit::{self, BigDigit, DoubleBigDigit};
8 use crate::Sign::{self, Minus, NoSign, Plus};
9 use crate::{BigInt, UsizePromotion};
10
11 use core::cmp::Ordering;
12 use core::iter::Product;
13 use core::ops::{Mul, MulAssign};
14 use num_traits::{CheckedMul, FromPrimitive, One, Zero};
15
16 #[inline]
mac_with_carry( a: BigDigit, b: BigDigit, c: BigDigit, acc: &mut DoubleBigDigit, ) -> BigDigit17 pub(super) fn mac_with_carry(
18 a: BigDigit,
19 b: BigDigit,
20 c: BigDigit,
21 acc: &mut DoubleBigDigit,
22 ) -> BigDigit {
23 *acc += DoubleBigDigit::from(a);
24 *acc += DoubleBigDigit::from(b) * DoubleBigDigit::from(c);
25 let lo = *acc as BigDigit;
26 *acc >>= big_digit::BITS;
27 lo
28 }
29
30 #[inline]
mul_with_carry(a: BigDigit, b: BigDigit, acc: &mut DoubleBigDigit) -> BigDigit31 fn mul_with_carry(a: BigDigit, b: BigDigit, acc: &mut DoubleBigDigit) -> BigDigit {
32 *acc += DoubleBigDigit::from(a) * DoubleBigDigit::from(b);
33 let lo = *acc as BigDigit;
34 *acc >>= big_digit::BITS;
35 lo
36 }
37
38 /// Three argument multiply accumulate:
39 /// acc += b * c
mac_digit(acc: &mut [BigDigit], b: &[BigDigit], c: BigDigit)40 fn mac_digit(acc: &mut [BigDigit], b: &[BigDigit], c: BigDigit) {
41 if c == 0 {
42 return;
43 }
44
45 let mut carry = 0;
46 let (a_lo, a_hi) = acc.split_at_mut(b.len());
47
48 for (a, &b) in a_lo.iter_mut().zip(b) {
49 *a = mac_with_carry(*a, b, c, &mut carry);
50 }
51
52 let (carry_hi, carry_lo) = big_digit::from_doublebigdigit(carry);
53
54 let final_carry = if carry_hi == 0 {
55 __add2(a_hi, &[carry_lo])
56 } else {
57 __add2(a_hi, &[carry_hi, carry_lo])
58 };
59 assert_eq!(final_carry, 0, "carry overflow during multiplication!");
60 }
61
bigint_from_slice(slice: &[BigDigit]) -> BigInt62 fn bigint_from_slice(slice: &[BigDigit]) -> BigInt {
63 BigInt::from(biguint_from_vec(slice.to_vec()))
64 }
65
66 /// Three argument multiply accumulate:
67 /// acc += b * c
68 #[allow(clippy::many_single_char_names)]
mac3(mut acc: &mut [BigDigit], mut b: &[BigDigit], mut c: &[BigDigit])69 fn mac3(mut acc: &mut [BigDigit], mut b: &[BigDigit], mut c: &[BigDigit]) {
70 // Least-significant zeros have no effect on the output.
71 if let Some(&0) = b.first() {
72 if let Some(nz) = b.iter().position(|&d| d != 0) {
73 b = &b[nz..];
74 acc = &mut acc[nz..];
75 } else {
76 return;
77 }
78 }
79 if let Some(&0) = c.first() {
80 if let Some(nz) = c.iter().position(|&d| d != 0) {
81 c = &c[nz..];
82 acc = &mut acc[nz..];
83 } else {
84 return;
85 }
86 }
87
88 let acc = acc;
89 let (x, y) = if b.len() < c.len() { (b, c) } else { (c, b) };
90
91 // We use three algorithms for different input sizes.
92 //
93 // - For small inputs, long multiplication is fastest.
94 // - Next we use Karatsuba multiplication (Toom-2), which we have optimized
95 // to avoid unnecessary allocations for intermediate values.
96 // - For the largest inputs we use Toom-3, which better optimizes the
97 // number of operations, but uses more temporary allocations.
98 //
99 // The thresholds are somewhat arbitrary, chosen by evaluating the results
100 // of `cargo bench --bench bigint multiply`.
101
102 if x.len() <= 32 {
103 // Long multiplication:
104 for (i, xi) in x.iter().enumerate() {
105 mac_digit(&mut acc[i..], y, *xi);
106 }
107 } else if x.len() <= 256 {
108 // Karatsuba multiplication:
109 //
110 // The idea is that we break x and y up into two smaller numbers that each have about half
111 // as many digits, like so (note that multiplying by b is just a shift):
112 //
113 // x = x0 + x1 * b
114 // y = y0 + y1 * b
115 //
116 // With some algebra, we can compute x * y with three smaller products, where the inputs to
117 // each of the smaller products have only about half as many digits as x and y:
118 //
119 // x * y = (x0 + x1 * b) * (y0 + y1 * b)
120 //
121 // x * y = x0 * y0
122 // + x0 * y1 * b
123 // + x1 * y0 * b
124 // + x1 * y1 * b^2
125 //
126 // Let p0 = x0 * y0 and p2 = x1 * y1:
127 //
128 // x * y = p0
129 // + (x0 * y1 + x1 * y0) * b
130 // + p2 * b^2
131 //
132 // The real trick is that middle term:
133 //
134 // x0 * y1 + x1 * y0
135 //
136 // = x0 * y1 + x1 * y0 - p0 + p0 - p2 + p2
137 //
138 // = x0 * y1 + x1 * y0 - x0 * y0 - x1 * y1 + p0 + p2
139 //
140 // Now we complete the square:
141 //
142 // = -(x0 * y0 - x0 * y1 - x1 * y0 + x1 * y1) + p0 + p2
143 //
144 // = -((x1 - x0) * (y1 - y0)) + p0 + p2
145 //
146 // Let p1 = (x1 - x0) * (y1 - y0), and substitute back into our original formula:
147 //
148 // x * y = p0
149 // + (p0 + p2 - p1) * b
150 // + p2 * b^2
151 //
152 // Where the three intermediate products are:
153 //
154 // p0 = x0 * y0
155 // p1 = (x1 - x0) * (y1 - y0)
156 // p2 = x1 * y1
157 //
158 // In doing the computation, we take great care to avoid unnecessary temporary variables
159 // (since creating a BigUint requires a heap allocation): thus, we rearrange the formula a
160 // bit so we can use the same temporary variable for all the intermediate products:
161 //
162 // x * y = p2 * b^2 + p2 * b
163 // + p0 * b + p0
164 // - p1 * b
165 //
166 // The other trick we use is instead of doing explicit shifts, we slice acc at the
167 // appropriate offset when doing the add.
168
169 // When x is smaller than y, it's significantly faster to pick b such that x is split in
170 // half, not y:
171 let b = x.len() / 2;
172 let (x0, x1) = x.split_at(b);
173 let (y0, y1) = y.split_at(b);
174
175 // We reuse the same BigUint for all the intermediate multiplies and have to size p
176 // appropriately here: x1.len() >= x0.len and y1.len() >= y0.len():
177 let len = x1.len() + y1.len() + 1;
178 let mut p = BigUint { data: vec![0; len] };
179
180 // p2 = x1 * y1
181 mac3(&mut p.data, x1, y1);
182
183 // Not required, but the adds go faster if we drop any unneeded 0s from the end:
184 p.normalize();
185
186 add2(&mut acc[b..], &p.data);
187 add2(&mut acc[b * 2..], &p.data);
188
189 // Zero out p before the next multiply:
190 p.data.truncate(0);
191 p.data.resize(len, 0);
192
193 // p0 = x0 * y0
194 mac3(&mut p.data, x0, y0);
195 p.normalize();
196
197 add2(acc, &p.data);
198 add2(&mut acc[b..], &p.data);
199
200 // p1 = (x1 - x0) * (y1 - y0)
201 // We do this one last, since it may be negative and acc can't ever be negative:
202 let (j0_sign, j0) = sub_sign(x1, x0);
203 let (j1_sign, j1) = sub_sign(y1, y0);
204
205 match j0_sign * j1_sign {
206 Plus => {
207 p.data.truncate(0);
208 p.data.resize(len, 0);
209
210 mac3(&mut p.data, &j0.data, &j1.data);
211 p.normalize();
212
213 sub2(&mut acc[b..], &p.data);
214 }
215 Minus => {
216 mac3(&mut acc[b..], &j0.data, &j1.data);
217 }
218 NoSign => (),
219 }
220 } else {
221 // Toom-3 multiplication:
222 //
223 // Toom-3 is like Karatsuba above, but dividing the inputs into three parts.
224 // Both are instances of Toom-Cook, using `k=3` and `k=2` respectively.
225 //
226 // The general idea is to treat the large integers digits as
227 // polynomials of a certain degree and determine the coefficients/digits
228 // of the product of the two via interpolation of the polynomial product.
229 let i = y.len() / 3 + 1;
230
231 let x0_len = Ord::min(x.len(), i);
232 let x1_len = Ord::min(x.len() - x0_len, i);
233
234 let y0_len = i;
235 let y1_len = Ord::min(y.len() - y0_len, i);
236
237 // Break x and y into three parts, representating an order two polynomial.
238 // t is chosen to be the size of a digit so we can use faster shifts
239 // in place of multiplications.
240 //
241 // x(t) = x2*t^2 + x1*t + x0
242 let x0 = bigint_from_slice(&x[..x0_len]);
243 let x1 = bigint_from_slice(&x[x0_len..x0_len + x1_len]);
244 let x2 = bigint_from_slice(&x[x0_len + x1_len..]);
245
246 // y(t) = y2*t^2 + y1*t + y0
247 let y0 = bigint_from_slice(&y[..y0_len]);
248 let y1 = bigint_from_slice(&y[y0_len..y0_len + y1_len]);
249 let y2 = bigint_from_slice(&y[y0_len + y1_len..]);
250
251 // Let w(t) = x(t) * y(t)
252 //
253 // This gives us the following order-4 polynomial.
254 //
255 // w(t) = w4*t^4 + w3*t^3 + w2*t^2 + w1*t + w0
256 //
257 // We need to find the coefficients w4, w3, w2, w1 and w0. Instead
258 // of simply multiplying the x and y in total, we can evaluate w
259 // at 5 points. An n-degree polynomial is uniquely identified by (n + 1)
260 // points.
261 //
262 // It is arbitrary as to what points we evaluate w at but we use the
263 // following.
264 //
265 // w(t) at t = 0, 1, -1, -2 and inf
266 //
267 // The values for w(t) in terms of x(t)*y(t) at these points are:
268 //
269 // let a = w(0) = x0 * y0
270 // let b = w(1) = (x2 + x1 + x0) * (y2 + y1 + y0)
271 // let c = w(-1) = (x2 - x1 + x0) * (y2 - y1 + y0)
272 // let d = w(-2) = (4*x2 - 2*x1 + x0) * (4*y2 - 2*y1 + y0)
273 // let e = w(inf) = x2 * y2 as t -> inf
274
275 // x0 + x2, avoiding temporaries
276 let p = &x0 + &x2;
277
278 // y0 + y2, avoiding temporaries
279 let q = &y0 + &y2;
280
281 // x2 - x1 + x0, avoiding temporaries
282 let p2 = &p - &x1;
283
284 // y2 - y1 + y0, avoiding temporaries
285 let q2 = &q - &y1;
286
287 // w(0)
288 let r0 = &x0 * &y0;
289
290 // w(inf)
291 let r4 = &x2 * &y2;
292
293 // w(1)
294 let r1 = (p + x1) * (q + y1);
295
296 // w(-1)
297 let r2 = &p2 * &q2;
298
299 // w(-2)
300 let r3 = ((p2 + x2) * 2 - x0) * ((q2 + y2) * 2 - y0);
301
302 // Evaluating these points gives us the following system of linear equations.
303 //
304 // 0 0 0 0 1 | a
305 // 1 1 1 1 1 | b
306 // 1 -1 1 -1 1 | c
307 // 16 -8 4 -2 1 | d
308 // 1 0 0 0 0 | e
309 //
310 // The solved equation (after gaussian elimination or similar)
311 // in terms of its coefficients:
312 //
313 // w0 = w(0)
314 // w1 = w(0)/2 + w(1)/3 - w(-1) + w(2)/6 - 2*w(inf)
315 // w2 = -w(0) + w(1)/2 + w(-1)/2 - w(inf)
316 // w3 = -w(0)/2 + w(1)/6 + w(-1)/2 - w(1)/6
317 // w4 = w(inf)
318 //
319 // This particular sequence is given by Bodrato and is an interpolation
320 // of the above equations.
321 let mut comp3: BigInt = (r3 - &r1) / 3u32;
322 let mut comp1: BigInt = (r1 - &r2) >> 1;
323 let mut comp2: BigInt = r2 - &r0;
324 comp3 = ((&comp2 - comp3) >> 1) + (&r4 << 1);
325 comp2 += &comp1 - &r4;
326 comp1 -= &comp3;
327
328 // Recomposition. The coefficients of the polynomial are now known.
329 //
330 // Evaluate at w(t) where t is our given base to get the result.
331 //
332 // let bits = u64::from(big_digit::BITS) * i as u64;
333 // let result = r0
334 // + (comp1 << bits)
335 // + (comp2 << (2 * bits))
336 // + (comp3 << (3 * bits))
337 // + (r4 << (4 * bits));
338 // let result_pos = result.to_biguint().unwrap();
339 // add2(&mut acc[..], &result_pos.data);
340 //
341 // But with less intermediate copying:
342 for (j, result) in [&r0, &comp1, &comp2, &comp3, &r4].iter().enumerate().rev() {
343 match result.sign() {
344 Plus => add2(&mut acc[i * j..], result.digits()),
345 Minus => sub2(&mut acc[i * j..], result.digits()),
346 NoSign => {}
347 }
348 }
349 }
350 }
351
mul3(x: &[BigDigit], y: &[BigDigit]) -> BigUint352 fn mul3(x: &[BigDigit], y: &[BigDigit]) -> BigUint {
353 let len = x.len() + y.len() + 1;
354 let mut prod = BigUint { data: vec![0; len] };
355
356 mac3(&mut prod.data, x, y);
357 prod.normalized()
358 }
359
scalar_mul(a: &mut BigUint, b: BigDigit)360 fn scalar_mul(a: &mut BigUint, b: BigDigit) {
361 match b {
362 0 => a.set_zero(),
363 1 => {}
364 _ => {
365 if b.is_power_of_two() {
366 *a <<= b.trailing_zeros();
367 } else {
368 let mut carry = 0;
369 for a in a.data.iter_mut() {
370 *a = mul_with_carry(*a, b, &mut carry);
371 }
372 if carry != 0 {
373 a.data.push(carry as BigDigit);
374 }
375 }
376 }
377 }
378 }
379
sub_sign(mut a: &[BigDigit], mut b: &[BigDigit]) -> (Sign, BigUint)380 fn sub_sign(mut a: &[BigDigit], mut b: &[BigDigit]) -> (Sign, BigUint) {
381 // Normalize:
382 if let Some(&0) = a.last() {
383 a = &a[..a.iter().rposition(|&x| x != 0).map_or(0, |i| i + 1)];
384 }
385 if let Some(&0) = b.last() {
386 b = &b[..b.iter().rposition(|&x| x != 0).map_or(0, |i| i + 1)];
387 }
388
389 match cmp_slice(a, b) {
390 Ordering::Greater => {
391 let mut a = a.to_vec();
392 sub2(&mut a, b);
393 (Plus, biguint_from_vec(a))
394 }
395 Ordering::Less => {
396 let mut b = b.to_vec();
397 sub2(&mut b, a);
398 (Minus, biguint_from_vec(b))
399 }
400 Ordering::Equal => (NoSign, Zero::zero()),
401 }
402 }
403
404 macro_rules! impl_mul {
405 ($(impl Mul<$Other:ty> for $Self:ty;)*) => {$(
406 impl Mul<$Other> for $Self {
407 type Output = BigUint;
408
409 #[inline]
410 fn mul(self, other: $Other) -> BigUint {
411 match (&*self.data, &*other.data) {
412 // multiply by zero
413 (&[], _) | (_, &[]) => BigUint::zero(),
414 // multiply by a scalar
415 (_, &[digit]) => self * digit,
416 (&[digit], _) => other * digit,
417 // full multiplication
418 (x, y) => mul3(x, y),
419 }
420 }
421 }
422 )*}
423 }
424 impl_mul! {
425 impl Mul<BigUint> for BigUint;
426 impl Mul<BigUint> for &BigUint;
427 impl Mul<&BigUint> for BigUint;
428 impl Mul<&BigUint> for &BigUint;
429 }
430
431 macro_rules! impl_mul_assign {
432 ($(impl MulAssign<$Other:ty> for BigUint;)*) => {$(
433 impl MulAssign<$Other> for BigUint {
434 #[inline]
435 fn mul_assign(&mut self, other: $Other) {
436 match (&*self.data, &*other.data) {
437 // multiply by zero
438 (&[], _) => {},
439 (_, &[]) => self.set_zero(),
440 // multiply by a scalar
441 (_, &[digit]) => *self *= digit,
442 (&[digit], _) => *self = other * digit,
443 // full multiplication
444 (x, y) => *self = mul3(x, y),
445 }
446 }
447 }
448 )*}
449 }
450 impl_mul_assign! {
451 impl MulAssign<BigUint> for BigUint;
452 impl MulAssign<&BigUint> for BigUint;
453 }
454
455 promote_unsigned_scalars!(impl Mul for BigUint, mul);
456 promote_unsigned_scalars_assign!(impl MulAssign for BigUint, mul_assign);
457 forward_all_scalar_binop_to_val_val_commutative!(impl Mul<u32> for BigUint, mul);
458 forward_all_scalar_binop_to_val_val_commutative!(impl Mul<u64> for BigUint, mul);
459 forward_all_scalar_binop_to_val_val_commutative!(impl Mul<u128> for BigUint, mul);
460
461 impl Mul<u32> for BigUint {
462 type Output = BigUint;
463
464 #[inline]
mul(mut self, other: u32) -> BigUint465 fn mul(mut self, other: u32) -> BigUint {
466 self *= other;
467 self
468 }
469 }
470 impl MulAssign<u32> for BigUint {
471 #[inline]
mul_assign(&mut self, other: u32)472 fn mul_assign(&mut self, other: u32) {
473 scalar_mul(self, other as BigDigit);
474 }
475 }
476
477 impl Mul<u64> for BigUint {
478 type Output = BigUint;
479
480 #[inline]
mul(mut self, other: u64) -> BigUint481 fn mul(mut self, other: u64) -> BigUint {
482 self *= other;
483 self
484 }
485 }
486 impl MulAssign<u64> for BigUint {
487 #[cfg(not(u64_digit))]
488 #[inline]
mul_assign(&mut self, other: u64)489 fn mul_assign(&mut self, other: u64) {
490 if let Some(other) = BigDigit::from_u64(other) {
491 scalar_mul(self, other);
492 } else {
493 let (hi, lo) = big_digit::from_doublebigdigit(other);
494 *self = mul3(&self.data, &[lo, hi]);
495 }
496 }
497
498 #[cfg(u64_digit)]
499 #[inline]
mul_assign(&mut self, other: u64)500 fn mul_assign(&mut self, other: u64) {
501 scalar_mul(self, other);
502 }
503 }
504
505 impl Mul<u128> for BigUint {
506 type Output = BigUint;
507
508 #[inline]
mul(mut self, other: u128) -> BigUint509 fn mul(mut self, other: u128) -> BigUint {
510 self *= other;
511 self
512 }
513 }
514
515 impl MulAssign<u128> for BigUint {
516 #[cfg(not(u64_digit))]
517 #[inline]
mul_assign(&mut self, other: u128)518 fn mul_assign(&mut self, other: u128) {
519 if let Some(other) = BigDigit::from_u128(other) {
520 scalar_mul(self, other);
521 } else {
522 *self = match u32_from_u128(other) {
523 (0, 0, c, d) => mul3(&self.data, &[d, c]),
524 (0, b, c, d) => mul3(&self.data, &[d, c, b]),
525 (a, b, c, d) => mul3(&self.data, &[d, c, b, a]),
526 };
527 }
528 }
529
530 #[cfg(u64_digit)]
531 #[inline]
mul_assign(&mut self, other: u128)532 fn mul_assign(&mut self, other: u128) {
533 if let Some(other) = BigDigit::from_u128(other) {
534 scalar_mul(self, other);
535 } else {
536 let (hi, lo) = big_digit::from_doublebigdigit(other);
537 *self = mul3(&self.data, &[lo, hi]);
538 }
539 }
540 }
541
542 impl CheckedMul for BigUint {
543 #[inline]
checked_mul(&self, v: &BigUint) -> Option<BigUint>544 fn checked_mul(&self, v: &BigUint) -> Option<BigUint> {
545 Some(self.mul(v))
546 }
547 }
548
549 impl_product_iter_type!(BigUint);
550
551 #[test]
test_sub_sign()552 fn test_sub_sign() {
553 use crate::BigInt;
554 use num_traits::Num;
555
556 fn sub_sign_i(a: &[BigDigit], b: &[BigDigit]) -> BigInt {
557 let (sign, val) = sub_sign(a, b);
558 BigInt::from_biguint(sign, val)
559 }
560
561 let a = BigUint::from_str_radix("265252859812191058636308480000000", 10).unwrap();
562 let b = BigUint::from_str_radix("26525285981219105863630848000000", 10).unwrap();
563 let a_i = BigInt::from(a.clone());
564 let b_i = BigInt::from(b.clone());
565
566 assert_eq!(sub_sign_i(&a.data, &b.data), &a_i - &b_i);
567 assert_eq!(sub_sign_i(&b.data, &a.data), &b_i - &a_i);
568 }
569