1// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// This file implements unsigned multi-precision integers (natural
6// numbers). They are the building blocks for the implementation
7// of signed integers, rationals, and floating-point numbers.
8//
9// Caution: This implementation relies on the function "alias"
10//          which assumes that (nat) slice capacities are never
11//          changed (no 3-operand slice expressions). If that
12//          changes, alias needs to be updated for correctness.
13
14package big
15
16import (
17	"internal/byteorder"
18	"math/bits"
19	"math/rand"
20	"sync"
21)
22
23// An unsigned integer x of the form
24//
25//	x = x[n-1]*_B^(n-1) + x[n-2]*_B^(n-2) + ... + x[1]*_B + x[0]
26//
27// with 0 <= x[i] < _B and 0 <= i < n is stored in a slice of length n,
28// with the digits x[i] as the slice elements.
29//
30// A number is normalized if the slice contains no leading 0 digits.
31// During arithmetic operations, denormalized values may occur but are
32// always normalized before returning the final result. The normalized
33// representation of 0 is the empty or nil slice (length = 0).
34type nat []Word
35
36var (
37	natOne  = nat{1}
38	natTwo  = nat{2}
39	natFive = nat{5}
40	natTen  = nat{10}
41)
42
43func (z nat) String() string {
44	return "0x" + string(z.itoa(false, 16))
45}
46
47func (z nat) norm() nat {
48	i := len(z)
49	for i > 0 && z[i-1] == 0 {
50		i--
51	}
52	return z[0:i]
53}
54
55func (z nat) make(n int) nat {
56	if n <= cap(z) {
57		return z[:n] // reuse z
58	}
59	if n == 1 {
60		// Most nats start small and stay that way; don't over-allocate.
61		return make(nat, 1)
62	}
63	// Choosing a good value for e has significant performance impact
64	// because it increases the chance that a value can be reused.
65	const e = 4 // extra capacity
66	return make(nat, n, n+e)
67}
68
69func (z nat) setWord(x Word) nat {
70	if x == 0 {
71		return z[:0]
72	}
73	z = z.make(1)
74	z[0] = x
75	return z
76}
77
78func (z nat) setUint64(x uint64) nat {
79	// single-word value
80	if w := Word(x); uint64(w) == x {
81		return z.setWord(w)
82	}
83	// 2-word value
84	z = z.make(2)
85	z[1] = Word(x >> 32)
86	z[0] = Word(x)
87	return z
88}
89
90func (z nat) set(x nat) nat {
91	z = z.make(len(x))
92	copy(z, x)
93	return z
94}
95
96func (z nat) add(x, y nat) nat {
97	m := len(x)
98	n := len(y)
99
100	switch {
101	case m < n:
102		return z.add(y, x)
103	case m == 0:
104		// n == 0 because m >= n; result is 0
105		return z[:0]
106	case n == 0:
107		// result is x
108		return z.set(x)
109	}
110	// m > 0
111
112	z = z.make(m + 1)
113	c := addVV(z[0:n], x, y)
114	if m > n {
115		c = addVW(z[n:m], x[n:], c)
116	}
117	z[m] = c
118
119	return z.norm()
120}
121
122func (z nat) sub(x, y nat) nat {
123	m := len(x)
124	n := len(y)
125
126	switch {
127	case m < n:
128		panic("underflow")
129	case m == 0:
130		// n == 0 because m >= n; result is 0
131		return z[:0]
132	case n == 0:
133		// result is x
134		return z.set(x)
135	}
136	// m > 0
137
138	z = z.make(m)
139	c := subVV(z[0:n], x, y)
140	if m > n {
141		c = subVW(z[n:], x[n:], c)
142	}
143	if c != 0 {
144		panic("underflow")
145	}
146
147	return z.norm()
148}
149
150func (x nat) cmp(y nat) (r int) {
151	m := len(x)
152	n := len(y)
153	if m != n || m == 0 {
154		switch {
155		case m < n:
156			r = -1
157		case m > n:
158			r = 1
159		}
160		return
161	}
162
163	i := m - 1
164	for i > 0 && x[i] == y[i] {
165		i--
166	}
167
168	switch {
169	case x[i] < y[i]:
170		r = -1
171	case x[i] > y[i]:
172		r = 1
173	}
174	return
175}
176
177func (z nat) mulAddWW(x nat, y, r Word) nat {
178	m := len(x)
179	if m == 0 || y == 0 {
180		return z.setWord(r) // result is r
181	}
182	// m > 0
183
184	z = z.make(m + 1)
185	z[m] = mulAddVWW(z[0:m], x, y, r)
186
187	return z.norm()
188}
189
190// basicMul multiplies x and y and leaves the result in z.
191// The (non-normalized) result is placed in z[0 : len(x) + len(y)].
192func basicMul(z, x, y nat) {
193	clear(z[0 : len(x)+len(y)]) // initialize z
194	for i, d := range y {
195		if d != 0 {
196			z[len(x)+i] = addMulVVW(z[i:i+len(x)], x, d)
197		}
198	}
199}
200
201// montgomery computes z mod m = x*y*2**(-n*_W) mod m,
202// assuming k = -1/m mod 2**_W.
203// z is used for storing the result which is returned;
204// z must not alias x, y or m.
205// See Gueron, "Efficient Software Implementations of Modular Exponentiation".
206// https://eprint.iacr.org/2011/239.pdf
207// In the terminology of that paper, this is an "Almost Montgomery Multiplication":
208// x and y are required to satisfy 0 <= z < 2**(n*_W) and then the result
209// z is guaranteed to satisfy 0 <= z < 2**(n*_W), but it may not be < m.
210func (z nat) montgomery(x, y, m nat, k Word, n int) nat {
211	// This code assumes x, y, m are all the same length, n.
212	// (required by addMulVVW and the for loop).
213	// It also assumes that x, y are already reduced mod m,
214	// or else the result will not be properly reduced.
215	if len(x) != n || len(y) != n || len(m) != n {
216		panic("math/big: mismatched montgomery number lengths")
217	}
218	z = z.make(n * 2)
219	clear(z)
220	var c Word
221	for i := 0; i < n; i++ {
222		d := y[i]
223		c2 := addMulVVW(z[i:n+i], x, d)
224		t := z[i] * k
225		c3 := addMulVVW(z[i:n+i], m, t)
226		cx := c + c2
227		cy := cx + c3
228		z[n+i] = cy
229		if cx < c2 || cy < c3 {
230			c = 1
231		} else {
232			c = 0
233		}
234	}
235	if c != 0 {
236		subVV(z[:n], z[n:], m)
237	} else {
238		copy(z[:n], z[n:])
239	}
240	return z[:n]
241}
242
243// Fast version of z[0:n+n>>1].add(z[0:n+n>>1], x[0:n]) w/o bounds checks.
244// Factored out for readability - do not use outside karatsuba.
245func karatsubaAdd(z, x nat, n int) {
246	if c := addVV(z[0:n], z, x); c != 0 {
247		addVW(z[n:n+n>>1], z[n:], c)
248	}
249}
250
251// Like karatsubaAdd, but does subtract.
252func karatsubaSub(z, x nat, n int) {
253	if c := subVV(z[0:n], z, x); c != 0 {
254		subVW(z[n:n+n>>1], z[n:], c)
255	}
256}
257
258// Operands that are shorter than karatsubaThreshold are multiplied using
259// "grade school" multiplication; for longer operands the Karatsuba algorithm
260// is used.
261var karatsubaThreshold = 40 // computed by calibrate_test.go
262
263// karatsuba multiplies x and y and leaves the result in z.
264// Both x and y must have the same length n and n must be a
265// power of 2. The result vector z must have len(z) >= 6*n.
266// The (non-normalized) result is placed in z[0 : 2*n].
267func karatsuba(z, x, y nat) {
268	n := len(y)
269
270	// Switch to basic multiplication if numbers are odd or small.
271	// (n is always even if karatsubaThreshold is even, but be
272	// conservative)
273	if n&1 != 0 || n < karatsubaThreshold || n < 2 {
274		basicMul(z, x, y)
275		return
276	}
277	// n&1 == 0 && n >= karatsubaThreshold && n >= 2
278
279	// Karatsuba multiplication is based on the observation that
280	// for two numbers x and y with:
281	//
282	//   x = x1*b + x0
283	//   y = y1*b + y0
284	//
285	// the product x*y can be obtained with 3 products z2, z1, z0
286	// instead of 4:
287	//
288	//   x*y = x1*y1*b*b + (x1*y0 + x0*y1)*b + x0*y0
289	//       =    z2*b*b +              z1*b +    z0
290	//
291	// with:
292	//
293	//   xd = x1 - x0
294	//   yd = y0 - y1
295	//
296	//   z1 =      xd*yd                    + z2 + z0
297	//      = (x1-x0)*(y0 - y1)             + z2 + z0
298	//      = x1*y0 - x1*y1 - x0*y0 + x0*y1 + z2 + z0
299	//      = x1*y0 -    z2 -    z0 + x0*y1 + z2 + z0
300	//      = x1*y0                 + x0*y1
301
302	// split x, y into "digits"
303	n2 := n >> 1              // n2 >= 1
304	x1, x0 := x[n2:], x[0:n2] // x = x1*b + y0
305	y1, y0 := y[n2:], y[0:n2] // y = y1*b + y0
306
307	// z is used for the result and temporary storage:
308	//
309	//   6*n     5*n     4*n     3*n     2*n     1*n     0*n
310	// z = [z2 copy|z0 copy| xd*yd | yd:xd | x1*y1 | x0*y0 ]
311	//
312	// For each recursive call of karatsuba, an unused slice of
313	// z is passed in that has (at least) half the length of the
314	// caller's z.
315
316	// compute z0 and z2 with the result "in place" in z
317	karatsuba(z, x0, y0)     // z0 = x0*y0
318	karatsuba(z[n:], x1, y1) // z2 = x1*y1
319
320	// compute xd (or the negative value if underflow occurs)
321	s := 1 // sign of product xd*yd
322	xd := z[2*n : 2*n+n2]
323	if subVV(xd, x1, x0) != 0 { // x1-x0
324		s = -s
325		subVV(xd, x0, x1) // x0-x1
326	}
327
328	// compute yd (or the negative value if underflow occurs)
329	yd := z[2*n+n2 : 3*n]
330	if subVV(yd, y0, y1) != 0 { // y0-y1
331		s = -s
332		subVV(yd, y1, y0) // y1-y0
333	}
334
335	// p = (x1-x0)*(y0-y1) == x1*y0 - x1*y1 - x0*y0 + x0*y1 for s > 0
336	// p = (x0-x1)*(y0-y1) == x0*y0 - x0*y1 - x1*y0 + x1*y1 for s < 0
337	p := z[n*3:]
338	karatsuba(p, xd, yd)
339
340	// save original z2:z0
341	// (ok to use upper half of z since we're done recurring)
342	r := z[n*4:]
343	copy(r, z[:n*2])
344
345	// add up all partial products
346	//
347	//   2*n     n     0
348	// z = [ z2  | z0  ]
349	//   +    [ z0  ]
350	//   +    [ z2  ]
351	//   +    [  p  ]
352	//
353	karatsubaAdd(z[n2:], r, n)
354	karatsubaAdd(z[n2:], r[n:], n)
355	if s > 0 {
356		karatsubaAdd(z[n2:], p, n)
357	} else {
358		karatsubaSub(z[n2:], p, n)
359	}
360}
361
362// alias reports whether x and y share the same base array.
363//
364// Note: alias assumes that the capacity of underlying arrays
365// is never changed for nat values; i.e. that there are
366// no 3-operand slice expressions in this code (or worse,
367// reflect-based operations to the same effect).
368func alias(x, y nat) bool {
369	return cap(x) > 0 && cap(y) > 0 && &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1]
370}
371
372// addAt implements z += x<<(_W*i); z must be long enough.
373// (we don't use nat.add because we need z to stay the same
374// slice, and we don't need to normalize z after each addition)
375func addAt(z, x nat, i int) {
376	if n := len(x); n > 0 {
377		if c := addVV(z[i:i+n], z[i:], x); c != 0 {
378			j := i + n
379			if j < len(z) {
380				addVW(z[j:], z[j:], c)
381			}
382		}
383	}
384}
385
386// karatsubaLen computes an approximation to the maximum k <= n such that
387// k = p<<i for a number p <= threshold and an i >= 0. Thus, the
388// result is the largest number that can be divided repeatedly by 2 before
389// becoming about the value of threshold.
390func karatsubaLen(n, threshold int) int {
391	i := uint(0)
392	for n > threshold {
393		n >>= 1
394		i++
395	}
396	return n << i
397}
398
399func (z nat) mul(x, y nat) nat {
400	m := len(x)
401	n := len(y)
402
403	switch {
404	case m < n:
405		return z.mul(y, x)
406	case m == 0 || n == 0:
407		return z[:0]
408	case n == 1:
409		return z.mulAddWW(x, y[0], 0)
410	}
411	// m >= n > 1
412
413	// determine if z can be reused
414	if alias(z, x) || alias(z, y) {
415		z = nil // z is an alias for x or y - cannot reuse
416	}
417
418	// use basic multiplication if the numbers are small
419	if n < karatsubaThreshold {
420		z = z.make(m + n)
421		basicMul(z, x, y)
422		return z.norm()
423	}
424	// m >= n && n >= karatsubaThreshold && n >= 2
425
426	// determine Karatsuba length k such that
427	//
428	//   x = xh*b + x0  (0 <= x0 < b)
429	//   y = yh*b + y0  (0 <= y0 < b)
430	//   b = 1<<(_W*k)  ("base" of digits xi, yi)
431	//
432	k := karatsubaLen(n, karatsubaThreshold)
433	// k <= n
434
435	// multiply x0 and y0 via Karatsuba
436	x0 := x[0:k]              // x0 is not normalized
437	y0 := y[0:k]              // y0 is not normalized
438	z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and full result of x*y
439	karatsuba(z, x0, y0)
440	z = z[0 : m+n] // z has final length but may be incomplete
441	clear(z[2*k:]) // upper portion of z is garbage (and 2*k <= m+n since k <= n <= m)
442
443	// If xh != 0 or yh != 0, add the missing terms to z. For
444	//
445	//   xh = xi*b^i + ... + x2*b^2 + x1*b (0 <= xi < b)
446	//   yh =                         y1*b (0 <= y1 < b)
447	//
448	// the missing terms are
449	//
450	//   x0*y1*b and xi*y0*b^i, xi*y1*b^(i+1) for i > 0
451	//
452	// since all the yi for i > 1 are 0 by choice of k: If any of them
453	// were > 0, then yh >= b^2 and thus y >= b^2. Then k' = k*2 would
454	// be a larger valid threshold contradicting the assumption about k.
455	//
456	if k < n || m != n {
457		tp := getNat(3 * k)
458		t := *tp
459
460		// add x0*y1*b
461		x0 := x0.norm()
462		y1 := y[k:]       // y1 is normalized because y is
463		t = t.mul(x0, y1) // update t so we don't lose t's underlying array
464		addAt(z, t, k)
465
466		// add xi*y0<<i, xi*y1*b<<(i+k)
467		y0 := y0.norm()
468		for i := k; i < len(x); i += k {
469			xi := x[i:]
470			if len(xi) > k {
471				xi = xi[:k]
472			}
473			xi = xi.norm()
474			t = t.mul(xi, y0)
475			addAt(z, t, i)
476			t = t.mul(xi, y1)
477			addAt(z, t, i+k)
478		}
479
480		putNat(tp)
481	}
482
483	return z.norm()
484}
485
486// basicSqr sets z = x*x and is asymptotically faster than basicMul
487// by about a factor of 2, but slower for small arguments due to overhead.
488// Requirements: len(x) > 0, len(z) == 2*len(x)
489// The (non-normalized) result is placed in z.
490func basicSqr(z, x nat) {
491	n := len(x)
492	tp := getNat(2 * n)
493	t := *tp // temporary variable to hold the products
494	clear(t)
495	z[1], z[0] = mulWW(x[0], x[0]) // the initial square
496	for i := 1; i < n; i++ {
497		d := x[i]
498		// z collects the squares x[i] * x[i]
499		z[2*i+1], z[2*i] = mulWW(d, d)
500		// t collects the products x[i] * x[j] where j < i
501		t[2*i] = addMulVVW(t[i:2*i], x[0:i], d)
502	}
503	t[2*n-1] = shlVU(t[1:2*n-1], t[1:2*n-1], 1) // double the j < i products
504	addVV(z, z, t)                              // combine the result
505	putNat(tp)
506}
507
508// karatsubaSqr squares x and leaves the result in z.
509// len(x) must be a power of 2 and len(z) >= 6*len(x).
510// The (non-normalized) result is placed in z[0 : 2*len(x)].
511//
512// The algorithm and the layout of z are the same as for karatsuba.
513func karatsubaSqr(z, x nat) {
514	n := len(x)
515
516	if n&1 != 0 || n < karatsubaSqrThreshold || n < 2 {
517		basicSqr(z[:2*n], x)
518		return
519	}
520
521	n2 := n >> 1
522	x1, x0 := x[n2:], x[0:n2]
523
524	karatsubaSqr(z, x0)
525	karatsubaSqr(z[n:], x1)
526
527	// s = sign(xd*yd) == -1 for xd != 0; s == 1 for xd == 0
528	xd := z[2*n : 2*n+n2]
529	if subVV(xd, x1, x0) != 0 {
530		subVV(xd, x0, x1)
531	}
532
533	p := z[n*3:]
534	karatsubaSqr(p, xd)
535
536	r := z[n*4:]
537	copy(r, z[:n*2])
538
539	karatsubaAdd(z[n2:], r, n)
540	karatsubaAdd(z[n2:], r[n:], n)
541	karatsubaSub(z[n2:], p, n) // s == -1 for p != 0; s == 1 for p == 0
542}
543
544// Operands that are shorter than basicSqrThreshold are squared using
545// "grade school" multiplication; for operands longer than karatsubaSqrThreshold
546// we use the Karatsuba algorithm optimized for x == y.
547var basicSqrThreshold = 20      // computed by calibrate_test.go
548var karatsubaSqrThreshold = 260 // computed by calibrate_test.go
549
550// z = x*x
551func (z nat) sqr(x nat) nat {
552	n := len(x)
553	switch {
554	case n == 0:
555		return z[:0]
556	case n == 1:
557		d := x[0]
558		z = z.make(2)
559		z[1], z[0] = mulWW(d, d)
560		return z.norm()
561	}
562
563	if alias(z, x) {
564		z = nil // z is an alias for x - cannot reuse
565	}
566
567	if n < basicSqrThreshold {
568		z = z.make(2 * n)
569		basicMul(z, x, x)
570		return z.norm()
571	}
572	if n < karatsubaSqrThreshold {
573		z = z.make(2 * n)
574		basicSqr(z, x)
575		return z.norm()
576	}
577
578	// Use Karatsuba multiplication optimized for x == y.
579	// The algorithm and layout of z are the same as for mul.
580
581	// z = (x1*b + x0)^2 = x1^2*b^2 + 2*x1*x0*b + x0^2
582
583	k := karatsubaLen(n, karatsubaSqrThreshold)
584
585	x0 := x[0:k]
586	z = z.make(max(6*k, 2*n))
587	karatsubaSqr(z, x0) // z = x0^2
588	z = z[0 : 2*n]
589	clear(z[2*k:])
590
591	if k < n {
592		tp := getNat(2 * k)
593		t := *tp
594		x0 := x0.norm()
595		x1 := x[k:]
596		t = t.mul(x0, x1)
597		addAt(z, t, k)
598		addAt(z, t, k) // z = 2*x1*x0*b + x0^2
599		t = t.sqr(x1)
600		addAt(z, t, 2*k) // z = x1^2*b^2 + 2*x1*x0*b + x0^2
601		putNat(tp)
602	}
603
604	return z.norm()
605}
606
607// mulRange computes the product of all the unsigned integers in the
608// range [a, b] inclusively. If a > b (empty range), the result is 1.
609func (z nat) mulRange(a, b uint64) nat {
610	switch {
611	case a == 0:
612		// cut long ranges short (optimization)
613		return z.setUint64(0)
614	case a > b:
615		return z.setUint64(1)
616	case a == b:
617		return z.setUint64(a)
618	case a+1 == b:
619		return z.mul(nat(nil).setUint64(a), nat(nil).setUint64(b))
620	}
621	m := a + (b-a)/2 // avoid overflow
622	return z.mul(nat(nil).mulRange(a, m), nat(nil).mulRange(m+1, b))
623}
624
625// getNat returns a *nat of len n. The contents may not be zero.
626// The pool holds *nat to avoid allocation when converting to interface{}.
627func getNat(n int) *nat {
628	var z *nat
629	if v := natPool.Get(); v != nil {
630		z = v.(*nat)
631	}
632	if z == nil {
633		z = new(nat)
634	}
635	*z = z.make(n)
636	if n > 0 {
637		(*z)[0] = 0xfedcb // break code expecting zero
638	}
639	return z
640}
641
642func putNat(x *nat) {
643	natPool.Put(x)
644}
645
646var natPool sync.Pool
647
648// bitLen returns the length of x in bits.
649// Unlike most methods, it works even if x is not normalized.
650func (x nat) bitLen() int {
651	// This function is used in cryptographic operations. It must not leak
652	// anything but the Int's sign and bit size through side-channels. Any
653	// changes must be reviewed by a security expert.
654	if i := len(x) - 1; i >= 0 {
655		// bits.Len uses a lookup table for the low-order bits on some
656		// architectures. Neutralize any input-dependent behavior by setting all
657		// bits after the first one bit.
658		top := uint(x[i])
659		top |= top >> 1
660		top |= top >> 2
661		top |= top >> 4
662		top |= top >> 8
663		top |= top >> 16
664		top |= top >> 16 >> 16 // ">> 32" doesn't compile on 32-bit architectures
665		return i*_W + bits.Len(top)
666	}
667	return 0
668}
669
670// trailingZeroBits returns the number of consecutive least significant zero
671// bits of x.
672func (x nat) trailingZeroBits() uint {
673	if len(x) == 0 {
674		return 0
675	}
676	var i uint
677	for x[i] == 0 {
678		i++
679	}
680	// x[i] != 0
681	return i*_W + uint(bits.TrailingZeros(uint(x[i])))
682}
683
684// isPow2 returns i, true when x == 2**i and 0, false otherwise.
685func (x nat) isPow2() (uint, bool) {
686	var i uint
687	for x[i] == 0 {
688		i++
689	}
690	if i == uint(len(x))-1 && x[i]&(x[i]-1) == 0 {
691		return i*_W + uint(bits.TrailingZeros(uint(x[i]))), true
692	}
693	return 0, false
694}
695
696func same(x, y nat) bool {
697	return len(x) == len(y) && len(x) > 0 && &x[0] == &y[0]
698}
699
700// z = x << s
701func (z nat) shl(x nat, s uint) nat {
702	if s == 0 {
703		if same(z, x) {
704			return z
705		}
706		if !alias(z, x) {
707			return z.set(x)
708		}
709	}
710
711	m := len(x)
712	if m == 0 {
713		return z[:0]
714	}
715	// m > 0
716
717	n := m + int(s/_W)
718	z = z.make(n + 1)
719	z[n] = shlVU(z[n-m:n], x, s%_W)
720	clear(z[0 : n-m])
721
722	return z.norm()
723}
724
725// z = x >> s
726func (z nat) shr(x nat, s uint) nat {
727	if s == 0 {
728		if same(z, x) {
729			return z
730		}
731		if !alias(z, x) {
732			return z.set(x)
733		}
734	}
735
736	m := len(x)
737	n := m - int(s/_W)
738	if n <= 0 {
739		return z[:0]
740	}
741	// n > 0
742
743	z = z.make(n)
744	shrVU(z, x[m-n:], s%_W)
745
746	return z.norm()
747}
748
749func (z nat) setBit(x nat, i uint, b uint) nat {
750	j := int(i / _W)
751	m := Word(1) << (i % _W)
752	n := len(x)
753	switch b {
754	case 0:
755		z = z.make(n)
756		copy(z, x)
757		if j >= n {
758			// no need to grow
759			return z
760		}
761		z[j] &^= m
762		return z.norm()
763	case 1:
764		if j >= n {
765			z = z.make(j + 1)
766			clear(z[n:])
767		} else {
768			z = z.make(n)
769		}
770		copy(z, x)
771		z[j] |= m
772		// no need to normalize
773		return z
774	}
775	panic("set bit is not 0 or 1")
776}
777
778// bit returns the value of the i'th bit, with lsb == bit 0.
779func (x nat) bit(i uint) uint {
780	j := i / _W
781	if j >= uint(len(x)) {
782		return 0
783	}
784	// 0 <= j < len(x)
785	return uint(x[j] >> (i % _W) & 1)
786}
787
788// sticky returns 1 if there's a 1 bit within the
789// i least significant bits, otherwise it returns 0.
790func (x nat) sticky(i uint) uint {
791	j := i / _W
792	if j >= uint(len(x)) {
793		if len(x) == 0 {
794			return 0
795		}
796		return 1
797	}
798	// 0 <= j < len(x)
799	for _, x := range x[:j] {
800		if x != 0 {
801			return 1
802		}
803	}
804	if x[j]<<(_W-i%_W) != 0 {
805		return 1
806	}
807	return 0
808}
809
810func (z nat) and(x, y nat) nat {
811	m := len(x)
812	n := len(y)
813	if m > n {
814		m = n
815	}
816	// m <= n
817
818	z = z.make(m)
819	for i := 0; i < m; i++ {
820		z[i] = x[i] & y[i]
821	}
822
823	return z.norm()
824}
825
826// trunc returns z = x mod 2ⁿ.
827func (z nat) trunc(x nat, n uint) nat {
828	w := (n + _W - 1) / _W
829	if uint(len(x)) < w {
830		return z.set(x)
831	}
832	z = z.make(int(w))
833	copy(z, x)
834	if n%_W != 0 {
835		z[len(z)-1] &= 1<<(n%_W) - 1
836	}
837	return z.norm()
838}
839
840func (z nat) andNot(x, y nat) nat {
841	m := len(x)
842	n := len(y)
843	if n > m {
844		n = m
845	}
846	// m >= n
847
848	z = z.make(m)
849	for i := 0; i < n; i++ {
850		z[i] = x[i] &^ y[i]
851	}
852	copy(z[n:m], x[n:m])
853
854	return z.norm()
855}
856
857func (z nat) or(x, y nat) nat {
858	m := len(x)
859	n := len(y)
860	s := x
861	if m < n {
862		n, m = m, n
863		s = y
864	}
865	// m >= n
866
867	z = z.make(m)
868	for i := 0; i < n; i++ {
869		z[i] = x[i] | y[i]
870	}
871	copy(z[n:m], s[n:m])
872
873	return z.norm()
874}
875
876func (z nat) xor(x, y nat) nat {
877	m := len(x)
878	n := len(y)
879	s := x
880	if m < n {
881		n, m = m, n
882		s = y
883	}
884	// m >= n
885
886	z = z.make(m)
887	for i := 0; i < n; i++ {
888		z[i] = x[i] ^ y[i]
889	}
890	copy(z[n:m], s[n:m])
891
892	return z.norm()
893}
894
895// random creates a random integer in [0..limit), using the space in z if
896// possible. n is the bit length of limit.
897func (z nat) random(rand *rand.Rand, limit nat, n int) nat {
898	if alias(z, limit) {
899		z = nil // z is an alias for limit - cannot reuse
900	}
901	z = z.make(len(limit))
902
903	bitLengthOfMSW := uint(n % _W)
904	if bitLengthOfMSW == 0 {
905		bitLengthOfMSW = _W
906	}
907	mask := Word((1 << bitLengthOfMSW) - 1)
908
909	for {
910		switch _W {
911		case 32:
912			for i := range z {
913				z[i] = Word(rand.Uint32())
914			}
915		case 64:
916			for i := range z {
917				z[i] = Word(rand.Uint32()) | Word(rand.Uint32())<<32
918			}
919		default:
920			panic("unknown word size")
921		}
922		z[len(limit)-1] &= mask
923		if z.cmp(limit) < 0 {
924			break
925		}
926	}
927
928	return z.norm()
929}
930
931// If m != 0 (i.e., len(m) != 0), expNN sets z to x**y mod m;
932// otherwise it sets z to x**y. The result is the value of z.
933func (z nat) expNN(x, y, m nat, slow bool) nat {
934	if alias(z, x) || alias(z, y) {
935		// We cannot allow in-place modification of x or y.
936		z = nil
937	}
938
939	// x**y mod 1 == 0
940	if len(m) == 1 && m[0] == 1 {
941		return z.setWord(0)
942	}
943	// m == 0 || m > 1
944
945	// x**0 == 1
946	if len(y) == 0 {
947		return z.setWord(1)
948	}
949	// y > 0
950
951	// 0**y = 0
952	if len(x) == 0 {
953		return z.setWord(0)
954	}
955	// x > 0
956
957	// 1**y = 1
958	if len(x) == 1 && x[0] == 1 {
959		return z.setWord(1)
960	}
961	// x > 1
962
963	// x**1 == x
964	if len(y) == 1 && y[0] == 1 {
965		if len(m) != 0 {
966			return z.rem(x, m)
967		}
968		return z.set(x)
969	}
970	// y > 1
971
972	if len(m) != 0 {
973		// We likely end up being as long as the modulus.
974		z = z.make(len(m))
975
976		// If the exponent is large, we use the Montgomery method for odd values,
977		// and a 4-bit, windowed exponentiation for powers of two,
978		// and a CRT-decomposed Montgomery method for the remaining values
979		// (even values times non-trivial odd values, which decompose into one
980		// instance of each of the first two cases).
981		if len(y) > 1 && !slow {
982			if m[0]&1 == 1 {
983				return z.expNNMontgomery(x, y, m)
984			}
985			if logM, ok := m.isPow2(); ok {
986				return z.expNNWindowed(x, y, logM)
987			}
988			return z.expNNMontgomeryEven(x, y, m)
989		}
990	}
991
992	z = z.set(x)
993	v := y[len(y)-1] // v > 0 because y is normalized and y > 0
994	shift := nlz(v) + 1
995	v <<= shift
996	var q nat
997
998	const mask = 1 << (_W - 1)
999
1000	// We walk through the bits of the exponent one by one. Each time we
1001	// see a bit, we square, thus doubling the power. If the bit is a one,
1002	// we also multiply by x, thus adding one to the power.
1003
1004	w := _W - int(shift)
1005	// zz and r are used to avoid allocating in mul and div as
1006	// otherwise the arguments would alias.
1007	var zz, r nat
1008	for j := 0; j < w; j++ {
1009		zz = zz.sqr(z)
1010		zz, z = z, zz
1011
1012		if v&mask != 0 {
1013			zz = zz.mul(z, x)
1014			zz, z = z, zz
1015		}
1016
1017		if len(m) != 0 {
1018			zz, r = zz.div(r, z, m)
1019			zz, r, q, z = q, z, zz, r
1020		}
1021
1022		v <<= 1
1023	}
1024
1025	for i := len(y) - 2; i >= 0; i-- {
1026		v = y[i]
1027
1028		for j := 0; j < _W; j++ {
1029			zz = zz.sqr(z)
1030			zz, z = z, zz
1031
1032			if v&mask != 0 {
1033				zz = zz.mul(z, x)
1034				zz, z = z, zz
1035			}
1036
1037			if len(m) != 0 {
1038				zz, r = zz.div(r, z, m)
1039				zz, r, q, z = q, z, zz, r
1040			}
1041
1042			v <<= 1
1043		}
1044	}
1045
1046	return z.norm()
1047}
1048
1049// expNNMontgomeryEven calculates x**y mod m where m = m1 × m2 for m1 = 2ⁿ and m2 odd.
1050// It uses two recursive calls to expNN for x**y mod m1 and x**y mod m2
1051// and then uses the Chinese Remainder Theorem to combine the results.
1052// The recursive call using m1 will use expNNWindowed,
1053// while the recursive call using m2 will use expNNMontgomery.
1054// For more details, see Ç. K. Koç, “Montgomery Reduction with Even Modulus”,
1055// IEE Proceedings: Computers and Digital Techniques, 141(5) 314-316, September 1994.
1056// http://www.people.vcu.edu/~jwang3/CMSC691/j34monex.pdf
1057func (z nat) expNNMontgomeryEven(x, y, m nat) nat {
1058	// Split m = m₁ × m₂ where m₁ = 2ⁿ
1059	n := m.trailingZeroBits()
1060	m1 := nat(nil).shl(natOne, n)
1061	m2 := nat(nil).shr(m, n)
1062
1063	// We want z = x**y mod m.
1064	// z₁ = x**y mod m1 = (x**y mod m) mod m1 = z mod m1
1065	// z₂ = x**y mod m2 = (x**y mod m) mod m2 = z mod m2
1066	// (We are using the math/big convention for names here,
1067	// where the computation is z = x**y mod m, so its parts are z1 and z2.
1068	// The paper is computing x = a**e mod n; it refers to these as x2 and z1.)
1069	z1 := nat(nil).expNN(x, y, m1, false)
1070	z2 := nat(nil).expNN(x, y, m2, false)
1071
1072	// Reconstruct z from z₁, z₂ using CRT, using algorithm from paper,
1073	// which uses only a single modInverse (and an easy one at that).
1074	//	p = (z₁ - z₂) × m₂⁻¹ (mod m₁)
1075	//	z = z₂ + p × m₂
1076	// The final addition is in range because:
1077	//	z = z₂ + p × m₂
1078	//	  ≤ z₂ + (m₁-1) × m₂
1079	//	  < m₂ + (m₁-1) × m₂
1080	//	  = m₁ × m₂
1081	//	  = m.
1082	z = z.set(z2)
1083
1084	// Compute (z₁ - z₂) mod m1 [m1 == 2**n] into z1.
1085	z1 = z1.subMod2N(z1, z2, n)
1086
1087	// Reuse z2 for p = (z₁ - z₂) [in z1] * m2⁻¹ (mod m₁ [= 2ⁿ]).
1088	m2inv := nat(nil).modInverse(m2, m1)
1089	z2 = z2.mul(z1, m2inv)
1090	z2 = z2.trunc(z2, n)
1091
1092	// Reuse z1 for p * m2.
1093	z = z.add(z, z1.mul(z2, m2))
1094
1095	return z
1096}
1097
1098// expNNWindowed calculates x**y mod m using a fixed, 4-bit window,
1099// where m = 2**logM.
1100func (z nat) expNNWindowed(x, y nat, logM uint) nat {
1101	if len(y) <= 1 {
1102		panic("big: misuse of expNNWindowed")
1103	}
1104	if x[0]&1 == 0 {
1105		// len(y) > 1, so y  > logM.
1106		// x is even, so x**y is a multiple of 2**y which is a multiple of 2**logM.
1107		return z.setWord(0)
1108	}
1109	if logM == 1 {
1110		return z.setWord(1)
1111	}
1112
1113	// zz is used to avoid allocating in mul as otherwise
1114	// the arguments would alias.
1115	w := int((logM + _W - 1) / _W)
1116	zzp := getNat(w)
1117	zz := *zzp
1118
1119	const n = 4
1120	// powers[i] contains x^i.
1121	var powers [1 << n]*nat
1122	for i := range powers {
1123		powers[i] = getNat(w)
1124	}
1125	*powers[0] = powers[0].set(natOne)
1126	*powers[1] = powers[1].trunc(x, logM)
1127	for i := 2; i < 1<<n; i += 2 {
1128		p2, p, p1 := powers[i/2], powers[i], powers[i+1]
1129		*p = p.sqr(*p2)
1130		*p = p.trunc(*p, logM)
1131		*p1 = p1.mul(*p, x)
1132		*p1 = p1.trunc(*p1, logM)
1133	}
1134
1135	// Because phi(2**logM) = 2**(logM-1), x**(2**(logM-1)) = 1,
1136	// so we can compute x**(y mod 2**(logM-1)) instead of x**y.
1137	// That is, we can throw away all but the bottom logM-1 bits of y.
1138	// Instead of allocating a new y, we start reading y at the right word
1139	// and truncate it appropriately at the start of the loop.
1140	i := len(y) - 1
1141	mtop := int((logM - 2) / _W) // -2 because the top word of N bits is the (N-1)/W'th word.
1142	mmask := ^Word(0)
1143	if mbits := (logM - 1) & (_W - 1); mbits != 0 {
1144		mmask = (1 << mbits) - 1
1145	}
1146	if i > mtop {
1147		i = mtop
1148	}
1149	advance := false
1150	z = z.setWord(1)
1151	for ; i >= 0; i-- {
1152		yi := y[i]
1153		if i == mtop {
1154			yi &= mmask
1155		}
1156		for j := 0; j < _W; j += n {
1157			if advance {
1158				// Account for use of 4 bits in previous iteration.
1159				// Unrolled loop for significant performance
1160				// gain. Use go test -bench=".*" in crypto/rsa
1161				// to check performance before making changes.
1162				zz = zz.sqr(z)
1163				zz, z = z, zz
1164				z = z.trunc(z, logM)
1165
1166				zz = zz.sqr(z)
1167				zz, z = z, zz
1168				z = z.trunc(z, logM)
1169
1170				zz = zz.sqr(z)
1171				zz, z = z, zz
1172				z = z.trunc(z, logM)
1173
1174				zz = zz.sqr(z)
1175				zz, z = z, zz
1176				z = z.trunc(z, logM)
1177			}
1178
1179			zz = zz.mul(z, *powers[yi>>(_W-n)])
1180			zz, z = z, zz
1181			z = z.trunc(z, logM)
1182
1183			yi <<= n
1184			advance = true
1185		}
1186	}
1187
1188	*zzp = zz
1189	putNat(zzp)
1190	for i := range powers {
1191		putNat(powers[i])
1192	}
1193
1194	return z.norm()
1195}
1196
1197// expNNMontgomery calculates x**y mod m using a fixed, 4-bit window.
1198// Uses Montgomery representation.
1199func (z nat) expNNMontgomery(x, y, m nat) nat {
1200	numWords := len(m)
1201
1202	// We want the lengths of x and m to be equal.
1203	// It is OK if x >= m as long as len(x) == len(m).
1204	if len(x) > numWords {
1205		_, x = nat(nil).div(nil, x, m)
1206		// Note: now len(x) <= numWords, not guaranteed ==.
1207	}
1208	if len(x) < numWords {
1209		rr := make(nat, numWords)
1210		copy(rr, x)
1211		x = rr
1212	}
1213
1214	// Ideally the precomputations would be performed outside, and reused
1215	// k0 = -m**-1 mod 2**_W. Algorithm from: Dumas, J.G. "On Newton–Raphson
1216	// Iteration for Multiplicative Inverses Modulo Prime Powers".
1217	k0 := 2 - m[0]
1218	t := m[0] - 1
1219	for i := 1; i < _W; i <<= 1 {
1220		t *= t
1221		k0 *= (t + 1)
1222	}
1223	k0 = -k0
1224
1225	// RR = 2**(2*_W*len(m)) mod m
1226	RR := nat(nil).setWord(1)
1227	zz := nat(nil).shl(RR, uint(2*numWords*_W))
1228	_, RR = nat(nil).div(RR, zz, m)
1229	if len(RR) < numWords {
1230		zz = zz.make(numWords)
1231		copy(zz, RR)
1232		RR = zz
1233	}
1234	// one = 1, with equal length to that of m
1235	one := make(nat, numWords)
1236	one[0] = 1
1237
1238	const n = 4
1239	// powers[i] contains x^i
1240	var powers [1 << n]nat
1241	powers[0] = powers[0].montgomery(one, RR, m, k0, numWords)
1242	powers[1] = powers[1].montgomery(x, RR, m, k0, numWords)
1243	for i := 2; i < 1<<n; i++ {
1244		powers[i] = powers[i].montgomery(powers[i-1], powers[1], m, k0, numWords)
1245	}
1246
1247	// initialize z = 1 (Montgomery 1)
1248	z = z.make(numWords)
1249	copy(z, powers[0])
1250
1251	zz = zz.make(numWords)
1252
1253	// same windowed exponent, but with Montgomery multiplications
1254	for i := len(y) - 1; i >= 0; i-- {
1255		yi := y[i]
1256		for j := 0; j < _W; j += n {
1257			if i != len(y)-1 || j != 0 {
1258				zz = zz.montgomery(z, z, m, k0, numWords)
1259				z = z.montgomery(zz, zz, m, k0, numWords)
1260				zz = zz.montgomery(z, z, m, k0, numWords)
1261				z = z.montgomery(zz, zz, m, k0, numWords)
1262			}
1263			zz = zz.montgomery(z, powers[yi>>(_W-n)], m, k0, numWords)
1264			z, zz = zz, z
1265			yi <<= n
1266		}
1267	}
1268	// convert to regular number
1269	zz = zz.montgomery(z, one, m, k0, numWords)
1270
1271	// One last reduction, just in case.
1272	// See golang.org/issue/13907.
1273	if zz.cmp(m) >= 0 {
1274		// Common case is m has high bit set; in that case,
1275		// since zz is the same length as m, there can be just
1276		// one multiple of m to remove. Just subtract.
1277		// We think that the subtract should be sufficient in general,
1278		// so do that unconditionally, but double-check,
1279		// in case our beliefs are wrong.
1280		// The div is not expected to be reached.
1281		zz = zz.sub(zz, m)
1282		if zz.cmp(m) >= 0 {
1283			_, zz = nat(nil).div(nil, zz, m)
1284		}
1285	}
1286
1287	return zz.norm()
1288}
1289
1290// bytes writes the value of z into buf using big-endian encoding.
1291// The value of z is encoded in the slice buf[i:]. If the value of z
1292// cannot be represented in buf, bytes panics. The number i of unused
1293// bytes at the beginning of buf is returned as result.
1294func (z nat) bytes(buf []byte) (i int) {
1295	// This function is used in cryptographic operations. It must not leak
1296	// anything but the Int's sign and bit size through side-channels. Any
1297	// changes must be reviewed by a security expert.
1298	i = len(buf)
1299	for _, d := range z {
1300		for j := 0; j < _S; j++ {
1301			i--
1302			if i >= 0 {
1303				buf[i] = byte(d)
1304			} else if byte(d) != 0 {
1305				panic("math/big: buffer too small to fit value")
1306			}
1307			d >>= 8
1308		}
1309	}
1310
1311	if i < 0 {
1312		i = 0
1313	}
1314	for i < len(buf) && buf[i] == 0 {
1315		i++
1316	}
1317
1318	return
1319}
1320
1321// bigEndianWord returns the contents of buf interpreted as a big-endian encoded Word value.
1322func bigEndianWord(buf []byte) Word {
1323	if _W == 64 {
1324		return Word(byteorder.BeUint64(buf))
1325	}
1326	return Word(byteorder.BeUint32(buf))
1327}
1328
1329// setBytes interprets buf as the bytes of a big-endian unsigned
1330// integer, sets z to that value, and returns z.
1331func (z nat) setBytes(buf []byte) nat {
1332	z = z.make((len(buf) + _S - 1) / _S)
1333
1334	i := len(buf)
1335	for k := 0; i >= _S; k++ {
1336		z[k] = bigEndianWord(buf[i-_S : i])
1337		i -= _S
1338	}
1339	if i > 0 {
1340		var d Word
1341		for s := uint(0); i > 0; s += 8 {
1342			d |= Word(buf[i-1]) << s
1343			i--
1344		}
1345		z[len(z)-1] = d
1346	}
1347
1348	return z.norm()
1349}
1350
1351// sqrt sets z = ⌊√x⌋
1352func (z nat) sqrt(x nat) nat {
1353	if x.cmp(natOne) <= 0 {
1354		return z.set(x)
1355	}
1356	if alias(z, x) {
1357		z = nil
1358	}
1359
1360	// Start with value known to be too large and repeat "z = ⌊(z + ⌊x/z⌋)/2⌋" until it stops getting smaller.
1361	// See Brent and Zimmermann, Modern Computer Arithmetic, Algorithm 1.13 (SqrtInt).
1362	// https://members.loria.fr/PZimmermann/mca/pub226.html
1363	// If x is one less than a perfect square, the sequence oscillates between the correct z and z+1;
1364	// otherwise it converges to the correct z and stays there.
1365	var z1, z2 nat
1366	z1 = z
1367	z1 = z1.setUint64(1)
1368	z1 = z1.shl(z1, uint(x.bitLen()+1)/2) // must be ≥ √x
1369	for n := 0; ; n++ {
1370		z2, _ = z2.div(nil, x, z1)
1371		z2 = z2.add(z2, z1)
1372		z2 = z2.shr(z2, 1)
1373		if z2.cmp(z1) >= 0 {
1374			// z1 is answer.
1375			// Figure out whether z1 or z2 is currently aliased to z by looking at loop count.
1376			if n&1 == 0 {
1377				return z1
1378			}
1379			return z.set(z1)
1380		}
1381		z1, z2 = z2, z1
1382	}
1383}
1384
1385// subMod2N returns z = (x - y) mod 2ⁿ.
1386func (z nat) subMod2N(x, y nat, n uint) nat {
1387	if uint(x.bitLen()) > n {
1388		if alias(z, x) {
1389			// ok to overwrite x in place
1390			x = x.trunc(x, n)
1391		} else {
1392			x = nat(nil).trunc(x, n)
1393		}
1394	}
1395	if uint(y.bitLen()) > n {
1396		if alias(z, y) {
1397			// ok to overwrite y in place
1398			y = y.trunc(y, n)
1399		} else {
1400			y = nat(nil).trunc(y, n)
1401		}
1402	}
1403	if x.cmp(y) >= 0 {
1404		return z.sub(x, y)
1405	}
1406	// x - y < 0; x - y mod 2ⁿ = x - y + 2ⁿ = 2ⁿ - (y - x) = 1 + 2ⁿ-1 - (y - x) = 1 + ^(y - x).
1407	z = z.sub(y, x)
1408	for uint(len(z))*_W < n {
1409		z = append(z, 0)
1410	}
1411	for i := range z {
1412		z[i] = ^z[i]
1413	}
1414	z = z.trunc(z, n)
1415	return z.add(z, natOne)
1416}
1417