1// Copyright 2009 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5// This file implements unsigned multi-precision integers (natural 6// numbers). They are the building blocks for the implementation 7// of signed integers, rationals, and floating-point numbers. 8// 9// Caution: This implementation relies on the function "alias" 10// which assumes that (nat) slice capacities are never 11// changed (no 3-operand slice expressions). If that 12// changes, alias needs to be updated for correctness. 13 14package big 15 16import ( 17 "internal/byteorder" 18 "math/bits" 19 "math/rand" 20 "sync" 21) 22 23// An unsigned integer x of the form 24// 25// x = x[n-1]*_B^(n-1) + x[n-2]*_B^(n-2) + ... + x[1]*_B + x[0] 26// 27// with 0 <= x[i] < _B and 0 <= i < n is stored in a slice of length n, 28// with the digits x[i] as the slice elements. 29// 30// A number is normalized if the slice contains no leading 0 digits. 31// During arithmetic operations, denormalized values may occur but are 32// always normalized before returning the final result. The normalized 33// representation of 0 is the empty or nil slice (length = 0). 34type nat []Word 35 36var ( 37 natOne = nat{1} 38 natTwo = nat{2} 39 natFive = nat{5} 40 natTen = nat{10} 41) 42 43func (z nat) String() string { 44 return "0x" + string(z.itoa(false, 16)) 45} 46 47func (z nat) norm() nat { 48 i := len(z) 49 for i > 0 && z[i-1] == 0 { 50 i-- 51 } 52 return z[0:i] 53} 54 55func (z nat) make(n int) nat { 56 if n <= cap(z) { 57 return z[:n] // reuse z 58 } 59 if n == 1 { 60 // Most nats start small and stay that way; don't over-allocate. 61 return make(nat, 1) 62 } 63 // Choosing a good value for e has significant performance impact 64 // because it increases the chance that a value can be reused. 65 const e = 4 // extra capacity 66 return make(nat, n, n+e) 67} 68 69func (z nat) setWord(x Word) nat { 70 if x == 0 { 71 return z[:0] 72 } 73 z = z.make(1) 74 z[0] = x 75 return z 76} 77 78func (z nat) setUint64(x uint64) nat { 79 // single-word value 80 if w := Word(x); uint64(w) == x { 81 return z.setWord(w) 82 } 83 // 2-word value 84 z = z.make(2) 85 z[1] = Word(x >> 32) 86 z[0] = Word(x) 87 return z 88} 89 90func (z nat) set(x nat) nat { 91 z = z.make(len(x)) 92 copy(z, x) 93 return z 94} 95 96func (z nat) add(x, y nat) nat { 97 m := len(x) 98 n := len(y) 99 100 switch { 101 case m < n: 102 return z.add(y, x) 103 case m == 0: 104 // n == 0 because m >= n; result is 0 105 return z[:0] 106 case n == 0: 107 // result is x 108 return z.set(x) 109 } 110 // m > 0 111 112 z = z.make(m + 1) 113 c := addVV(z[0:n], x, y) 114 if m > n { 115 c = addVW(z[n:m], x[n:], c) 116 } 117 z[m] = c 118 119 return z.norm() 120} 121 122func (z nat) sub(x, y nat) nat { 123 m := len(x) 124 n := len(y) 125 126 switch { 127 case m < n: 128 panic("underflow") 129 case m == 0: 130 // n == 0 because m >= n; result is 0 131 return z[:0] 132 case n == 0: 133 // result is x 134 return z.set(x) 135 } 136 // m > 0 137 138 z = z.make(m) 139 c := subVV(z[0:n], x, y) 140 if m > n { 141 c = subVW(z[n:], x[n:], c) 142 } 143 if c != 0 { 144 panic("underflow") 145 } 146 147 return z.norm() 148} 149 150func (x nat) cmp(y nat) (r int) { 151 m := len(x) 152 n := len(y) 153 if m != n || m == 0 { 154 switch { 155 case m < n: 156 r = -1 157 case m > n: 158 r = 1 159 } 160 return 161 } 162 163 i := m - 1 164 for i > 0 && x[i] == y[i] { 165 i-- 166 } 167 168 switch { 169 case x[i] < y[i]: 170 r = -1 171 case x[i] > y[i]: 172 r = 1 173 } 174 return 175} 176 177func (z nat) mulAddWW(x nat, y, r Word) nat { 178 m := len(x) 179 if m == 0 || y == 0 { 180 return z.setWord(r) // result is r 181 } 182 // m > 0 183 184 z = z.make(m + 1) 185 z[m] = mulAddVWW(z[0:m], x, y, r) 186 187 return z.norm() 188} 189 190// basicMul multiplies x and y and leaves the result in z. 191// The (non-normalized) result is placed in z[0 : len(x) + len(y)]. 192func basicMul(z, x, y nat) { 193 clear(z[0 : len(x)+len(y)]) // initialize z 194 for i, d := range y { 195 if d != 0 { 196 z[len(x)+i] = addMulVVW(z[i:i+len(x)], x, d) 197 } 198 } 199} 200 201// montgomery computes z mod m = x*y*2**(-n*_W) mod m, 202// assuming k = -1/m mod 2**_W. 203// z is used for storing the result which is returned; 204// z must not alias x, y or m. 205// See Gueron, "Efficient Software Implementations of Modular Exponentiation". 206// https://eprint.iacr.org/2011/239.pdf 207// In the terminology of that paper, this is an "Almost Montgomery Multiplication": 208// x and y are required to satisfy 0 <= z < 2**(n*_W) and then the result 209// z is guaranteed to satisfy 0 <= z < 2**(n*_W), but it may not be < m. 210func (z nat) montgomery(x, y, m nat, k Word, n int) nat { 211 // This code assumes x, y, m are all the same length, n. 212 // (required by addMulVVW and the for loop). 213 // It also assumes that x, y are already reduced mod m, 214 // or else the result will not be properly reduced. 215 if len(x) != n || len(y) != n || len(m) != n { 216 panic("math/big: mismatched montgomery number lengths") 217 } 218 z = z.make(n * 2) 219 clear(z) 220 var c Word 221 for i := 0; i < n; i++ { 222 d := y[i] 223 c2 := addMulVVW(z[i:n+i], x, d) 224 t := z[i] * k 225 c3 := addMulVVW(z[i:n+i], m, t) 226 cx := c + c2 227 cy := cx + c3 228 z[n+i] = cy 229 if cx < c2 || cy < c3 { 230 c = 1 231 } else { 232 c = 0 233 } 234 } 235 if c != 0 { 236 subVV(z[:n], z[n:], m) 237 } else { 238 copy(z[:n], z[n:]) 239 } 240 return z[:n] 241} 242 243// Fast version of z[0:n+n>>1].add(z[0:n+n>>1], x[0:n]) w/o bounds checks. 244// Factored out for readability - do not use outside karatsuba. 245func karatsubaAdd(z, x nat, n int) { 246 if c := addVV(z[0:n], z, x); c != 0 { 247 addVW(z[n:n+n>>1], z[n:], c) 248 } 249} 250 251// Like karatsubaAdd, but does subtract. 252func karatsubaSub(z, x nat, n int) { 253 if c := subVV(z[0:n], z, x); c != 0 { 254 subVW(z[n:n+n>>1], z[n:], c) 255 } 256} 257 258// Operands that are shorter than karatsubaThreshold are multiplied using 259// "grade school" multiplication; for longer operands the Karatsuba algorithm 260// is used. 261var karatsubaThreshold = 40 // computed by calibrate_test.go 262 263// karatsuba multiplies x and y and leaves the result in z. 264// Both x and y must have the same length n and n must be a 265// power of 2. The result vector z must have len(z) >= 6*n. 266// The (non-normalized) result is placed in z[0 : 2*n]. 267func karatsuba(z, x, y nat) { 268 n := len(y) 269 270 // Switch to basic multiplication if numbers are odd or small. 271 // (n is always even if karatsubaThreshold is even, but be 272 // conservative) 273 if n&1 != 0 || n < karatsubaThreshold || n < 2 { 274 basicMul(z, x, y) 275 return 276 } 277 // n&1 == 0 && n >= karatsubaThreshold && n >= 2 278 279 // Karatsuba multiplication is based on the observation that 280 // for two numbers x and y with: 281 // 282 // x = x1*b + x0 283 // y = y1*b + y0 284 // 285 // the product x*y can be obtained with 3 products z2, z1, z0 286 // instead of 4: 287 // 288 // x*y = x1*y1*b*b + (x1*y0 + x0*y1)*b + x0*y0 289 // = z2*b*b + z1*b + z0 290 // 291 // with: 292 // 293 // xd = x1 - x0 294 // yd = y0 - y1 295 // 296 // z1 = xd*yd + z2 + z0 297 // = (x1-x0)*(y0 - y1) + z2 + z0 298 // = x1*y0 - x1*y1 - x0*y0 + x0*y1 + z2 + z0 299 // = x1*y0 - z2 - z0 + x0*y1 + z2 + z0 300 // = x1*y0 + x0*y1 301 302 // split x, y into "digits" 303 n2 := n >> 1 // n2 >= 1 304 x1, x0 := x[n2:], x[0:n2] // x = x1*b + y0 305 y1, y0 := y[n2:], y[0:n2] // y = y1*b + y0 306 307 // z is used for the result and temporary storage: 308 // 309 // 6*n 5*n 4*n 3*n 2*n 1*n 0*n 310 // z = [z2 copy|z0 copy| xd*yd | yd:xd | x1*y1 | x0*y0 ] 311 // 312 // For each recursive call of karatsuba, an unused slice of 313 // z is passed in that has (at least) half the length of the 314 // caller's z. 315 316 // compute z0 and z2 with the result "in place" in z 317 karatsuba(z, x0, y0) // z0 = x0*y0 318 karatsuba(z[n:], x1, y1) // z2 = x1*y1 319 320 // compute xd (or the negative value if underflow occurs) 321 s := 1 // sign of product xd*yd 322 xd := z[2*n : 2*n+n2] 323 if subVV(xd, x1, x0) != 0 { // x1-x0 324 s = -s 325 subVV(xd, x0, x1) // x0-x1 326 } 327 328 // compute yd (or the negative value if underflow occurs) 329 yd := z[2*n+n2 : 3*n] 330 if subVV(yd, y0, y1) != 0 { // y0-y1 331 s = -s 332 subVV(yd, y1, y0) // y1-y0 333 } 334 335 // p = (x1-x0)*(y0-y1) == x1*y0 - x1*y1 - x0*y0 + x0*y1 for s > 0 336 // p = (x0-x1)*(y0-y1) == x0*y0 - x0*y1 - x1*y0 + x1*y1 for s < 0 337 p := z[n*3:] 338 karatsuba(p, xd, yd) 339 340 // save original z2:z0 341 // (ok to use upper half of z since we're done recurring) 342 r := z[n*4:] 343 copy(r, z[:n*2]) 344 345 // add up all partial products 346 // 347 // 2*n n 0 348 // z = [ z2 | z0 ] 349 // + [ z0 ] 350 // + [ z2 ] 351 // + [ p ] 352 // 353 karatsubaAdd(z[n2:], r, n) 354 karatsubaAdd(z[n2:], r[n:], n) 355 if s > 0 { 356 karatsubaAdd(z[n2:], p, n) 357 } else { 358 karatsubaSub(z[n2:], p, n) 359 } 360} 361 362// alias reports whether x and y share the same base array. 363// 364// Note: alias assumes that the capacity of underlying arrays 365// is never changed for nat values; i.e. that there are 366// no 3-operand slice expressions in this code (or worse, 367// reflect-based operations to the same effect). 368func alias(x, y nat) bool { 369 return cap(x) > 0 && cap(y) > 0 && &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1] 370} 371 372// addAt implements z += x<<(_W*i); z must be long enough. 373// (we don't use nat.add because we need z to stay the same 374// slice, and we don't need to normalize z after each addition) 375func addAt(z, x nat, i int) { 376 if n := len(x); n > 0 { 377 if c := addVV(z[i:i+n], z[i:], x); c != 0 { 378 j := i + n 379 if j < len(z) { 380 addVW(z[j:], z[j:], c) 381 } 382 } 383 } 384} 385 386// karatsubaLen computes an approximation to the maximum k <= n such that 387// k = p<<i for a number p <= threshold and an i >= 0. Thus, the 388// result is the largest number that can be divided repeatedly by 2 before 389// becoming about the value of threshold. 390func karatsubaLen(n, threshold int) int { 391 i := uint(0) 392 for n > threshold { 393 n >>= 1 394 i++ 395 } 396 return n << i 397} 398 399func (z nat) mul(x, y nat) nat { 400 m := len(x) 401 n := len(y) 402 403 switch { 404 case m < n: 405 return z.mul(y, x) 406 case m == 0 || n == 0: 407 return z[:0] 408 case n == 1: 409 return z.mulAddWW(x, y[0], 0) 410 } 411 // m >= n > 1 412 413 // determine if z can be reused 414 if alias(z, x) || alias(z, y) { 415 z = nil // z is an alias for x or y - cannot reuse 416 } 417 418 // use basic multiplication if the numbers are small 419 if n < karatsubaThreshold { 420 z = z.make(m + n) 421 basicMul(z, x, y) 422 return z.norm() 423 } 424 // m >= n && n >= karatsubaThreshold && n >= 2 425 426 // determine Karatsuba length k such that 427 // 428 // x = xh*b + x0 (0 <= x0 < b) 429 // y = yh*b + y0 (0 <= y0 < b) 430 // b = 1<<(_W*k) ("base" of digits xi, yi) 431 // 432 k := karatsubaLen(n, karatsubaThreshold) 433 // k <= n 434 435 // multiply x0 and y0 via Karatsuba 436 x0 := x[0:k] // x0 is not normalized 437 y0 := y[0:k] // y0 is not normalized 438 z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and full result of x*y 439 karatsuba(z, x0, y0) 440 z = z[0 : m+n] // z has final length but may be incomplete 441 clear(z[2*k:]) // upper portion of z is garbage (and 2*k <= m+n since k <= n <= m) 442 443 // If xh != 0 or yh != 0, add the missing terms to z. For 444 // 445 // xh = xi*b^i + ... + x2*b^2 + x1*b (0 <= xi < b) 446 // yh = y1*b (0 <= y1 < b) 447 // 448 // the missing terms are 449 // 450 // x0*y1*b and xi*y0*b^i, xi*y1*b^(i+1) for i > 0 451 // 452 // since all the yi for i > 1 are 0 by choice of k: If any of them 453 // were > 0, then yh >= b^2 and thus y >= b^2. Then k' = k*2 would 454 // be a larger valid threshold contradicting the assumption about k. 455 // 456 if k < n || m != n { 457 tp := getNat(3 * k) 458 t := *tp 459 460 // add x0*y1*b 461 x0 := x0.norm() 462 y1 := y[k:] // y1 is normalized because y is 463 t = t.mul(x0, y1) // update t so we don't lose t's underlying array 464 addAt(z, t, k) 465 466 // add xi*y0<<i, xi*y1*b<<(i+k) 467 y0 := y0.norm() 468 for i := k; i < len(x); i += k { 469 xi := x[i:] 470 if len(xi) > k { 471 xi = xi[:k] 472 } 473 xi = xi.norm() 474 t = t.mul(xi, y0) 475 addAt(z, t, i) 476 t = t.mul(xi, y1) 477 addAt(z, t, i+k) 478 } 479 480 putNat(tp) 481 } 482 483 return z.norm() 484} 485 486// basicSqr sets z = x*x and is asymptotically faster than basicMul 487// by about a factor of 2, but slower for small arguments due to overhead. 488// Requirements: len(x) > 0, len(z) == 2*len(x) 489// The (non-normalized) result is placed in z. 490func basicSqr(z, x nat) { 491 n := len(x) 492 tp := getNat(2 * n) 493 t := *tp // temporary variable to hold the products 494 clear(t) 495 z[1], z[0] = mulWW(x[0], x[0]) // the initial square 496 for i := 1; i < n; i++ { 497 d := x[i] 498 // z collects the squares x[i] * x[i] 499 z[2*i+1], z[2*i] = mulWW(d, d) 500 // t collects the products x[i] * x[j] where j < i 501 t[2*i] = addMulVVW(t[i:2*i], x[0:i], d) 502 } 503 t[2*n-1] = shlVU(t[1:2*n-1], t[1:2*n-1], 1) // double the j < i products 504 addVV(z, z, t) // combine the result 505 putNat(tp) 506} 507 508// karatsubaSqr squares x and leaves the result in z. 509// len(x) must be a power of 2 and len(z) >= 6*len(x). 510// The (non-normalized) result is placed in z[0 : 2*len(x)]. 511// 512// The algorithm and the layout of z are the same as for karatsuba. 513func karatsubaSqr(z, x nat) { 514 n := len(x) 515 516 if n&1 != 0 || n < karatsubaSqrThreshold || n < 2 { 517 basicSqr(z[:2*n], x) 518 return 519 } 520 521 n2 := n >> 1 522 x1, x0 := x[n2:], x[0:n2] 523 524 karatsubaSqr(z, x0) 525 karatsubaSqr(z[n:], x1) 526 527 // s = sign(xd*yd) == -1 for xd != 0; s == 1 for xd == 0 528 xd := z[2*n : 2*n+n2] 529 if subVV(xd, x1, x0) != 0 { 530 subVV(xd, x0, x1) 531 } 532 533 p := z[n*3:] 534 karatsubaSqr(p, xd) 535 536 r := z[n*4:] 537 copy(r, z[:n*2]) 538 539 karatsubaAdd(z[n2:], r, n) 540 karatsubaAdd(z[n2:], r[n:], n) 541 karatsubaSub(z[n2:], p, n) // s == -1 for p != 0; s == 1 for p == 0 542} 543 544// Operands that are shorter than basicSqrThreshold are squared using 545// "grade school" multiplication; for operands longer than karatsubaSqrThreshold 546// we use the Karatsuba algorithm optimized for x == y. 547var basicSqrThreshold = 20 // computed by calibrate_test.go 548var karatsubaSqrThreshold = 260 // computed by calibrate_test.go 549 550// z = x*x 551func (z nat) sqr(x nat) nat { 552 n := len(x) 553 switch { 554 case n == 0: 555 return z[:0] 556 case n == 1: 557 d := x[0] 558 z = z.make(2) 559 z[1], z[0] = mulWW(d, d) 560 return z.norm() 561 } 562 563 if alias(z, x) { 564 z = nil // z is an alias for x - cannot reuse 565 } 566 567 if n < basicSqrThreshold { 568 z = z.make(2 * n) 569 basicMul(z, x, x) 570 return z.norm() 571 } 572 if n < karatsubaSqrThreshold { 573 z = z.make(2 * n) 574 basicSqr(z, x) 575 return z.norm() 576 } 577 578 // Use Karatsuba multiplication optimized for x == y. 579 // The algorithm and layout of z are the same as for mul. 580 581 // z = (x1*b + x0)^2 = x1^2*b^2 + 2*x1*x0*b + x0^2 582 583 k := karatsubaLen(n, karatsubaSqrThreshold) 584 585 x0 := x[0:k] 586 z = z.make(max(6*k, 2*n)) 587 karatsubaSqr(z, x0) // z = x0^2 588 z = z[0 : 2*n] 589 clear(z[2*k:]) 590 591 if k < n { 592 tp := getNat(2 * k) 593 t := *tp 594 x0 := x0.norm() 595 x1 := x[k:] 596 t = t.mul(x0, x1) 597 addAt(z, t, k) 598 addAt(z, t, k) // z = 2*x1*x0*b + x0^2 599 t = t.sqr(x1) 600 addAt(z, t, 2*k) // z = x1^2*b^2 + 2*x1*x0*b + x0^2 601 putNat(tp) 602 } 603 604 return z.norm() 605} 606 607// mulRange computes the product of all the unsigned integers in the 608// range [a, b] inclusively. If a > b (empty range), the result is 1. 609func (z nat) mulRange(a, b uint64) nat { 610 switch { 611 case a == 0: 612 // cut long ranges short (optimization) 613 return z.setUint64(0) 614 case a > b: 615 return z.setUint64(1) 616 case a == b: 617 return z.setUint64(a) 618 case a+1 == b: 619 return z.mul(nat(nil).setUint64(a), nat(nil).setUint64(b)) 620 } 621 m := a + (b-a)/2 // avoid overflow 622 return z.mul(nat(nil).mulRange(a, m), nat(nil).mulRange(m+1, b)) 623} 624 625// getNat returns a *nat of len n. The contents may not be zero. 626// The pool holds *nat to avoid allocation when converting to interface{}. 627func getNat(n int) *nat { 628 var z *nat 629 if v := natPool.Get(); v != nil { 630 z = v.(*nat) 631 } 632 if z == nil { 633 z = new(nat) 634 } 635 *z = z.make(n) 636 if n > 0 { 637 (*z)[0] = 0xfedcb // break code expecting zero 638 } 639 return z 640} 641 642func putNat(x *nat) { 643 natPool.Put(x) 644} 645 646var natPool sync.Pool 647 648// bitLen returns the length of x in bits. 649// Unlike most methods, it works even if x is not normalized. 650func (x nat) bitLen() int { 651 // This function is used in cryptographic operations. It must not leak 652 // anything but the Int's sign and bit size through side-channels. Any 653 // changes must be reviewed by a security expert. 654 if i := len(x) - 1; i >= 0 { 655 // bits.Len uses a lookup table for the low-order bits on some 656 // architectures. Neutralize any input-dependent behavior by setting all 657 // bits after the first one bit. 658 top := uint(x[i]) 659 top |= top >> 1 660 top |= top >> 2 661 top |= top >> 4 662 top |= top >> 8 663 top |= top >> 16 664 top |= top >> 16 >> 16 // ">> 32" doesn't compile on 32-bit architectures 665 return i*_W + bits.Len(top) 666 } 667 return 0 668} 669 670// trailingZeroBits returns the number of consecutive least significant zero 671// bits of x. 672func (x nat) trailingZeroBits() uint { 673 if len(x) == 0 { 674 return 0 675 } 676 var i uint 677 for x[i] == 0 { 678 i++ 679 } 680 // x[i] != 0 681 return i*_W + uint(bits.TrailingZeros(uint(x[i]))) 682} 683 684// isPow2 returns i, true when x == 2**i and 0, false otherwise. 685func (x nat) isPow2() (uint, bool) { 686 var i uint 687 for x[i] == 0 { 688 i++ 689 } 690 if i == uint(len(x))-1 && x[i]&(x[i]-1) == 0 { 691 return i*_W + uint(bits.TrailingZeros(uint(x[i]))), true 692 } 693 return 0, false 694} 695 696func same(x, y nat) bool { 697 return len(x) == len(y) && len(x) > 0 && &x[0] == &y[0] 698} 699 700// z = x << s 701func (z nat) shl(x nat, s uint) nat { 702 if s == 0 { 703 if same(z, x) { 704 return z 705 } 706 if !alias(z, x) { 707 return z.set(x) 708 } 709 } 710 711 m := len(x) 712 if m == 0 { 713 return z[:0] 714 } 715 // m > 0 716 717 n := m + int(s/_W) 718 z = z.make(n + 1) 719 z[n] = shlVU(z[n-m:n], x, s%_W) 720 clear(z[0 : n-m]) 721 722 return z.norm() 723} 724 725// z = x >> s 726func (z nat) shr(x nat, s uint) nat { 727 if s == 0 { 728 if same(z, x) { 729 return z 730 } 731 if !alias(z, x) { 732 return z.set(x) 733 } 734 } 735 736 m := len(x) 737 n := m - int(s/_W) 738 if n <= 0 { 739 return z[:0] 740 } 741 // n > 0 742 743 z = z.make(n) 744 shrVU(z, x[m-n:], s%_W) 745 746 return z.norm() 747} 748 749func (z nat) setBit(x nat, i uint, b uint) nat { 750 j := int(i / _W) 751 m := Word(1) << (i % _W) 752 n := len(x) 753 switch b { 754 case 0: 755 z = z.make(n) 756 copy(z, x) 757 if j >= n { 758 // no need to grow 759 return z 760 } 761 z[j] &^= m 762 return z.norm() 763 case 1: 764 if j >= n { 765 z = z.make(j + 1) 766 clear(z[n:]) 767 } else { 768 z = z.make(n) 769 } 770 copy(z, x) 771 z[j] |= m 772 // no need to normalize 773 return z 774 } 775 panic("set bit is not 0 or 1") 776} 777 778// bit returns the value of the i'th bit, with lsb == bit 0. 779func (x nat) bit(i uint) uint { 780 j := i / _W 781 if j >= uint(len(x)) { 782 return 0 783 } 784 // 0 <= j < len(x) 785 return uint(x[j] >> (i % _W) & 1) 786} 787 788// sticky returns 1 if there's a 1 bit within the 789// i least significant bits, otherwise it returns 0. 790func (x nat) sticky(i uint) uint { 791 j := i / _W 792 if j >= uint(len(x)) { 793 if len(x) == 0 { 794 return 0 795 } 796 return 1 797 } 798 // 0 <= j < len(x) 799 for _, x := range x[:j] { 800 if x != 0 { 801 return 1 802 } 803 } 804 if x[j]<<(_W-i%_W) != 0 { 805 return 1 806 } 807 return 0 808} 809 810func (z nat) and(x, y nat) nat { 811 m := len(x) 812 n := len(y) 813 if m > n { 814 m = n 815 } 816 // m <= n 817 818 z = z.make(m) 819 for i := 0; i < m; i++ { 820 z[i] = x[i] & y[i] 821 } 822 823 return z.norm() 824} 825 826// trunc returns z = x mod 2ⁿ. 827func (z nat) trunc(x nat, n uint) nat { 828 w := (n + _W - 1) / _W 829 if uint(len(x)) < w { 830 return z.set(x) 831 } 832 z = z.make(int(w)) 833 copy(z, x) 834 if n%_W != 0 { 835 z[len(z)-1] &= 1<<(n%_W) - 1 836 } 837 return z.norm() 838} 839 840func (z nat) andNot(x, y nat) nat { 841 m := len(x) 842 n := len(y) 843 if n > m { 844 n = m 845 } 846 // m >= n 847 848 z = z.make(m) 849 for i := 0; i < n; i++ { 850 z[i] = x[i] &^ y[i] 851 } 852 copy(z[n:m], x[n:m]) 853 854 return z.norm() 855} 856 857func (z nat) or(x, y nat) nat { 858 m := len(x) 859 n := len(y) 860 s := x 861 if m < n { 862 n, m = m, n 863 s = y 864 } 865 // m >= n 866 867 z = z.make(m) 868 for i := 0; i < n; i++ { 869 z[i] = x[i] | y[i] 870 } 871 copy(z[n:m], s[n:m]) 872 873 return z.norm() 874} 875 876func (z nat) xor(x, y nat) nat { 877 m := len(x) 878 n := len(y) 879 s := x 880 if m < n { 881 n, m = m, n 882 s = y 883 } 884 // m >= n 885 886 z = z.make(m) 887 for i := 0; i < n; i++ { 888 z[i] = x[i] ^ y[i] 889 } 890 copy(z[n:m], s[n:m]) 891 892 return z.norm() 893} 894 895// random creates a random integer in [0..limit), using the space in z if 896// possible. n is the bit length of limit. 897func (z nat) random(rand *rand.Rand, limit nat, n int) nat { 898 if alias(z, limit) { 899 z = nil // z is an alias for limit - cannot reuse 900 } 901 z = z.make(len(limit)) 902 903 bitLengthOfMSW := uint(n % _W) 904 if bitLengthOfMSW == 0 { 905 bitLengthOfMSW = _W 906 } 907 mask := Word((1 << bitLengthOfMSW) - 1) 908 909 for { 910 switch _W { 911 case 32: 912 for i := range z { 913 z[i] = Word(rand.Uint32()) 914 } 915 case 64: 916 for i := range z { 917 z[i] = Word(rand.Uint32()) | Word(rand.Uint32())<<32 918 } 919 default: 920 panic("unknown word size") 921 } 922 z[len(limit)-1] &= mask 923 if z.cmp(limit) < 0 { 924 break 925 } 926 } 927 928 return z.norm() 929} 930 931// If m != 0 (i.e., len(m) != 0), expNN sets z to x**y mod m; 932// otherwise it sets z to x**y. The result is the value of z. 933func (z nat) expNN(x, y, m nat, slow bool) nat { 934 if alias(z, x) || alias(z, y) { 935 // We cannot allow in-place modification of x or y. 936 z = nil 937 } 938 939 // x**y mod 1 == 0 940 if len(m) == 1 && m[0] == 1 { 941 return z.setWord(0) 942 } 943 // m == 0 || m > 1 944 945 // x**0 == 1 946 if len(y) == 0 { 947 return z.setWord(1) 948 } 949 // y > 0 950 951 // 0**y = 0 952 if len(x) == 0 { 953 return z.setWord(0) 954 } 955 // x > 0 956 957 // 1**y = 1 958 if len(x) == 1 && x[0] == 1 { 959 return z.setWord(1) 960 } 961 // x > 1 962 963 // x**1 == x 964 if len(y) == 1 && y[0] == 1 { 965 if len(m) != 0 { 966 return z.rem(x, m) 967 } 968 return z.set(x) 969 } 970 // y > 1 971 972 if len(m) != 0 { 973 // We likely end up being as long as the modulus. 974 z = z.make(len(m)) 975 976 // If the exponent is large, we use the Montgomery method for odd values, 977 // and a 4-bit, windowed exponentiation for powers of two, 978 // and a CRT-decomposed Montgomery method for the remaining values 979 // (even values times non-trivial odd values, which decompose into one 980 // instance of each of the first two cases). 981 if len(y) > 1 && !slow { 982 if m[0]&1 == 1 { 983 return z.expNNMontgomery(x, y, m) 984 } 985 if logM, ok := m.isPow2(); ok { 986 return z.expNNWindowed(x, y, logM) 987 } 988 return z.expNNMontgomeryEven(x, y, m) 989 } 990 } 991 992 z = z.set(x) 993 v := y[len(y)-1] // v > 0 because y is normalized and y > 0 994 shift := nlz(v) + 1 995 v <<= shift 996 var q nat 997 998 const mask = 1 << (_W - 1) 999 1000 // We walk through the bits of the exponent one by one. Each time we 1001 // see a bit, we square, thus doubling the power. If the bit is a one, 1002 // we also multiply by x, thus adding one to the power. 1003 1004 w := _W - int(shift) 1005 // zz and r are used to avoid allocating in mul and div as 1006 // otherwise the arguments would alias. 1007 var zz, r nat 1008 for j := 0; j < w; j++ { 1009 zz = zz.sqr(z) 1010 zz, z = z, zz 1011 1012 if v&mask != 0 { 1013 zz = zz.mul(z, x) 1014 zz, z = z, zz 1015 } 1016 1017 if len(m) != 0 { 1018 zz, r = zz.div(r, z, m) 1019 zz, r, q, z = q, z, zz, r 1020 } 1021 1022 v <<= 1 1023 } 1024 1025 for i := len(y) - 2; i >= 0; i-- { 1026 v = y[i] 1027 1028 for j := 0; j < _W; j++ { 1029 zz = zz.sqr(z) 1030 zz, z = z, zz 1031 1032 if v&mask != 0 { 1033 zz = zz.mul(z, x) 1034 zz, z = z, zz 1035 } 1036 1037 if len(m) != 0 { 1038 zz, r = zz.div(r, z, m) 1039 zz, r, q, z = q, z, zz, r 1040 } 1041 1042 v <<= 1 1043 } 1044 } 1045 1046 return z.norm() 1047} 1048 1049// expNNMontgomeryEven calculates x**y mod m where m = m1 × m2 for m1 = 2ⁿ and m2 odd. 1050// It uses two recursive calls to expNN for x**y mod m1 and x**y mod m2 1051// and then uses the Chinese Remainder Theorem to combine the results. 1052// The recursive call using m1 will use expNNWindowed, 1053// while the recursive call using m2 will use expNNMontgomery. 1054// For more details, see Ç. K. Koç, “Montgomery Reduction with Even Modulus”, 1055// IEE Proceedings: Computers and Digital Techniques, 141(5) 314-316, September 1994. 1056// http://www.people.vcu.edu/~jwang3/CMSC691/j34monex.pdf 1057func (z nat) expNNMontgomeryEven(x, y, m nat) nat { 1058 // Split m = m₁ × m₂ where m₁ = 2ⁿ 1059 n := m.trailingZeroBits() 1060 m1 := nat(nil).shl(natOne, n) 1061 m2 := nat(nil).shr(m, n) 1062 1063 // We want z = x**y mod m. 1064 // z₁ = x**y mod m1 = (x**y mod m) mod m1 = z mod m1 1065 // z₂ = x**y mod m2 = (x**y mod m) mod m2 = z mod m2 1066 // (We are using the math/big convention for names here, 1067 // where the computation is z = x**y mod m, so its parts are z1 and z2. 1068 // The paper is computing x = a**e mod n; it refers to these as x2 and z1.) 1069 z1 := nat(nil).expNN(x, y, m1, false) 1070 z2 := nat(nil).expNN(x, y, m2, false) 1071 1072 // Reconstruct z from z₁, z₂ using CRT, using algorithm from paper, 1073 // which uses only a single modInverse (and an easy one at that). 1074 // p = (z₁ - z₂) × m₂⁻¹ (mod m₁) 1075 // z = z₂ + p × m₂ 1076 // The final addition is in range because: 1077 // z = z₂ + p × m₂ 1078 // ≤ z₂ + (m₁-1) × m₂ 1079 // < m₂ + (m₁-1) × m₂ 1080 // = m₁ × m₂ 1081 // = m. 1082 z = z.set(z2) 1083 1084 // Compute (z₁ - z₂) mod m1 [m1 == 2**n] into z1. 1085 z1 = z1.subMod2N(z1, z2, n) 1086 1087 // Reuse z2 for p = (z₁ - z₂) [in z1] * m2⁻¹ (mod m₁ [= 2ⁿ]). 1088 m2inv := nat(nil).modInverse(m2, m1) 1089 z2 = z2.mul(z1, m2inv) 1090 z2 = z2.trunc(z2, n) 1091 1092 // Reuse z1 for p * m2. 1093 z = z.add(z, z1.mul(z2, m2)) 1094 1095 return z 1096} 1097 1098// expNNWindowed calculates x**y mod m using a fixed, 4-bit window, 1099// where m = 2**logM. 1100func (z nat) expNNWindowed(x, y nat, logM uint) nat { 1101 if len(y) <= 1 { 1102 panic("big: misuse of expNNWindowed") 1103 } 1104 if x[0]&1 == 0 { 1105 // len(y) > 1, so y > logM. 1106 // x is even, so x**y is a multiple of 2**y which is a multiple of 2**logM. 1107 return z.setWord(0) 1108 } 1109 if logM == 1 { 1110 return z.setWord(1) 1111 } 1112 1113 // zz is used to avoid allocating in mul as otherwise 1114 // the arguments would alias. 1115 w := int((logM + _W - 1) / _W) 1116 zzp := getNat(w) 1117 zz := *zzp 1118 1119 const n = 4 1120 // powers[i] contains x^i. 1121 var powers [1 << n]*nat 1122 for i := range powers { 1123 powers[i] = getNat(w) 1124 } 1125 *powers[0] = powers[0].set(natOne) 1126 *powers[1] = powers[1].trunc(x, logM) 1127 for i := 2; i < 1<<n; i += 2 { 1128 p2, p, p1 := powers[i/2], powers[i], powers[i+1] 1129 *p = p.sqr(*p2) 1130 *p = p.trunc(*p, logM) 1131 *p1 = p1.mul(*p, x) 1132 *p1 = p1.trunc(*p1, logM) 1133 } 1134 1135 // Because phi(2**logM) = 2**(logM-1), x**(2**(logM-1)) = 1, 1136 // so we can compute x**(y mod 2**(logM-1)) instead of x**y. 1137 // That is, we can throw away all but the bottom logM-1 bits of y. 1138 // Instead of allocating a new y, we start reading y at the right word 1139 // and truncate it appropriately at the start of the loop. 1140 i := len(y) - 1 1141 mtop := int((logM - 2) / _W) // -2 because the top word of N bits is the (N-1)/W'th word. 1142 mmask := ^Word(0) 1143 if mbits := (logM - 1) & (_W - 1); mbits != 0 { 1144 mmask = (1 << mbits) - 1 1145 } 1146 if i > mtop { 1147 i = mtop 1148 } 1149 advance := false 1150 z = z.setWord(1) 1151 for ; i >= 0; i-- { 1152 yi := y[i] 1153 if i == mtop { 1154 yi &= mmask 1155 } 1156 for j := 0; j < _W; j += n { 1157 if advance { 1158 // Account for use of 4 bits in previous iteration. 1159 // Unrolled loop for significant performance 1160 // gain. Use go test -bench=".*" in crypto/rsa 1161 // to check performance before making changes. 1162 zz = zz.sqr(z) 1163 zz, z = z, zz 1164 z = z.trunc(z, logM) 1165 1166 zz = zz.sqr(z) 1167 zz, z = z, zz 1168 z = z.trunc(z, logM) 1169 1170 zz = zz.sqr(z) 1171 zz, z = z, zz 1172 z = z.trunc(z, logM) 1173 1174 zz = zz.sqr(z) 1175 zz, z = z, zz 1176 z = z.trunc(z, logM) 1177 } 1178 1179 zz = zz.mul(z, *powers[yi>>(_W-n)]) 1180 zz, z = z, zz 1181 z = z.trunc(z, logM) 1182 1183 yi <<= n 1184 advance = true 1185 } 1186 } 1187 1188 *zzp = zz 1189 putNat(zzp) 1190 for i := range powers { 1191 putNat(powers[i]) 1192 } 1193 1194 return z.norm() 1195} 1196 1197// expNNMontgomery calculates x**y mod m using a fixed, 4-bit window. 1198// Uses Montgomery representation. 1199func (z nat) expNNMontgomery(x, y, m nat) nat { 1200 numWords := len(m) 1201 1202 // We want the lengths of x and m to be equal. 1203 // It is OK if x >= m as long as len(x) == len(m). 1204 if len(x) > numWords { 1205 _, x = nat(nil).div(nil, x, m) 1206 // Note: now len(x) <= numWords, not guaranteed ==. 1207 } 1208 if len(x) < numWords { 1209 rr := make(nat, numWords) 1210 copy(rr, x) 1211 x = rr 1212 } 1213 1214 // Ideally the precomputations would be performed outside, and reused 1215 // k0 = -m**-1 mod 2**_W. Algorithm from: Dumas, J.G. "On Newton–Raphson 1216 // Iteration for Multiplicative Inverses Modulo Prime Powers". 1217 k0 := 2 - m[0] 1218 t := m[0] - 1 1219 for i := 1; i < _W; i <<= 1 { 1220 t *= t 1221 k0 *= (t + 1) 1222 } 1223 k0 = -k0 1224 1225 // RR = 2**(2*_W*len(m)) mod m 1226 RR := nat(nil).setWord(1) 1227 zz := nat(nil).shl(RR, uint(2*numWords*_W)) 1228 _, RR = nat(nil).div(RR, zz, m) 1229 if len(RR) < numWords { 1230 zz = zz.make(numWords) 1231 copy(zz, RR) 1232 RR = zz 1233 } 1234 // one = 1, with equal length to that of m 1235 one := make(nat, numWords) 1236 one[0] = 1 1237 1238 const n = 4 1239 // powers[i] contains x^i 1240 var powers [1 << n]nat 1241 powers[0] = powers[0].montgomery(one, RR, m, k0, numWords) 1242 powers[1] = powers[1].montgomery(x, RR, m, k0, numWords) 1243 for i := 2; i < 1<<n; i++ { 1244 powers[i] = powers[i].montgomery(powers[i-1], powers[1], m, k0, numWords) 1245 } 1246 1247 // initialize z = 1 (Montgomery 1) 1248 z = z.make(numWords) 1249 copy(z, powers[0]) 1250 1251 zz = zz.make(numWords) 1252 1253 // same windowed exponent, but with Montgomery multiplications 1254 for i := len(y) - 1; i >= 0; i-- { 1255 yi := y[i] 1256 for j := 0; j < _W; j += n { 1257 if i != len(y)-1 || j != 0 { 1258 zz = zz.montgomery(z, z, m, k0, numWords) 1259 z = z.montgomery(zz, zz, m, k0, numWords) 1260 zz = zz.montgomery(z, z, m, k0, numWords) 1261 z = z.montgomery(zz, zz, m, k0, numWords) 1262 } 1263 zz = zz.montgomery(z, powers[yi>>(_W-n)], m, k0, numWords) 1264 z, zz = zz, z 1265 yi <<= n 1266 } 1267 } 1268 // convert to regular number 1269 zz = zz.montgomery(z, one, m, k0, numWords) 1270 1271 // One last reduction, just in case. 1272 // See golang.org/issue/13907. 1273 if zz.cmp(m) >= 0 { 1274 // Common case is m has high bit set; in that case, 1275 // since zz is the same length as m, there can be just 1276 // one multiple of m to remove. Just subtract. 1277 // We think that the subtract should be sufficient in general, 1278 // so do that unconditionally, but double-check, 1279 // in case our beliefs are wrong. 1280 // The div is not expected to be reached. 1281 zz = zz.sub(zz, m) 1282 if zz.cmp(m) >= 0 { 1283 _, zz = nat(nil).div(nil, zz, m) 1284 } 1285 } 1286 1287 return zz.norm() 1288} 1289 1290// bytes writes the value of z into buf using big-endian encoding. 1291// The value of z is encoded in the slice buf[i:]. If the value of z 1292// cannot be represented in buf, bytes panics. The number i of unused 1293// bytes at the beginning of buf is returned as result. 1294func (z nat) bytes(buf []byte) (i int) { 1295 // This function is used in cryptographic operations. It must not leak 1296 // anything but the Int's sign and bit size through side-channels. Any 1297 // changes must be reviewed by a security expert. 1298 i = len(buf) 1299 for _, d := range z { 1300 for j := 0; j < _S; j++ { 1301 i-- 1302 if i >= 0 { 1303 buf[i] = byte(d) 1304 } else if byte(d) != 0 { 1305 panic("math/big: buffer too small to fit value") 1306 } 1307 d >>= 8 1308 } 1309 } 1310 1311 if i < 0 { 1312 i = 0 1313 } 1314 for i < len(buf) && buf[i] == 0 { 1315 i++ 1316 } 1317 1318 return 1319} 1320 1321// bigEndianWord returns the contents of buf interpreted as a big-endian encoded Word value. 1322func bigEndianWord(buf []byte) Word { 1323 if _W == 64 { 1324 return Word(byteorder.BeUint64(buf)) 1325 } 1326 return Word(byteorder.BeUint32(buf)) 1327} 1328 1329// setBytes interprets buf as the bytes of a big-endian unsigned 1330// integer, sets z to that value, and returns z. 1331func (z nat) setBytes(buf []byte) nat { 1332 z = z.make((len(buf) + _S - 1) / _S) 1333 1334 i := len(buf) 1335 for k := 0; i >= _S; k++ { 1336 z[k] = bigEndianWord(buf[i-_S : i]) 1337 i -= _S 1338 } 1339 if i > 0 { 1340 var d Word 1341 for s := uint(0); i > 0; s += 8 { 1342 d |= Word(buf[i-1]) << s 1343 i-- 1344 } 1345 z[len(z)-1] = d 1346 } 1347 1348 return z.norm() 1349} 1350 1351// sqrt sets z = ⌊√x⌋ 1352func (z nat) sqrt(x nat) nat { 1353 if x.cmp(natOne) <= 0 { 1354 return z.set(x) 1355 } 1356 if alias(z, x) { 1357 z = nil 1358 } 1359 1360 // Start with value known to be too large and repeat "z = ⌊(z + ⌊x/z⌋)/2⌋" until it stops getting smaller. 1361 // See Brent and Zimmermann, Modern Computer Arithmetic, Algorithm 1.13 (SqrtInt). 1362 // https://members.loria.fr/PZimmermann/mca/pub226.html 1363 // If x is one less than a perfect square, the sequence oscillates between the correct z and z+1; 1364 // otherwise it converges to the correct z and stays there. 1365 var z1, z2 nat 1366 z1 = z 1367 z1 = z1.setUint64(1) 1368 z1 = z1.shl(z1, uint(x.bitLen()+1)/2) // must be ≥ √x 1369 for n := 0; ; n++ { 1370 z2, _ = z2.div(nil, x, z1) 1371 z2 = z2.add(z2, z1) 1372 z2 = z2.shr(z2, 1) 1373 if z2.cmp(z1) >= 0 { 1374 // z1 is answer. 1375 // Figure out whether z1 or z2 is currently aliased to z by looking at loop count. 1376 if n&1 == 0 { 1377 return z1 1378 } 1379 return z.set(z1) 1380 } 1381 z1, z2 = z2, z1 1382 } 1383} 1384 1385// subMod2N returns z = (x - y) mod 2ⁿ. 1386func (z nat) subMod2N(x, y nat, n uint) nat { 1387 if uint(x.bitLen()) > n { 1388 if alias(z, x) { 1389 // ok to overwrite x in place 1390 x = x.trunc(x, n) 1391 } else { 1392 x = nat(nil).trunc(x, n) 1393 } 1394 } 1395 if uint(y.bitLen()) > n { 1396 if alias(z, y) { 1397 // ok to overwrite y in place 1398 y = y.trunc(y, n) 1399 } else { 1400 y = nat(nil).trunc(y, n) 1401 } 1402 } 1403 if x.cmp(y) >= 0 { 1404 return z.sub(x, y) 1405 } 1406 // x - y < 0; x - y mod 2ⁿ = x - y + 2ⁿ = 2ⁿ - (y - x) = 1 + 2ⁿ-1 - (y - x) = 1 + ^(y - x). 1407 z = z.sub(y, x) 1408 for uint(len(z))*_W < n { 1409 z = append(z, 0) 1410 } 1411 for i := range z { 1412 z[i] = ^z[i] 1413 } 1414 z = z.trunc(z, n) 1415 return z.add(z, natOne) 1416} 1417