1 // Copyright 2015-2023 Brian Smith.
2 //
3 // Permission to use, copy, modify, and/or distribute this software for any
4 // purpose with or without fee is hereby granted, provided that the above
5 // copyright notice and this permission notice appear in all copies.
6 //
7 // THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHORS DISCLAIM ALL WARRANTIES
8 // WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9 // MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY
10 // SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11 // WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12 // OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13 // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
14
15 //! Multi-precision integers.
16 //!
17 //! # Modular Arithmetic.
18 //!
19 //! Modular arithmetic is done in finite commutative rings ℤ/mℤ for some
20 //! modulus *m*. We work in finite commutative rings instead of finite fields
21 //! because the RSA public modulus *n* is not prime, which means ℤ/nℤ contains
22 //! nonzero elements that have no multiplicative inverse, so ℤ/nℤ is not a
23 //! finite field.
24 //!
25 //! In some calculations we need to deal with multiple rings at once. For
26 //! example, RSA private key operations operate in the rings ℤ/nℤ, ℤ/pℤ, and
27 //! ℤ/qℤ. Types and functions dealing with such rings are all parameterized
28 //! over a type `M` to ensure that we don't wrongly mix up the math, e.g. by
29 //! multiplying an element of ℤ/pℤ by an element of ℤ/qℤ modulo q. This follows
30 //! the "unit" pattern described in [Static checking of units in Servo].
31 //!
32 //! `Elem` also uses the static unit checking pattern to statically track the
33 //! Montgomery factors that need to be canceled out in each value using it's
34 //! `E` parameter.
35 //!
36 //! [Static checking of units in Servo]:
37 //! https://blog.mozilla.org/research/2014/06/23/static-checking-of-units-in-servo/
38
39 use self::boxed_limbs::BoxedLimbs;
40 pub(crate) use self::{
41 modulus::{Modulus, PartialModulus, MODULUS_MAX_LIMBS},
42 private_exponent::PrivateExponent,
43 };
44 use super::n0::N0;
45 pub(crate) use super::nonnegative::Nonnegative;
46 use crate::{
47 arithmetic::montgomery::*,
48 bits, c, cpu, error,
49 limb::{self, Limb, LimbMask, LIMB_BITS},
50 polyfill::u64_from_usize,
51 };
52 use alloc::vec;
53 use core::{marker::PhantomData, num::NonZeroU64};
54
55 mod boxed_limbs;
56 mod modulus;
57 mod private_exponent;
58
59 /// A prime modulus.
60 ///
61 /// # Safety
62 ///
63 /// Some logic may assume a `Prime` number is non-zero, and thus a non-empty
64 /// array of limbs, or make similar assumptions. TODO: Any such logic should
65 /// be encapsulated here, or this trait should be made non-`unsafe`. TODO:
66 /// non-zero-ness and non-empty-ness should be factored out into a separate
67 /// trait. (In retrospect, this shouldn't have been made an `unsafe` trait
68 /// preemptively.)
69 pub unsafe trait Prime {}
70
71 struct Width<M> {
72 num_limbs: usize,
73
74 /// The modulus *m* that the width originated from.
75 m: PhantomData<M>,
76 }
77
78 /// A modulus *s* that is smaller than another modulus *l* so every element of
79 /// ℤ/sℤ is also an element of ℤ/lℤ.
80 ///
81 /// # Safety
82 ///
83 /// Some logic may assume that the invariant holds when accessing limbs within
84 /// a value, e.g. by assuming the larger modulus has at least as many limbs.
85 /// TODO: Any such logic should be encapsulated here, or this trait should be
86 /// made non-`unsafe`. (In retrospect, this shouldn't have been made an `unsafe`
87 /// trait preemptively.)
88 pub unsafe trait SmallerModulus<L> {}
89
90 /// A modulus *s* where s < l < 2*s for the given larger modulus *l*. This is
91 /// the precondition for reduction by conditional subtraction,
92 /// `elem_reduce_once()`.
93 ///
94 /// # Safety
95 ///
96 /// Some logic may assume that the invariant holds when accessing limbs within
97 /// a value, e.g. by assuming that the smaller modulus is at most one limb
98 /// smaller than the larger modulus. TODO: Any such logic should be
99 /// encapsulated here, or this trait should be made non-`unsafe`. (In retrospect,
100 /// this shouldn't have been made an `unsafe` trait preemptively.)
101 pub unsafe trait SlightlySmallerModulus<L>: SmallerModulus<L> {}
102
103 /// A modulus *s* where √l <= s < l for the given larger modulus *l*. This is
104 /// the precondition for the more general Montgomery reduction from ℤ/lℤ to
105 /// ℤ/sℤ.
106 ///
107 /// # Safety
108 ///
109 /// Some logic may assume that the invariant holds when accessing limbs within
110 /// a value. TODO: Any such logic should be encapsulated here, or this trait
111 /// should be made non-`unsafe`. (In retrospect, this shouldn't have been made
112 /// an `unsafe` trait preemptively.)
113 pub unsafe trait NotMuchSmallerModulus<L>: SmallerModulus<L> {}
114
115 pub trait PublicModulus {}
116
117 /// Elements of ℤ/mℤ for some modulus *m*.
118 //
119 // Defaulting `E` to `Unencoded` is a convenience for callers from outside this
120 // submodule. However, for maximum clarity, we always explicitly use
121 // `Unencoded` within the `bigint` submodule.
122 pub struct Elem<M, E = Unencoded> {
123 limbs: BoxedLimbs<M>,
124
125 /// The number of Montgomery factors that need to be canceled out from
126 /// `value` to get the actual value.
127 encoding: PhantomData<E>,
128 }
129
130 // TODO: `derive(Clone)` after https://github.com/rust-lang/rust/issues/26925
131 // is resolved or restrict `M: Clone` and `E: Clone`.
132 impl<M, E> Clone for Elem<M, E> {
clone(&self) -> Self133 fn clone(&self) -> Self {
134 Self {
135 limbs: self.limbs.clone(),
136 encoding: self.encoding,
137 }
138 }
139 }
140
141 impl<M, E> Elem<M, E> {
142 #[inline]
is_zero(&self) -> bool143 pub fn is_zero(&self) -> bool {
144 self.limbs.is_zero()
145 }
146 }
147
148 /// Does a Montgomery reduction on `limbs` assuming they are Montgomery-encoded ('R') and assuming
149 /// they are the same size as `m`, but perhaps not reduced mod `m`. The result will be
150 /// fully reduced mod `m`.
from_montgomery_amm<M>(limbs: BoxedLimbs<M>, m: &Modulus<M>) -> Elem<M, Unencoded>151 fn from_montgomery_amm<M>(limbs: BoxedLimbs<M>, m: &Modulus<M>) -> Elem<M, Unencoded> {
152 debug_assert_eq!(limbs.len(), m.limbs().len());
153
154 let mut limbs = limbs;
155 let num_limbs = m.width().num_limbs;
156 let mut one = [0; MODULUS_MAX_LIMBS];
157 one[0] = 1;
158 let one = &one[..num_limbs]; // assert!(num_limbs <= MODULUS_MAX_LIMBS);
159 limbs_mont_mul(&mut limbs, one, m.limbs(), m.n0(), m.cpu_features());
160 Elem {
161 limbs,
162 encoding: PhantomData,
163 }
164 }
165
166 impl<M> Elem<M, R> {
167 #[inline]
into_unencoded(self, m: &Modulus<M>) -> Elem<M, Unencoded>168 pub fn into_unencoded(self, m: &Modulus<M>) -> Elem<M, Unencoded> {
169 from_montgomery_amm(self.limbs, m)
170 }
171 }
172
173 impl<M> Elem<M, Unencoded> {
from_be_bytes_padded( input: untrusted::Input, m: &Modulus<M>, ) -> Result<Self, error::Unspecified>174 pub fn from_be_bytes_padded(
175 input: untrusted::Input,
176 m: &Modulus<M>,
177 ) -> Result<Self, error::Unspecified> {
178 Ok(Self {
179 limbs: BoxedLimbs::from_be_bytes_padded_less_than(input, m)?,
180 encoding: PhantomData,
181 })
182 }
183
184 #[inline]
fill_be_bytes(&self, out: &mut [u8])185 pub fn fill_be_bytes(&self, out: &mut [u8]) {
186 // See Falko Strenzke, "Manger's Attack revisited", ICICS 2010.
187 limb::big_endian_from_limbs(&self.limbs, out)
188 }
189
is_one(&self) -> bool190 fn is_one(&self) -> bool {
191 limb::limbs_equal_limb_constant_time(&self.limbs, 1) == LimbMask::True
192 }
193 }
194
elem_mul<M, AF, BF>( a: &Elem<M, AF>, b: Elem<M, BF>, m: &Modulus<M>, ) -> Elem<M, <(AF, BF) as ProductEncoding>::Output> where (AF, BF): ProductEncoding,195 pub fn elem_mul<M, AF, BF>(
196 a: &Elem<M, AF>,
197 b: Elem<M, BF>,
198 m: &Modulus<M>,
199 ) -> Elem<M, <(AF, BF) as ProductEncoding>::Output>
200 where
201 (AF, BF): ProductEncoding,
202 {
203 elem_mul_(a, b, &m.as_partial())
204 }
205
elem_mul_<M, AF, BF>( a: &Elem<M, AF>, mut b: Elem<M, BF>, m: &PartialModulus<M>, ) -> Elem<M, <(AF, BF) as ProductEncoding>::Output> where (AF, BF): ProductEncoding,206 fn elem_mul_<M, AF, BF>(
207 a: &Elem<M, AF>,
208 mut b: Elem<M, BF>,
209 m: &PartialModulus<M>,
210 ) -> Elem<M, <(AF, BF) as ProductEncoding>::Output>
211 where
212 (AF, BF): ProductEncoding,
213 {
214 limbs_mont_mul(&mut b.limbs, &a.limbs, m.limbs(), m.n0(), m.cpu_features());
215 Elem {
216 limbs: b.limbs,
217 encoding: PhantomData,
218 }
219 }
220
elem_mul_by_2<M, AF>(a: &mut Elem<M, AF>, m: &PartialModulus<M>)221 fn elem_mul_by_2<M, AF>(a: &mut Elem<M, AF>, m: &PartialModulus<M>) {
222 prefixed_extern! {
223 fn LIMBS_shl_mod(r: *mut Limb, a: *const Limb, m: *const Limb, num_limbs: c::size_t);
224 }
225 unsafe {
226 LIMBS_shl_mod(
227 a.limbs.as_mut_ptr(),
228 a.limbs.as_ptr(),
229 m.limbs().as_ptr(),
230 m.limbs().len(),
231 );
232 }
233 }
234
elem_reduced_once<Larger, Smaller: SlightlySmallerModulus<Larger>>( a: &Elem<Larger, Unencoded>, m: &Modulus<Smaller>, ) -> Elem<Smaller, Unencoded>235 pub fn elem_reduced_once<Larger, Smaller: SlightlySmallerModulus<Larger>>(
236 a: &Elem<Larger, Unencoded>,
237 m: &Modulus<Smaller>,
238 ) -> Elem<Smaller, Unencoded> {
239 let mut r = a.limbs.clone();
240 assert!(r.len() <= m.limbs().len());
241 limb::limbs_reduce_once_constant_time(&mut r, m.limbs());
242 Elem {
243 limbs: BoxedLimbs::new_unchecked(r.into_limbs()),
244 encoding: PhantomData,
245 }
246 }
247
248 #[inline]
elem_reduced<Larger, Smaller: NotMuchSmallerModulus<Larger>>( a: &Elem<Larger, Unencoded>, m: &Modulus<Smaller>, ) -> Elem<Smaller, RInverse>249 pub fn elem_reduced<Larger, Smaller: NotMuchSmallerModulus<Larger>>(
250 a: &Elem<Larger, Unencoded>,
251 m: &Modulus<Smaller>,
252 ) -> Elem<Smaller, RInverse> {
253 let mut tmp = [0; MODULUS_MAX_LIMBS];
254 let tmp = &mut tmp[..a.limbs.len()];
255 tmp.copy_from_slice(&a.limbs);
256
257 let mut r = m.zero();
258 limbs_from_mont_in_place(&mut r.limbs, tmp, m.limbs(), m.n0());
259 r
260 }
261
elem_squared<M, E>( mut a: Elem<M, E>, m: &PartialModulus<M>, ) -> Elem<M, <(E, E) as ProductEncoding>::Output> where (E, E): ProductEncoding,262 fn elem_squared<M, E>(
263 mut a: Elem<M, E>,
264 m: &PartialModulus<M>,
265 ) -> Elem<M, <(E, E) as ProductEncoding>::Output>
266 where
267 (E, E): ProductEncoding,
268 {
269 limbs_mont_square(&mut a.limbs, m.limbs(), m.n0(), m.cpu_features());
270 Elem {
271 limbs: a.limbs,
272 encoding: PhantomData,
273 }
274 }
275
elem_widen<Larger, Smaller: SmallerModulus<Larger>>( a: Elem<Smaller, Unencoded>, m: &Modulus<Larger>, ) -> Elem<Larger, Unencoded>276 pub fn elem_widen<Larger, Smaller: SmallerModulus<Larger>>(
277 a: Elem<Smaller, Unencoded>,
278 m: &Modulus<Larger>,
279 ) -> Elem<Larger, Unencoded> {
280 let mut r = m.zero();
281 r.limbs[..a.limbs.len()].copy_from_slice(&a.limbs);
282 r
283 }
284
285 // TODO: Document why this works for all Montgomery factors.
elem_add<M, E>(mut a: Elem<M, E>, b: Elem<M, E>, m: &Modulus<M>) -> Elem<M, E>286 pub fn elem_add<M, E>(mut a: Elem<M, E>, b: Elem<M, E>, m: &Modulus<M>) -> Elem<M, E> {
287 limb::limbs_add_assign_mod(&mut a.limbs, &b.limbs, m.limbs());
288 a
289 }
290
291 // TODO: Document why this works for all Montgomery factors.
elem_sub<M, E>(mut a: Elem<M, E>, b: &Elem<M, E>, m: &Modulus<M>) -> Elem<M, E>292 pub fn elem_sub<M, E>(mut a: Elem<M, E>, b: &Elem<M, E>, m: &Modulus<M>) -> Elem<M, E> {
293 prefixed_extern! {
294 // `r` and `a` may alias.
295 fn LIMBS_sub_mod(
296 r: *mut Limb,
297 a: *const Limb,
298 b: *const Limb,
299 m: *const Limb,
300 num_limbs: c::size_t,
301 );
302 }
303 unsafe {
304 LIMBS_sub_mod(
305 a.limbs.as_mut_ptr(),
306 a.limbs.as_ptr(),
307 b.limbs.as_ptr(),
308 m.limbs().as_ptr(),
309 m.limbs().len(),
310 );
311 }
312 a
313 }
314
315 // The value 1, Montgomery-encoded some number of times.
316 pub struct One<M, E>(Elem<M, E>);
317
318 impl<M> One<M, RR> {
319 // Returns RR = = R**2 (mod n) where R = 2**r is the smallest power of
320 // 2**LIMB_BITS such that R > m.
321 //
322 // Even though the assembly on some 32-bit platforms works with 64-bit
323 // values, using `LIMB_BITS` here, rather than `N0::LIMBS_USED * LIMB_BITS`,
324 // is correct because R**2 will still be a multiple of the latter as
325 // `N0::LIMBS_USED` is either one or two.
newRR(m: &PartialModulus<M>, m_bits: bits::BitLength) -> Self326 fn newRR(m: &PartialModulus<M>, m_bits: bits::BitLength) -> Self {
327 let m_bits = m_bits.as_usize_bits();
328 let r = (m_bits + (LIMB_BITS - 1)) / LIMB_BITS * LIMB_BITS;
329
330 // base = 2**(lg m - 1).
331 let bit = m_bits - 1;
332 let mut base = m.zero();
333 base.limbs[bit / LIMB_BITS] = 1 << (bit % LIMB_BITS);
334
335 // Double `base` so that base == R == 2**r (mod m). For normal moduli
336 // that have the high bit of the highest limb set, this requires one
337 // doubling. Unusual moduli require more doublings but we are less
338 // concerned about the performance of those.
339 //
340 // Then double `base` again so that base == 2*R (mod n), i.e. `2` in
341 // Montgomery form (`elem_exp_vartime()` requires the base to be in
342 // Montgomery form). Then compute
343 // RR = R**2 == base**r == R**r == (2**r)**r (mod n).
344 //
345 // Take advantage of the fact that `elem_mul_by_2` is faster than
346 // `elem_squared` by replacing some of the early squarings with shifts.
347 // TODO: Benchmark shift vs. squaring performance to determine the
348 // optimal value of `LG_BASE`.
349 const LG_BASE: usize = 2; // Shifts vs. squaring trade-off.
350 debug_assert_eq!(LG_BASE.count_ones(), 1); // Must be 2**n for n >= 0.
351 let shifts = r - bit + LG_BASE;
352 // `m_bits >= LG_BASE` (for the currently chosen value of `LG_BASE`)
353 // since we require the modulus to have at least `MODULUS_MIN_LIMBS`
354 // limbs. `r >= m_bits` as seen above. So `r >= LG_BASE` and thus
355 // `r / LG_BASE` is non-zero.
356 //
357 // The maximum value of `r` is determined by
358 // `MODULUS_MAX_LIMBS * LIMB_BITS`. Further `r` is a multiple of
359 // `LIMB_BITS` so the maximum Hamming Weight is bounded by
360 // `MODULUS_MAX_LIMBS`. For the common case of {2048, 4096, 8192}-bit
361 // moduli the Hamming weight is 1. For the other common case of 3072
362 // the Hamming weight is 2.
363 let exponent = NonZeroU64::new(u64_from_usize(r / LG_BASE)).unwrap();
364 for _ in 0..shifts {
365 elem_mul_by_2(&mut base, m)
366 }
367 let RR = elem_exp_vartime(base, exponent, m);
368
369 Self(Elem {
370 limbs: RR.limbs,
371 encoding: PhantomData, // PhantomData<RR>
372 })
373 }
374 }
375
376 impl<M, E> AsRef<Elem<M, E>> for One<M, E> {
as_ref(&self) -> &Elem<M, E>377 fn as_ref(&self) -> &Elem<M, E> {
378 &self.0
379 }
380 }
381
382 impl<M: PublicModulus, E> Clone for One<M, E> {
clone(&self) -> Self383 fn clone(&self) -> Self {
384 Self(self.0.clone())
385 }
386 }
387
388 /// Calculates base**exponent (mod m).
389 ///
390 /// The run time is a function of the number of limbs in `m` and the bit
391 /// length and Hamming Weight of `exponent`. The bounds on `m` are pretty
392 /// obvious but the bounds on `exponent` are less obvious. Callers should
393 /// document the bounds they place on the maximum value and maximum Hamming
394 /// weight of `exponent`.
395 // TODO: The test coverage needs to be expanded, e.g. test with the largest
396 // accepted exponent and with the most common values of 65537 and 3.
elem_exp_vartime<M>( base: Elem<M, R>, exponent: NonZeroU64, m: &PartialModulus<M>, ) -> Elem<M, R>397 pub(crate) fn elem_exp_vartime<M>(
398 base: Elem<M, R>,
399 exponent: NonZeroU64,
400 m: &PartialModulus<M>,
401 ) -> Elem<M, R> {
402 // Use what [Knuth] calls the "S-and-X binary method", i.e. variable-time
403 // square-and-multiply that scans the exponent from the most significant
404 // bit to the least significant bit (left-to-right). Left-to-right requires
405 // less storage compared to right-to-left scanning, at the cost of needing
406 // to compute `exponent.leading_zeros()`, which we assume to be cheap.
407 //
408 // As explained in [Knuth], exponentiation by squaring is the most
409 // efficient algorithm when the Hamming weight is 2 or less. It isn't the
410 // most efficient for all other, uncommon, exponent values but any
411 // suboptimality is bounded at least by the small bit length of `exponent`
412 // as enforced by its type.
413 //
414 // This implementation is slightly simplified by taking advantage of the
415 // fact that we require the exponent to be a positive integer.
416 //
417 // [Knuth]: The Art of Computer Programming, Volume 2: Seminumerical
418 // Algorithms (3rd Edition), Section 4.6.3.
419 let exponent = exponent.get();
420 let mut acc = base.clone();
421 let mut bit = 1 << (64 - 1 - exponent.leading_zeros());
422 debug_assert!((exponent & bit) != 0);
423 while bit > 1 {
424 bit >>= 1;
425 acc = elem_squared(acc, m);
426 if (exponent & bit) != 0 {
427 acc = elem_mul_(&base, acc, m);
428 }
429 }
430 acc
431 }
432
433 /// Uses Fermat's Little Theorem to calculate modular inverse in constant time.
elem_inverse_consttime<M: Prime>( a: Elem<M, R>, m: &Modulus<M>, ) -> Result<Elem<M, Unencoded>, error::Unspecified>434 pub fn elem_inverse_consttime<M: Prime>(
435 a: Elem<M, R>,
436 m: &Modulus<M>,
437 ) -> Result<Elem<M, Unencoded>, error::Unspecified> {
438 elem_exp_consttime(a, &PrivateExponent::for_flt(m), m)
439 }
440
441 #[cfg(not(target_arch = "x86_64"))]
elem_exp_consttime<M>( base: Elem<M, R>, exponent: &PrivateExponent, m: &Modulus<M>, ) -> Result<Elem<M, Unencoded>, error::Unspecified>442 pub fn elem_exp_consttime<M>(
443 base: Elem<M, R>,
444 exponent: &PrivateExponent,
445 m: &Modulus<M>,
446 ) -> Result<Elem<M, Unencoded>, error::Unspecified> {
447 use crate::{bssl, limb::Window};
448
449 const WINDOW_BITS: usize = 5;
450 const TABLE_ENTRIES: usize = 1 << WINDOW_BITS;
451
452 let num_limbs = m.limbs().len();
453
454 let mut table = vec![0; TABLE_ENTRIES * num_limbs];
455
456 fn gather<M>(table: &[Limb], i: Window, r: &mut Elem<M, R>) {
457 prefixed_extern! {
458 fn LIMBS_select_512_32(
459 r: *mut Limb,
460 table: *const Limb,
461 num_limbs: c::size_t,
462 i: Window,
463 ) -> bssl::Result;
464 }
465 Result::from(unsafe {
466 LIMBS_select_512_32(r.limbs.as_mut_ptr(), table.as_ptr(), r.limbs.len(), i)
467 })
468 .unwrap();
469 }
470
471 fn power<M>(
472 table: &[Limb],
473 i: Window,
474 mut acc: Elem<M, R>,
475 mut tmp: Elem<M, R>,
476 m: &Modulus<M>,
477 ) -> (Elem<M, R>, Elem<M, R>) {
478 for _ in 0..WINDOW_BITS {
479 acc = elem_squared(acc, &m.as_partial());
480 }
481 gather(table, i, &mut tmp);
482 let acc = elem_mul(&tmp, acc, m);
483 (acc, tmp)
484 }
485
486 let tmp = m.one();
487 let tmp = elem_mul(m.oneRR().as_ref(), tmp, m);
488
489 fn entry(table: &[Limb], i: usize, num_limbs: usize) -> &[Limb] {
490 &table[(i * num_limbs)..][..num_limbs]
491 }
492 fn entry_mut(table: &mut [Limb], i: usize, num_limbs: usize) -> &mut [Limb] {
493 &mut table[(i * num_limbs)..][..num_limbs]
494 }
495 entry_mut(&mut table, 0, num_limbs).copy_from_slice(&tmp.limbs);
496 entry_mut(&mut table, 1, num_limbs).copy_from_slice(&base.limbs);
497 for i in 2..TABLE_ENTRIES {
498 let (src1, src2) = if i % 2 == 0 {
499 (i / 2, i / 2)
500 } else {
501 (i - 1, 1)
502 };
503 let (previous, rest) = table.split_at_mut(num_limbs * i);
504 let src1 = entry(previous, src1, num_limbs);
505 let src2 = entry(previous, src2, num_limbs);
506 let dst = entry_mut(rest, 0, num_limbs);
507 limbs_mont_product(dst, src1, src2, m.limbs(), m.n0(), m.cpu_features());
508 }
509
510 let (r, _) = limb::fold_5_bit_windows(
511 exponent.limbs(),
512 |initial_window| {
513 let mut r = Elem {
514 limbs: base.limbs,
515 encoding: PhantomData,
516 };
517 gather(&table, initial_window, &mut r);
518 (r, tmp)
519 },
520 |(acc, tmp), window| power(&table, window, acc, tmp, m),
521 );
522
523 let r = r.into_unencoded(m);
524
525 Ok(r)
526 }
527
528 #[cfg(target_arch = "x86_64")]
elem_exp_consttime<M>( base: Elem<M, R>, exponent: &PrivateExponent, m: &Modulus<M>, ) -> Result<Elem<M, Unencoded>, error::Unspecified>529 pub fn elem_exp_consttime<M>(
530 base: Elem<M, R>,
531 exponent: &PrivateExponent,
532 m: &Modulus<M>,
533 ) -> Result<Elem<M, Unencoded>, error::Unspecified> {
534 use crate::limb::LIMB_BYTES;
535
536 // Pretty much all the math here requires CPU feature detection to have
537 // been done. `cpu_features` isn't threaded through all the internal
538 // functions, so just make it clear that it has been done at this point.
539 let cpu_features = m.cpu_features();
540
541 // The x86_64 assembly was written under the assumption that the input data
542 // is aligned to `MOD_EXP_CTIME_ALIGN` bytes, which was/is 64 in OpenSSL.
543 // Similarly, OpenSSL uses the x86_64 assembly functions by giving it only
544 // inputs `tmp`, `am`, and `np` that immediately follow the table. All the
545 // awkwardness here stems from trying to use the assembly code like OpenSSL
546 // does.
547
548 use crate::limb::Window;
549
550 const WINDOW_BITS: usize = 5;
551 const TABLE_ENTRIES: usize = 1 << WINDOW_BITS;
552
553 let num_limbs = m.limbs().len();
554
555 const ALIGNMENT: usize = 64;
556 assert_eq!(ALIGNMENT % LIMB_BYTES, 0);
557 let mut table = vec![0; ((TABLE_ENTRIES + 3) * num_limbs) + ALIGNMENT];
558 let (table, state) = {
559 let misalignment = (table.as_ptr() as usize) % ALIGNMENT;
560 let table = &mut table[((ALIGNMENT - misalignment) / LIMB_BYTES)..];
561 assert_eq!((table.as_ptr() as usize) % ALIGNMENT, 0);
562 table.split_at_mut(TABLE_ENTRIES * num_limbs)
563 };
564
565 fn entry(table: &[Limb], i: usize, num_limbs: usize) -> &[Limb] {
566 &table[(i * num_limbs)..][..num_limbs]
567 }
568 fn entry_mut(table: &mut [Limb], i: usize, num_limbs: usize) -> &mut [Limb] {
569 &mut table[(i * num_limbs)..][..num_limbs]
570 }
571
572 const ACC: usize = 0; // `tmp` in OpenSSL
573 const BASE: usize = ACC + 1; // `am` in OpenSSL
574 const M: usize = BASE + 1; // `np` in OpenSSL
575
576 entry_mut(state, BASE, num_limbs).copy_from_slice(&base.limbs);
577 entry_mut(state, M, num_limbs).copy_from_slice(m.limbs());
578
579 fn scatter(table: &mut [Limb], state: &[Limb], i: Window, num_limbs: usize) {
580 prefixed_extern! {
581 fn bn_scatter5(a: *const Limb, a_len: c::size_t, table: *mut Limb, i: Window);
582 }
583 unsafe {
584 bn_scatter5(
585 entry(state, ACC, num_limbs).as_ptr(),
586 num_limbs,
587 table.as_mut_ptr(),
588 i,
589 )
590 }
591 }
592
593 fn gather(table: &[Limb], state: &mut [Limb], i: Window, num_limbs: usize) {
594 prefixed_extern! {
595 fn bn_gather5(r: *mut Limb, a_len: c::size_t, table: *const Limb, i: Window);
596 }
597 unsafe {
598 bn_gather5(
599 entry_mut(state, ACC, num_limbs).as_mut_ptr(),
600 num_limbs,
601 table.as_ptr(),
602 i,
603 )
604 }
605 }
606
607 fn gather_square(
608 table: &[Limb],
609 state: &mut [Limb],
610 n0: &N0,
611 i: Window,
612 num_limbs: usize,
613 cpu_features: cpu::Features,
614 ) {
615 gather(table, state, i, num_limbs);
616 assert_eq!(ACC, 0);
617 let (acc, rest) = state.split_at_mut(num_limbs);
618 let m = entry(rest, M - 1, num_limbs);
619 limbs_mont_square(acc, m, n0, cpu_features);
620 }
621
622 fn gather_mul_base_amm(
623 table: &[Limb],
624 state: &mut [Limb],
625 n0: &N0,
626 i: Window,
627 num_limbs: usize,
628 ) {
629 prefixed_extern! {
630 fn bn_mul_mont_gather5(
631 rp: *mut Limb,
632 ap: *const Limb,
633 table: *const Limb,
634 np: *const Limb,
635 n0: &N0,
636 num: c::size_t,
637 power: Window,
638 );
639 }
640 unsafe {
641 bn_mul_mont_gather5(
642 entry_mut(state, ACC, num_limbs).as_mut_ptr(),
643 entry(state, BASE, num_limbs).as_ptr(),
644 table.as_ptr(),
645 entry(state, M, num_limbs).as_ptr(),
646 n0,
647 num_limbs,
648 i,
649 );
650 }
651 }
652
653 fn power_amm(table: &[Limb], state: &mut [Limb], n0: &N0, i: Window, num_limbs: usize) {
654 prefixed_extern! {
655 fn bn_power5(
656 r: *mut Limb,
657 a: *const Limb,
658 table: *const Limb,
659 n: *const Limb,
660 n0: &N0,
661 num: c::size_t,
662 i: Window,
663 );
664 }
665 unsafe {
666 bn_power5(
667 entry_mut(state, ACC, num_limbs).as_mut_ptr(),
668 entry_mut(state, ACC, num_limbs).as_mut_ptr(),
669 table.as_ptr(),
670 entry(state, M, num_limbs).as_ptr(),
671 n0,
672 num_limbs,
673 i,
674 );
675 }
676 }
677
678 // table[0] = base**0.
679 {
680 let acc = entry_mut(state, ACC, num_limbs);
681 acc[0] = 1;
682 limbs_mont_mul(acc, &m.oneRR().0.limbs, m.limbs(), m.n0(), cpu_features);
683 }
684 scatter(table, state, 0, num_limbs);
685
686 // table[1] = base**1.
687 entry_mut(state, ACC, num_limbs).copy_from_slice(&base.limbs);
688 scatter(table, state, 1, num_limbs);
689
690 for i in 2..(TABLE_ENTRIES as Window) {
691 if i % 2 == 0 {
692 // TODO: Optimize this to avoid gathering
693 gather_square(table, state, m.n0(), i / 2, num_limbs, cpu_features);
694 } else {
695 gather_mul_base_amm(table, state, m.n0(), i - 1, num_limbs)
696 };
697 scatter(table, state, i, num_limbs);
698 }
699
700 let state = limb::fold_5_bit_windows(
701 exponent.limbs(),
702 |initial_window| {
703 gather(table, state, initial_window, num_limbs);
704 state
705 },
706 |state, window| {
707 power_amm(table, state, m.n0(), window, num_limbs);
708 state
709 },
710 );
711
712 let mut r_amm = base.limbs;
713 r_amm.copy_from_slice(entry(state, ACC, num_limbs));
714
715 Ok(from_montgomery_amm(r_amm, m))
716 }
717
718 /// Verified a == b**-1 (mod m), i.e. a**-1 == b (mod m).
verify_inverses_consttime<M>( a: &Elem<M, R>, b: Elem<M, Unencoded>, m: &Modulus<M>, ) -> Result<(), error::Unspecified>719 pub fn verify_inverses_consttime<M>(
720 a: &Elem<M, R>,
721 b: Elem<M, Unencoded>,
722 m: &Modulus<M>,
723 ) -> Result<(), error::Unspecified> {
724 if elem_mul(a, b, m).is_one() {
725 Ok(())
726 } else {
727 Err(error::Unspecified)
728 }
729 }
730
731 #[inline]
elem_verify_equal_consttime<M, E>( a: &Elem<M, E>, b: &Elem<M, E>, ) -> Result<(), error::Unspecified>732 pub fn elem_verify_equal_consttime<M, E>(
733 a: &Elem<M, E>,
734 b: &Elem<M, E>,
735 ) -> Result<(), error::Unspecified> {
736 if limb::limbs_equal_limbs_consttime(&a.limbs, &b.limbs) == LimbMask::True {
737 Ok(())
738 } else {
739 Err(error::Unspecified)
740 }
741 }
742
743 // TODO: Move these methods from `Nonnegative` to `Modulus`.
744 impl Nonnegative {
to_elem<M>(&self, m: &Modulus<M>) -> Result<Elem<M, Unencoded>, error::Unspecified>745 pub fn to_elem<M>(&self, m: &Modulus<M>) -> Result<Elem<M, Unencoded>, error::Unspecified> {
746 self.verify_less_than_modulus(m)?;
747 let mut r = m.zero();
748 r.limbs[0..self.limbs().len()].copy_from_slice(self.limbs());
749 Ok(r)
750 }
751
verify_less_than_modulus<M>(&self, m: &Modulus<M>) -> Result<(), error::Unspecified>752 pub fn verify_less_than_modulus<M>(&self, m: &Modulus<M>) -> Result<(), error::Unspecified> {
753 if self.limbs().len() > m.limbs().len() {
754 return Err(error::Unspecified);
755 }
756 if self.limbs().len() == m.limbs().len() {
757 if limb::limbs_less_than_limbs_consttime(self.limbs(), m.limbs()) != LimbMask::True {
758 return Err(error::Unspecified);
759 }
760 }
761 Ok(())
762 }
763 }
764
765 /// r *= a
limbs_mont_mul(r: &mut [Limb], a: &[Limb], m: &[Limb], n0: &N0, _cpu_features: cpu::Features)766 fn limbs_mont_mul(r: &mut [Limb], a: &[Limb], m: &[Limb], n0: &N0, _cpu_features: cpu::Features) {
767 debug_assert_eq!(r.len(), m.len());
768 debug_assert_eq!(a.len(), m.len());
769 unsafe {
770 bn_mul_mont(
771 r.as_mut_ptr(),
772 r.as_ptr(),
773 a.as_ptr(),
774 m.as_ptr(),
775 n0,
776 r.len(),
777 )
778 }
779 }
780
781 /// r = a * b
782 #[cfg(not(target_arch = "x86_64"))]
limbs_mont_product( r: &mut [Limb], a: &[Limb], b: &[Limb], m: &[Limb], n0: &N0, _cpu_features: cpu::Features, )783 fn limbs_mont_product(
784 r: &mut [Limb],
785 a: &[Limb],
786 b: &[Limb],
787 m: &[Limb],
788 n0: &N0,
789 _cpu_features: cpu::Features,
790 ) {
791 debug_assert_eq!(r.len(), m.len());
792 debug_assert_eq!(a.len(), m.len());
793 debug_assert_eq!(b.len(), m.len());
794
795 unsafe {
796 bn_mul_mont(
797 r.as_mut_ptr(),
798 a.as_ptr(),
799 b.as_ptr(),
800 m.as_ptr(),
801 n0,
802 r.len(),
803 )
804 }
805 }
806
807 /// r = r**2
limbs_mont_square(r: &mut [Limb], m: &[Limb], n0: &N0, _cpu_features: cpu::Features)808 fn limbs_mont_square(r: &mut [Limb], m: &[Limb], n0: &N0, _cpu_features: cpu::Features) {
809 debug_assert_eq!(r.len(), m.len());
810 unsafe {
811 bn_mul_mont(
812 r.as_mut_ptr(),
813 r.as_ptr(),
814 r.as_ptr(),
815 m.as_ptr(),
816 n0,
817 r.len(),
818 )
819 }
820 }
821
822 prefixed_extern! {
823 // `r` and/or 'a' and/or 'b' may alias.
824 fn bn_mul_mont(
825 r: *mut Limb,
826 a: *const Limb,
827 b: *const Limb,
828 n: *const Limb,
829 n0: &N0,
830 num_limbs: c::size_t,
831 );
832 }
833
834 #[cfg(test)]
835 mod tests {
836 use super::{modulus::MODULUS_MIN_LIMBS, *};
837 use crate::{limb::LIMB_BYTES, test};
838 use alloc::format;
839
840 // Type-level representation of an arbitrary modulus.
841 struct M {}
842
843 impl PublicModulus for M {}
844
845 #[test]
test_elem_exp_consttime()846 fn test_elem_exp_consttime() {
847 let cpu_features = cpu::features();
848 test::run(
849 test_file!("../../crypto/fipsmodule/bn/test/mod_exp_tests.txt"),
850 |section, test_case| {
851 assert_eq!(section, "");
852
853 let m = consume_modulus::<M>(test_case, "M", cpu_features);
854 let expected_result = consume_elem(test_case, "ModExp", &m);
855 let base = consume_elem(test_case, "A", &m);
856 let e = {
857 let bytes = test_case.consume_bytes("E");
858 PrivateExponent::from_be_bytes_for_test_only(untrusted::Input::from(&bytes), &m)
859 .expect("valid exponent")
860 };
861 let base = into_encoded(base, &m);
862 let actual_result = elem_exp_consttime(base, &e, &m).unwrap();
863 assert_elem_eq(&actual_result, &expected_result);
864
865 Ok(())
866 },
867 )
868 }
869
870 // TODO: fn test_elem_exp_vartime() using
871 // "src/rsa/bigint_elem_exp_vartime_tests.txt". See that file for details.
872 // In the meantime, the function is tested indirectly via the RSA
873 // verification and signing tests.
874 #[test]
test_elem_mul()875 fn test_elem_mul() {
876 let cpu_features = cpu::features();
877 test::run(
878 test_file!("../../crypto/fipsmodule/bn/test/mod_mul_tests.txt"),
879 |section, test_case| {
880 assert_eq!(section, "");
881
882 let m = consume_modulus::<M>(test_case, "M", cpu_features);
883 let expected_result = consume_elem(test_case, "ModMul", &m);
884 let a = consume_elem(test_case, "A", &m);
885 let b = consume_elem(test_case, "B", &m);
886
887 let b = into_encoded(b, &m);
888 let a = into_encoded(a, &m);
889 let actual_result = elem_mul(&a, b, &m);
890 let actual_result = actual_result.into_unencoded(&m);
891 assert_elem_eq(&actual_result, &expected_result);
892
893 Ok(())
894 },
895 )
896 }
897
898 #[test]
test_elem_squared()899 fn test_elem_squared() {
900 let cpu_features = cpu::features();
901 test::run(
902 test_file!("bigint_elem_squared_tests.txt"),
903 |section, test_case| {
904 assert_eq!(section, "");
905
906 let m = consume_modulus::<M>(test_case, "M", cpu_features);
907 let expected_result = consume_elem(test_case, "ModSquare", &m);
908 let a = consume_elem(test_case, "A", &m);
909
910 let a = into_encoded(a, &m);
911 let actual_result = elem_squared(a, &m.as_partial());
912 let actual_result = actual_result.into_unencoded(&m);
913 assert_elem_eq(&actual_result, &expected_result);
914
915 Ok(())
916 },
917 )
918 }
919
920 #[test]
test_elem_reduced()921 fn test_elem_reduced() {
922 let cpu_features = cpu::features();
923 test::run(
924 test_file!("bigint_elem_reduced_tests.txt"),
925 |section, test_case| {
926 assert_eq!(section, "");
927
928 struct MM {}
929 unsafe impl SmallerModulus<MM> for M {}
930 unsafe impl NotMuchSmallerModulus<MM> for M {}
931
932 let m = consume_modulus::<M>(test_case, "M", cpu_features);
933 let expected_result = consume_elem(test_case, "R", &m);
934 let a =
935 consume_elem_unchecked::<MM>(test_case, "A", expected_result.limbs.len() * 2);
936
937 let actual_result = elem_reduced(&a, &m);
938 let oneRR = m.oneRR();
939 let actual_result = elem_mul(oneRR.as_ref(), actual_result, &m);
940 assert_elem_eq(&actual_result, &expected_result);
941
942 Ok(())
943 },
944 )
945 }
946
947 #[test]
test_elem_reduced_once()948 fn test_elem_reduced_once() {
949 let cpu_features = cpu::features();
950 test::run(
951 test_file!("bigint_elem_reduced_once_tests.txt"),
952 |section, test_case| {
953 assert_eq!(section, "");
954
955 struct N {}
956 struct QQ {}
957 unsafe impl SmallerModulus<N> for QQ {}
958 unsafe impl SlightlySmallerModulus<N> for QQ {}
959
960 let qq = consume_modulus::<QQ>(test_case, "QQ", cpu_features);
961 let expected_result = consume_elem::<QQ>(test_case, "R", &qq);
962 let n = consume_modulus::<N>(test_case, "N", cpu_features);
963 let a = consume_elem::<N>(test_case, "A", &n);
964
965 let actual_result = elem_reduced_once(&a, &qq);
966 assert_elem_eq(&actual_result, &expected_result);
967
968 Ok(())
969 },
970 )
971 }
972
973 #[test]
test_modulus_debug()974 fn test_modulus_debug() {
975 let (modulus, _) = Modulus::<M>::from_be_bytes_with_bit_length(
976 untrusted::Input::from(&[0xff; LIMB_BYTES * MODULUS_MIN_LIMBS]),
977 cpu::features(),
978 )
979 .unwrap();
980 assert_eq!("Modulus", format!("{:?}", modulus));
981 }
982
consume_elem<M>( test_case: &mut test::TestCase, name: &str, m: &Modulus<M>, ) -> Elem<M, Unencoded>983 fn consume_elem<M>(
984 test_case: &mut test::TestCase,
985 name: &str,
986 m: &Modulus<M>,
987 ) -> Elem<M, Unencoded> {
988 let value = test_case.consume_bytes(name);
989 Elem::from_be_bytes_padded(untrusted::Input::from(&value), m).unwrap()
990 }
991
consume_elem_unchecked<M>( test_case: &mut test::TestCase, name: &str, num_limbs: usize, ) -> Elem<M, Unencoded>992 fn consume_elem_unchecked<M>(
993 test_case: &mut test::TestCase,
994 name: &str,
995 num_limbs: usize,
996 ) -> Elem<M, Unencoded> {
997 let value = consume_nonnegative(test_case, name);
998 let mut limbs = BoxedLimbs::zero(Width {
999 num_limbs,
1000 m: PhantomData,
1001 });
1002 limbs[0..value.limbs().len()].copy_from_slice(value.limbs());
1003 Elem {
1004 limbs,
1005 encoding: PhantomData,
1006 }
1007 }
1008
consume_modulus<M>( test_case: &mut test::TestCase, name: &str, cpu_features: cpu::Features, ) -> Modulus<M>1009 fn consume_modulus<M>(
1010 test_case: &mut test::TestCase,
1011 name: &str,
1012 cpu_features: cpu::Features,
1013 ) -> Modulus<M> {
1014 let value = test_case.consume_bytes(name);
1015 let (value, _) =
1016 Modulus::from_be_bytes_with_bit_length(untrusted::Input::from(&value), cpu_features)
1017 .unwrap();
1018 value
1019 }
1020
consume_nonnegative(test_case: &mut test::TestCase, name: &str) -> Nonnegative1021 fn consume_nonnegative(test_case: &mut test::TestCase, name: &str) -> Nonnegative {
1022 let bytes = test_case.consume_bytes(name);
1023 let (r, _r_bits) =
1024 Nonnegative::from_be_bytes_with_bit_length(untrusted::Input::from(&bytes)).unwrap();
1025 r
1026 }
1027
assert_elem_eq<M, E>(a: &Elem<M, E>, b: &Elem<M, E>)1028 fn assert_elem_eq<M, E>(a: &Elem<M, E>, b: &Elem<M, E>) {
1029 if elem_verify_equal_consttime(a, b).is_err() {
1030 panic!("{:x?} != {:x?}", &*a.limbs, &*b.limbs);
1031 }
1032 }
1033
into_encoded<M>(a: Elem<M, Unencoded>, m: &Modulus<M>) -> Elem<M, R>1034 fn into_encoded<M>(a: Elem<M, Unencoded>, m: &Modulus<M>) -> Elem<M, R> {
1035 elem_mul(m.oneRR().as_ref(), a, m)
1036 }
1037 }
1038