1 // Copyright 2015-2023 Brian Smith.
2 //
3 // Permission to use, copy, modify, and/or distribute this software for any
4 // purpose with or without fee is hereby granted, provided that the above
5 // copyright notice and this permission notice appear in all copies.
6 //
7 // THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHORS DISCLAIM ALL WARRANTIES
8 // WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9 // MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY
10 // SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11 // WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12 // OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13 // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
14 
15 //! Multi-precision integers.
16 //!
17 //! # Modular Arithmetic.
18 //!
19 //! Modular arithmetic is done in finite commutative rings ℤ/mℤ for some
20 //! modulus *m*. We work in finite commutative rings instead of finite fields
21 //! because the RSA public modulus *n* is not prime, which means ℤ/nℤ contains
22 //! nonzero elements that have no multiplicative inverse, so ℤ/nℤ is not a
23 //! finite field.
24 //!
25 //! In some calculations we need to deal with multiple rings at once. For
26 //! example, RSA private key operations operate in the rings ℤ/nℤ, ℤ/pℤ, and
27 //! ℤ/qℤ. Types and functions dealing with such rings are all parameterized
28 //! over a type `M` to ensure that we don't wrongly mix up the math, e.g. by
29 //! multiplying an element of ℤ/pℤ by an element of ℤ/qℤ modulo q. This follows
30 //! the "unit" pattern described in [Static checking of units in Servo].
31 //!
32 //! `Elem` also uses the static unit checking pattern to statically track the
33 //! Montgomery factors that need to be canceled out in each value using it's
34 //! `E` parameter.
35 //!
36 //! [Static checking of units in Servo]:
37 //!     https://blog.mozilla.org/research/2014/06/23/static-checking-of-units-in-servo/
38 
39 use self::boxed_limbs::BoxedLimbs;
40 pub(crate) use self::{
41     modulus::{Modulus, PartialModulus, MODULUS_MAX_LIMBS},
42     private_exponent::PrivateExponent,
43 };
44 use super::n0::N0;
45 pub(crate) use super::nonnegative::Nonnegative;
46 use crate::{
47     arithmetic::montgomery::*,
48     bits, c, cpu, error,
49     limb::{self, Limb, LimbMask, LIMB_BITS},
50     polyfill::u64_from_usize,
51 };
52 use alloc::vec;
53 use core::{marker::PhantomData, num::NonZeroU64};
54 
55 mod boxed_limbs;
56 mod modulus;
57 mod private_exponent;
58 
59 /// A prime modulus.
60 ///
61 /// # Safety
62 ///
63 /// Some logic may assume a `Prime` number is non-zero, and thus a non-empty
64 /// array of limbs, or make similar assumptions. TODO: Any such logic should
65 /// be encapsulated here, or this trait should be made non-`unsafe`. TODO:
66 /// non-zero-ness and non-empty-ness should be factored out into a separate
67 /// trait. (In retrospect, this shouldn't have been made an `unsafe` trait
68 /// preemptively.)
69 pub unsafe trait Prime {}
70 
71 struct Width<M> {
72     num_limbs: usize,
73 
74     /// The modulus *m* that the width originated from.
75     m: PhantomData<M>,
76 }
77 
78 /// A modulus *s* that is smaller than another modulus *l* so every element of
79 /// ℤ/sℤ is also an element of ℤ/lℤ.
80 ///
81 /// # Safety
82 ///
83 /// Some logic may assume that the invariant holds when accessing limbs within
84 /// a value, e.g. by assuming the larger modulus has at least as many limbs.
85 /// TODO: Any such logic should be encapsulated here, or this trait should be
86 /// made non-`unsafe`. (In retrospect, this shouldn't have been made an `unsafe`
87 /// trait preemptively.)
88 pub unsafe trait SmallerModulus<L> {}
89 
90 /// A modulus *s* where s < l < 2*s for the given larger modulus *l*. This is
91 /// the precondition for reduction by conditional subtraction,
92 /// `elem_reduce_once()`.
93 ///
94 /// # Safety
95 ///
96 /// Some logic may assume that the invariant holds when accessing limbs within
97 /// a value, e.g. by assuming that the smaller modulus is at most one limb
98 /// smaller than the larger modulus. TODO: Any such logic should be
99 /// encapsulated here, or this trait should be made non-`unsafe`. (In retrospect,
100 /// this shouldn't have been made an `unsafe` trait preemptively.)
101 pub unsafe trait SlightlySmallerModulus<L>: SmallerModulus<L> {}
102 
103 /// A modulus *s* where √l <= s < l for the given larger modulus *l*. This is
104 /// the precondition for the more general Montgomery reduction from ℤ/lℤ to
105 /// ℤ/sℤ.
106 ///
107 /// # Safety
108 ///
109 /// Some logic may assume that the invariant holds when accessing limbs within
110 /// a value. TODO: Any such logic should be encapsulated here, or this trait
111 /// should be made non-`unsafe`. (In retrospect, this shouldn't have been made
112 /// an `unsafe` trait preemptively.)
113 pub unsafe trait NotMuchSmallerModulus<L>: SmallerModulus<L> {}
114 
115 pub trait PublicModulus {}
116 
117 /// Elements of ℤ/mℤ for some modulus *m*.
118 //
119 // Defaulting `E` to `Unencoded` is a convenience for callers from outside this
120 // submodule. However, for maximum clarity, we always explicitly use
121 // `Unencoded` within the `bigint` submodule.
122 pub struct Elem<M, E = Unencoded> {
123     limbs: BoxedLimbs<M>,
124 
125     /// The number of Montgomery factors that need to be canceled out from
126     /// `value` to get the actual value.
127     encoding: PhantomData<E>,
128 }
129 
130 // TODO: `derive(Clone)` after https://github.com/rust-lang/rust/issues/26925
131 // is resolved or restrict `M: Clone` and `E: Clone`.
132 impl<M, E> Clone for Elem<M, E> {
clone(&self) -> Self133     fn clone(&self) -> Self {
134         Self {
135             limbs: self.limbs.clone(),
136             encoding: self.encoding,
137         }
138     }
139 }
140 
141 impl<M, E> Elem<M, E> {
142     #[inline]
is_zero(&self) -> bool143     pub fn is_zero(&self) -> bool {
144         self.limbs.is_zero()
145     }
146 }
147 
148 /// Does a Montgomery reduction on `limbs` assuming they are Montgomery-encoded ('R') and assuming
149 /// they are the same size as `m`, but perhaps not reduced mod `m`. The result will be
150 /// fully reduced mod `m`.
from_montgomery_amm<M>(limbs: BoxedLimbs<M>, m: &Modulus<M>) -> Elem<M, Unencoded>151 fn from_montgomery_amm<M>(limbs: BoxedLimbs<M>, m: &Modulus<M>) -> Elem<M, Unencoded> {
152     debug_assert_eq!(limbs.len(), m.limbs().len());
153 
154     let mut limbs = limbs;
155     let num_limbs = m.width().num_limbs;
156     let mut one = [0; MODULUS_MAX_LIMBS];
157     one[0] = 1;
158     let one = &one[..num_limbs]; // assert!(num_limbs <= MODULUS_MAX_LIMBS);
159     limbs_mont_mul(&mut limbs, one, m.limbs(), m.n0(), m.cpu_features());
160     Elem {
161         limbs,
162         encoding: PhantomData,
163     }
164 }
165 
166 impl<M> Elem<M, R> {
167     #[inline]
into_unencoded(self, m: &Modulus<M>) -> Elem<M, Unencoded>168     pub fn into_unencoded(self, m: &Modulus<M>) -> Elem<M, Unencoded> {
169         from_montgomery_amm(self.limbs, m)
170     }
171 }
172 
173 impl<M> Elem<M, Unencoded> {
from_be_bytes_padded( input: untrusted::Input, m: &Modulus<M>, ) -> Result<Self, error::Unspecified>174     pub fn from_be_bytes_padded(
175         input: untrusted::Input,
176         m: &Modulus<M>,
177     ) -> Result<Self, error::Unspecified> {
178         Ok(Self {
179             limbs: BoxedLimbs::from_be_bytes_padded_less_than(input, m)?,
180             encoding: PhantomData,
181         })
182     }
183 
184     #[inline]
fill_be_bytes(&self, out: &mut [u8])185     pub fn fill_be_bytes(&self, out: &mut [u8]) {
186         // See Falko Strenzke, "Manger's Attack revisited", ICICS 2010.
187         limb::big_endian_from_limbs(&self.limbs, out)
188     }
189 
is_one(&self) -> bool190     fn is_one(&self) -> bool {
191         limb::limbs_equal_limb_constant_time(&self.limbs, 1) == LimbMask::True
192     }
193 }
194 
elem_mul<M, AF, BF>( a: &Elem<M, AF>, b: Elem<M, BF>, m: &Modulus<M>, ) -> Elem<M, <(AF, BF) as ProductEncoding>::Output> where (AF, BF): ProductEncoding,195 pub fn elem_mul<M, AF, BF>(
196     a: &Elem<M, AF>,
197     b: Elem<M, BF>,
198     m: &Modulus<M>,
199 ) -> Elem<M, <(AF, BF) as ProductEncoding>::Output>
200 where
201     (AF, BF): ProductEncoding,
202 {
203     elem_mul_(a, b, &m.as_partial())
204 }
205 
elem_mul_<M, AF, BF>( a: &Elem<M, AF>, mut b: Elem<M, BF>, m: &PartialModulus<M>, ) -> Elem<M, <(AF, BF) as ProductEncoding>::Output> where (AF, BF): ProductEncoding,206 fn elem_mul_<M, AF, BF>(
207     a: &Elem<M, AF>,
208     mut b: Elem<M, BF>,
209     m: &PartialModulus<M>,
210 ) -> Elem<M, <(AF, BF) as ProductEncoding>::Output>
211 where
212     (AF, BF): ProductEncoding,
213 {
214     limbs_mont_mul(&mut b.limbs, &a.limbs, m.limbs(), m.n0(), m.cpu_features());
215     Elem {
216         limbs: b.limbs,
217         encoding: PhantomData,
218     }
219 }
220 
elem_mul_by_2<M, AF>(a: &mut Elem<M, AF>, m: &PartialModulus<M>)221 fn elem_mul_by_2<M, AF>(a: &mut Elem<M, AF>, m: &PartialModulus<M>) {
222     prefixed_extern! {
223         fn LIMBS_shl_mod(r: *mut Limb, a: *const Limb, m: *const Limb, num_limbs: c::size_t);
224     }
225     unsafe {
226         LIMBS_shl_mod(
227             a.limbs.as_mut_ptr(),
228             a.limbs.as_ptr(),
229             m.limbs().as_ptr(),
230             m.limbs().len(),
231         );
232     }
233 }
234 
elem_reduced_once<Larger, Smaller: SlightlySmallerModulus<Larger>>( a: &Elem<Larger, Unencoded>, m: &Modulus<Smaller>, ) -> Elem<Smaller, Unencoded>235 pub fn elem_reduced_once<Larger, Smaller: SlightlySmallerModulus<Larger>>(
236     a: &Elem<Larger, Unencoded>,
237     m: &Modulus<Smaller>,
238 ) -> Elem<Smaller, Unencoded> {
239     let mut r = a.limbs.clone();
240     assert!(r.len() <= m.limbs().len());
241     limb::limbs_reduce_once_constant_time(&mut r, m.limbs());
242     Elem {
243         limbs: BoxedLimbs::new_unchecked(r.into_limbs()),
244         encoding: PhantomData,
245     }
246 }
247 
248 #[inline]
elem_reduced<Larger, Smaller: NotMuchSmallerModulus<Larger>>( a: &Elem<Larger, Unencoded>, m: &Modulus<Smaller>, ) -> Elem<Smaller, RInverse>249 pub fn elem_reduced<Larger, Smaller: NotMuchSmallerModulus<Larger>>(
250     a: &Elem<Larger, Unencoded>,
251     m: &Modulus<Smaller>,
252 ) -> Elem<Smaller, RInverse> {
253     let mut tmp = [0; MODULUS_MAX_LIMBS];
254     let tmp = &mut tmp[..a.limbs.len()];
255     tmp.copy_from_slice(&a.limbs);
256 
257     let mut r = m.zero();
258     limbs_from_mont_in_place(&mut r.limbs, tmp, m.limbs(), m.n0());
259     r
260 }
261 
elem_squared<M, E>( mut a: Elem<M, E>, m: &PartialModulus<M>, ) -> Elem<M, <(E, E) as ProductEncoding>::Output> where (E, E): ProductEncoding,262 fn elem_squared<M, E>(
263     mut a: Elem<M, E>,
264     m: &PartialModulus<M>,
265 ) -> Elem<M, <(E, E) as ProductEncoding>::Output>
266 where
267     (E, E): ProductEncoding,
268 {
269     limbs_mont_square(&mut a.limbs, m.limbs(), m.n0(), m.cpu_features());
270     Elem {
271         limbs: a.limbs,
272         encoding: PhantomData,
273     }
274 }
275 
elem_widen<Larger, Smaller: SmallerModulus<Larger>>( a: Elem<Smaller, Unencoded>, m: &Modulus<Larger>, ) -> Elem<Larger, Unencoded>276 pub fn elem_widen<Larger, Smaller: SmallerModulus<Larger>>(
277     a: Elem<Smaller, Unencoded>,
278     m: &Modulus<Larger>,
279 ) -> Elem<Larger, Unencoded> {
280     let mut r = m.zero();
281     r.limbs[..a.limbs.len()].copy_from_slice(&a.limbs);
282     r
283 }
284 
285 // TODO: Document why this works for all Montgomery factors.
elem_add<M, E>(mut a: Elem<M, E>, b: Elem<M, E>, m: &Modulus<M>) -> Elem<M, E>286 pub fn elem_add<M, E>(mut a: Elem<M, E>, b: Elem<M, E>, m: &Modulus<M>) -> Elem<M, E> {
287     limb::limbs_add_assign_mod(&mut a.limbs, &b.limbs, m.limbs());
288     a
289 }
290 
291 // TODO: Document why this works for all Montgomery factors.
elem_sub<M, E>(mut a: Elem<M, E>, b: &Elem<M, E>, m: &Modulus<M>) -> Elem<M, E>292 pub fn elem_sub<M, E>(mut a: Elem<M, E>, b: &Elem<M, E>, m: &Modulus<M>) -> Elem<M, E> {
293     prefixed_extern! {
294         // `r` and `a` may alias.
295         fn LIMBS_sub_mod(
296             r: *mut Limb,
297             a: *const Limb,
298             b: *const Limb,
299             m: *const Limb,
300             num_limbs: c::size_t,
301         );
302     }
303     unsafe {
304         LIMBS_sub_mod(
305             a.limbs.as_mut_ptr(),
306             a.limbs.as_ptr(),
307             b.limbs.as_ptr(),
308             m.limbs().as_ptr(),
309             m.limbs().len(),
310         );
311     }
312     a
313 }
314 
315 // The value 1, Montgomery-encoded some number of times.
316 pub struct One<M, E>(Elem<M, E>);
317 
318 impl<M> One<M, RR> {
319     // Returns RR = = R**2 (mod n) where R = 2**r is the smallest power of
320     // 2**LIMB_BITS such that R > m.
321     //
322     // Even though the assembly on some 32-bit platforms works with 64-bit
323     // values, using `LIMB_BITS` here, rather than `N0::LIMBS_USED * LIMB_BITS`,
324     // is correct because R**2 will still be a multiple of the latter as
325     // `N0::LIMBS_USED` is either one or two.
newRR(m: &PartialModulus<M>, m_bits: bits::BitLength) -> Self326     fn newRR(m: &PartialModulus<M>, m_bits: bits::BitLength) -> Self {
327         let m_bits = m_bits.as_usize_bits();
328         let r = (m_bits + (LIMB_BITS - 1)) / LIMB_BITS * LIMB_BITS;
329 
330         // base = 2**(lg m - 1).
331         let bit = m_bits - 1;
332         let mut base = m.zero();
333         base.limbs[bit / LIMB_BITS] = 1 << (bit % LIMB_BITS);
334 
335         // Double `base` so that base == R == 2**r (mod m). For normal moduli
336         // that have the high bit of the highest limb set, this requires one
337         // doubling. Unusual moduli require more doublings but we are less
338         // concerned about the performance of those.
339         //
340         // Then double `base` again so that base == 2*R (mod n), i.e. `2` in
341         // Montgomery form (`elem_exp_vartime()` requires the base to be in
342         // Montgomery form). Then compute
343         // RR = R**2 == base**r == R**r == (2**r)**r (mod n).
344         //
345         // Take advantage of the fact that `elem_mul_by_2` is faster than
346         // `elem_squared` by replacing some of the early squarings with shifts.
347         // TODO: Benchmark shift vs. squaring performance to determine the
348         // optimal value of `LG_BASE`.
349         const LG_BASE: usize = 2; // Shifts vs. squaring trade-off.
350         debug_assert_eq!(LG_BASE.count_ones(), 1); // Must be 2**n for n >= 0.
351         let shifts = r - bit + LG_BASE;
352         // `m_bits >= LG_BASE` (for the currently chosen value of `LG_BASE`)
353         // since we require the modulus to have at least `MODULUS_MIN_LIMBS`
354         // limbs. `r >= m_bits` as seen above. So `r >= LG_BASE` and thus
355         // `r / LG_BASE` is non-zero.
356         //
357         // The maximum value of `r` is determined by
358         // `MODULUS_MAX_LIMBS * LIMB_BITS`. Further `r` is a multiple of
359         // `LIMB_BITS` so the maximum Hamming Weight is bounded by
360         // `MODULUS_MAX_LIMBS`. For the common case of {2048, 4096, 8192}-bit
361         // moduli the Hamming weight is 1. For the other common case of 3072
362         // the Hamming weight is 2.
363         let exponent = NonZeroU64::new(u64_from_usize(r / LG_BASE)).unwrap();
364         for _ in 0..shifts {
365             elem_mul_by_2(&mut base, m)
366         }
367         let RR = elem_exp_vartime(base, exponent, m);
368 
369         Self(Elem {
370             limbs: RR.limbs,
371             encoding: PhantomData, // PhantomData<RR>
372         })
373     }
374 }
375 
376 impl<M, E> AsRef<Elem<M, E>> for One<M, E> {
as_ref(&self) -> &Elem<M, E>377     fn as_ref(&self) -> &Elem<M, E> {
378         &self.0
379     }
380 }
381 
382 impl<M: PublicModulus, E> Clone for One<M, E> {
clone(&self) -> Self383     fn clone(&self) -> Self {
384         Self(self.0.clone())
385     }
386 }
387 
388 /// Calculates base**exponent (mod m).
389 ///
390 /// The run time  is a function of the number of limbs in `m` and the bit
391 /// length and Hamming Weight of `exponent`. The bounds on `m` are pretty
392 /// obvious but the bounds on `exponent` are less obvious. Callers should
393 /// document the bounds they place on the maximum value and maximum Hamming
394 /// weight of `exponent`.
395 // TODO: The test coverage needs to be expanded, e.g. test with the largest
396 // accepted exponent and with the most common values of 65537 and 3.
elem_exp_vartime<M>( base: Elem<M, R>, exponent: NonZeroU64, m: &PartialModulus<M>, ) -> Elem<M, R>397 pub(crate) fn elem_exp_vartime<M>(
398     base: Elem<M, R>,
399     exponent: NonZeroU64,
400     m: &PartialModulus<M>,
401 ) -> Elem<M, R> {
402     // Use what [Knuth] calls the "S-and-X binary method", i.e. variable-time
403     // square-and-multiply that scans the exponent from the most significant
404     // bit to the least significant bit (left-to-right). Left-to-right requires
405     // less storage compared to right-to-left scanning, at the cost of needing
406     // to compute `exponent.leading_zeros()`, which we assume to be cheap.
407     //
408     // As explained in [Knuth], exponentiation by squaring is the most
409     // efficient algorithm when the Hamming weight is 2 or less. It isn't the
410     // most efficient for all other, uncommon, exponent values but any
411     // suboptimality is bounded at least by the small bit length of `exponent`
412     // as enforced by its type.
413     //
414     // This implementation is slightly simplified by taking advantage of the
415     // fact that we require the exponent to be a positive integer.
416     //
417     // [Knuth]: The Art of Computer Programming, Volume 2: Seminumerical
418     //          Algorithms (3rd Edition), Section 4.6.3.
419     let exponent = exponent.get();
420     let mut acc = base.clone();
421     let mut bit = 1 << (64 - 1 - exponent.leading_zeros());
422     debug_assert!((exponent & bit) != 0);
423     while bit > 1 {
424         bit >>= 1;
425         acc = elem_squared(acc, m);
426         if (exponent & bit) != 0 {
427             acc = elem_mul_(&base, acc, m);
428         }
429     }
430     acc
431 }
432 
433 /// Uses Fermat's Little Theorem to calculate modular inverse in constant time.
elem_inverse_consttime<M: Prime>( a: Elem<M, R>, m: &Modulus<M>, ) -> Result<Elem<M, Unencoded>, error::Unspecified>434 pub fn elem_inverse_consttime<M: Prime>(
435     a: Elem<M, R>,
436     m: &Modulus<M>,
437 ) -> Result<Elem<M, Unencoded>, error::Unspecified> {
438     elem_exp_consttime(a, &PrivateExponent::for_flt(m), m)
439 }
440 
441 #[cfg(not(target_arch = "x86_64"))]
elem_exp_consttime<M>( base: Elem<M, R>, exponent: &PrivateExponent, m: &Modulus<M>, ) -> Result<Elem<M, Unencoded>, error::Unspecified>442 pub fn elem_exp_consttime<M>(
443     base: Elem<M, R>,
444     exponent: &PrivateExponent,
445     m: &Modulus<M>,
446 ) -> Result<Elem<M, Unencoded>, error::Unspecified> {
447     use crate::{bssl, limb::Window};
448 
449     const WINDOW_BITS: usize = 5;
450     const TABLE_ENTRIES: usize = 1 << WINDOW_BITS;
451 
452     let num_limbs = m.limbs().len();
453 
454     let mut table = vec![0; TABLE_ENTRIES * num_limbs];
455 
456     fn gather<M>(table: &[Limb], i: Window, r: &mut Elem<M, R>) {
457         prefixed_extern! {
458             fn LIMBS_select_512_32(
459                 r: *mut Limb,
460                 table: *const Limb,
461                 num_limbs: c::size_t,
462                 i: Window,
463             ) -> bssl::Result;
464         }
465         Result::from(unsafe {
466             LIMBS_select_512_32(r.limbs.as_mut_ptr(), table.as_ptr(), r.limbs.len(), i)
467         })
468         .unwrap();
469     }
470 
471     fn power<M>(
472         table: &[Limb],
473         i: Window,
474         mut acc: Elem<M, R>,
475         mut tmp: Elem<M, R>,
476         m: &Modulus<M>,
477     ) -> (Elem<M, R>, Elem<M, R>) {
478         for _ in 0..WINDOW_BITS {
479             acc = elem_squared(acc, &m.as_partial());
480         }
481         gather(table, i, &mut tmp);
482         let acc = elem_mul(&tmp, acc, m);
483         (acc, tmp)
484     }
485 
486     let tmp = m.one();
487     let tmp = elem_mul(m.oneRR().as_ref(), tmp, m);
488 
489     fn entry(table: &[Limb], i: usize, num_limbs: usize) -> &[Limb] {
490         &table[(i * num_limbs)..][..num_limbs]
491     }
492     fn entry_mut(table: &mut [Limb], i: usize, num_limbs: usize) -> &mut [Limb] {
493         &mut table[(i * num_limbs)..][..num_limbs]
494     }
495     entry_mut(&mut table, 0, num_limbs).copy_from_slice(&tmp.limbs);
496     entry_mut(&mut table, 1, num_limbs).copy_from_slice(&base.limbs);
497     for i in 2..TABLE_ENTRIES {
498         let (src1, src2) = if i % 2 == 0 {
499             (i / 2, i / 2)
500         } else {
501             (i - 1, 1)
502         };
503         let (previous, rest) = table.split_at_mut(num_limbs * i);
504         let src1 = entry(previous, src1, num_limbs);
505         let src2 = entry(previous, src2, num_limbs);
506         let dst = entry_mut(rest, 0, num_limbs);
507         limbs_mont_product(dst, src1, src2, m.limbs(), m.n0(), m.cpu_features());
508     }
509 
510     let (r, _) = limb::fold_5_bit_windows(
511         exponent.limbs(),
512         |initial_window| {
513             let mut r = Elem {
514                 limbs: base.limbs,
515                 encoding: PhantomData,
516             };
517             gather(&table, initial_window, &mut r);
518             (r, tmp)
519         },
520         |(acc, tmp), window| power(&table, window, acc, tmp, m),
521     );
522 
523     let r = r.into_unencoded(m);
524 
525     Ok(r)
526 }
527 
528 #[cfg(target_arch = "x86_64")]
elem_exp_consttime<M>( base: Elem<M, R>, exponent: &PrivateExponent, m: &Modulus<M>, ) -> Result<Elem<M, Unencoded>, error::Unspecified>529 pub fn elem_exp_consttime<M>(
530     base: Elem<M, R>,
531     exponent: &PrivateExponent,
532     m: &Modulus<M>,
533 ) -> Result<Elem<M, Unencoded>, error::Unspecified> {
534     use crate::limb::LIMB_BYTES;
535 
536     // Pretty much all the math here requires CPU feature detection to have
537     // been done. `cpu_features` isn't threaded through all the internal
538     // functions, so just make it clear that it has been done at this point.
539     let cpu_features = m.cpu_features();
540 
541     // The x86_64 assembly was written under the assumption that the input data
542     // is aligned to `MOD_EXP_CTIME_ALIGN` bytes, which was/is 64 in OpenSSL.
543     // Similarly, OpenSSL uses the x86_64 assembly functions by giving it only
544     // inputs `tmp`, `am`, and `np` that immediately follow the table. All the
545     // awkwardness here stems from trying to use the assembly code like OpenSSL
546     // does.
547 
548     use crate::limb::Window;
549 
550     const WINDOW_BITS: usize = 5;
551     const TABLE_ENTRIES: usize = 1 << WINDOW_BITS;
552 
553     let num_limbs = m.limbs().len();
554 
555     const ALIGNMENT: usize = 64;
556     assert_eq!(ALIGNMENT % LIMB_BYTES, 0);
557     let mut table = vec![0; ((TABLE_ENTRIES + 3) * num_limbs) + ALIGNMENT];
558     let (table, state) = {
559         let misalignment = (table.as_ptr() as usize) % ALIGNMENT;
560         let table = &mut table[((ALIGNMENT - misalignment) / LIMB_BYTES)..];
561         assert_eq!((table.as_ptr() as usize) % ALIGNMENT, 0);
562         table.split_at_mut(TABLE_ENTRIES * num_limbs)
563     };
564 
565     fn entry(table: &[Limb], i: usize, num_limbs: usize) -> &[Limb] {
566         &table[(i * num_limbs)..][..num_limbs]
567     }
568     fn entry_mut(table: &mut [Limb], i: usize, num_limbs: usize) -> &mut [Limb] {
569         &mut table[(i * num_limbs)..][..num_limbs]
570     }
571 
572     const ACC: usize = 0; // `tmp` in OpenSSL
573     const BASE: usize = ACC + 1; // `am` in OpenSSL
574     const M: usize = BASE + 1; // `np` in OpenSSL
575 
576     entry_mut(state, BASE, num_limbs).copy_from_slice(&base.limbs);
577     entry_mut(state, M, num_limbs).copy_from_slice(m.limbs());
578 
579     fn scatter(table: &mut [Limb], state: &[Limb], i: Window, num_limbs: usize) {
580         prefixed_extern! {
581             fn bn_scatter5(a: *const Limb, a_len: c::size_t, table: *mut Limb, i: Window);
582         }
583         unsafe {
584             bn_scatter5(
585                 entry(state, ACC, num_limbs).as_ptr(),
586                 num_limbs,
587                 table.as_mut_ptr(),
588                 i,
589             )
590         }
591     }
592 
593     fn gather(table: &[Limb], state: &mut [Limb], i: Window, num_limbs: usize) {
594         prefixed_extern! {
595             fn bn_gather5(r: *mut Limb, a_len: c::size_t, table: *const Limb, i: Window);
596         }
597         unsafe {
598             bn_gather5(
599                 entry_mut(state, ACC, num_limbs).as_mut_ptr(),
600                 num_limbs,
601                 table.as_ptr(),
602                 i,
603             )
604         }
605     }
606 
607     fn gather_square(
608         table: &[Limb],
609         state: &mut [Limb],
610         n0: &N0,
611         i: Window,
612         num_limbs: usize,
613         cpu_features: cpu::Features,
614     ) {
615         gather(table, state, i, num_limbs);
616         assert_eq!(ACC, 0);
617         let (acc, rest) = state.split_at_mut(num_limbs);
618         let m = entry(rest, M - 1, num_limbs);
619         limbs_mont_square(acc, m, n0, cpu_features);
620     }
621 
622     fn gather_mul_base_amm(
623         table: &[Limb],
624         state: &mut [Limb],
625         n0: &N0,
626         i: Window,
627         num_limbs: usize,
628     ) {
629         prefixed_extern! {
630             fn bn_mul_mont_gather5(
631                 rp: *mut Limb,
632                 ap: *const Limb,
633                 table: *const Limb,
634                 np: *const Limb,
635                 n0: &N0,
636                 num: c::size_t,
637                 power: Window,
638             );
639         }
640         unsafe {
641             bn_mul_mont_gather5(
642                 entry_mut(state, ACC, num_limbs).as_mut_ptr(),
643                 entry(state, BASE, num_limbs).as_ptr(),
644                 table.as_ptr(),
645                 entry(state, M, num_limbs).as_ptr(),
646                 n0,
647                 num_limbs,
648                 i,
649             );
650         }
651     }
652 
653     fn power_amm(table: &[Limb], state: &mut [Limb], n0: &N0, i: Window, num_limbs: usize) {
654         prefixed_extern! {
655             fn bn_power5(
656                 r: *mut Limb,
657                 a: *const Limb,
658                 table: *const Limb,
659                 n: *const Limb,
660                 n0: &N0,
661                 num: c::size_t,
662                 i: Window,
663             );
664         }
665         unsafe {
666             bn_power5(
667                 entry_mut(state, ACC, num_limbs).as_mut_ptr(),
668                 entry_mut(state, ACC, num_limbs).as_mut_ptr(),
669                 table.as_ptr(),
670                 entry(state, M, num_limbs).as_ptr(),
671                 n0,
672                 num_limbs,
673                 i,
674             );
675         }
676     }
677 
678     // table[0] = base**0.
679     {
680         let acc = entry_mut(state, ACC, num_limbs);
681         acc[0] = 1;
682         limbs_mont_mul(acc, &m.oneRR().0.limbs, m.limbs(), m.n0(), cpu_features);
683     }
684     scatter(table, state, 0, num_limbs);
685 
686     // table[1] = base**1.
687     entry_mut(state, ACC, num_limbs).copy_from_slice(&base.limbs);
688     scatter(table, state, 1, num_limbs);
689 
690     for i in 2..(TABLE_ENTRIES as Window) {
691         if i % 2 == 0 {
692             // TODO: Optimize this to avoid gathering
693             gather_square(table, state, m.n0(), i / 2, num_limbs, cpu_features);
694         } else {
695             gather_mul_base_amm(table, state, m.n0(), i - 1, num_limbs)
696         };
697         scatter(table, state, i, num_limbs);
698     }
699 
700     let state = limb::fold_5_bit_windows(
701         exponent.limbs(),
702         |initial_window| {
703             gather(table, state, initial_window, num_limbs);
704             state
705         },
706         |state, window| {
707             power_amm(table, state, m.n0(), window, num_limbs);
708             state
709         },
710     );
711 
712     let mut r_amm = base.limbs;
713     r_amm.copy_from_slice(entry(state, ACC, num_limbs));
714 
715     Ok(from_montgomery_amm(r_amm, m))
716 }
717 
718 /// Verified a == b**-1 (mod m), i.e. a**-1 == b (mod m).
verify_inverses_consttime<M>( a: &Elem<M, R>, b: Elem<M, Unencoded>, m: &Modulus<M>, ) -> Result<(), error::Unspecified>719 pub fn verify_inverses_consttime<M>(
720     a: &Elem<M, R>,
721     b: Elem<M, Unencoded>,
722     m: &Modulus<M>,
723 ) -> Result<(), error::Unspecified> {
724     if elem_mul(a, b, m).is_one() {
725         Ok(())
726     } else {
727         Err(error::Unspecified)
728     }
729 }
730 
731 #[inline]
elem_verify_equal_consttime<M, E>( a: &Elem<M, E>, b: &Elem<M, E>, ) -> Result<(), error::Unspecified>732 pub fn elem_verify_equal_consttime<M, E>(
733     a: &Elem<M, E>,
734     b: &Elem<M, E>,
735 ) -> Result<(), error::Unspecified> {
736     if limb::limbs_equal_limbs_consttime(&a.limbs, &b.limbs) == LimbMask::True {
737         Ok(())
738     } else {
739         Err(error::Unspecified)
740     }
741 }
742 
743 // TODO: Move these methods from `Nonnegative` to `Modulus`.
744 impl Nonnegative {
to_elem<M>(&self, m: &Modulus<M>) -> Result<Elem<M, Unencoded>, error::Unspecified>745     pub fn to_elem<M>(&self, m: &Modulus<M>) -> Result<Elem<M, Unencoded>, error::Unspecified> {
746         self.verify_less_than_modulus(m)?;
747         let mut r = m.zero();
748         r.limbs[0..self.limbs().len()].copy_from_slice(self.limbs());
749         Ok(r)
750     }
751 
verify_less_than_modulus<M>(&self, m: &Modulus<M>) -> Result<(), error::Unspecified>752     pub fn verify_less_than_modulus<M>(&self, m: &Modulus<M>) -> Result<(), error::Unspecified> {
753         if self.limbs().len() > m.limbs().len() {
754             return Err(error::Unspecified);
755         }
756         if self.limbs().len() == m.limbs().len() {
757             if limb::limbs_less_than_limbs_consttime(self.limbs(), m.limbs()) != LimbMask::True {
758                 return Err(error::Unspecified);
759             }
760         }
761         Ok(())
762     }
763 }
764 
765 /// r *= a
limbs_mont_mul(r: &mut [Limb], a: &[Limb], m: &[Limb], n0: &N0, _cpu_features: cpu::Features)766 fn limbs_mont_mul(r: &mut [Limb], a: &[Limb], m: &[Limb], n0: &N0, _cpu_features: cpu::Features) {
767     debug_assert_eq!(r.len(), m.len());
768     debug_assert_eq!(a.len(), m.len());
769     unsafe {
770         bn_mul_mont(
771             r.as_mut_ptr(),
772             r.as_ptr(),
773             a.as_ptr(),
774             m.as_ptr(),
775             n0,
776             r.len(),
777         )
778     }
779 }
780 
781 /// r = a * b
782 #[cfg(not(target_arch = "x86_64"))]
limbs_mont_product( r: &mut [Limb], a: &[Limb], b: &[Limb], m: &[Limb], n0: &N0, _cpu_features: cpu::Features, )783 fn limbs_mont_product(
784     r: &mut [Limb],
785     a: &[Limb],
786     b: &[Limb],
787     m: &[Limb],
788     n0: &N0,
789     _cpu_features: cpu::Features,
790 ) {
791     debug_assert_eq!(r.len(), m.len());
792     debug_assert_eq!(a.len(), m.len());
793     debug_assert_eq!(b.len(), m.len());
794 
795     unsafe {
796         bn_mul_mont(
797             r.as_mut_ptr(),
798             a.as_ptr(),
799             b.as_ptr(),
800             m.as_ptr(),
801             n0,
802             r.len(),
803         )
804     }
805 }
806 
807 /// r = r**2
limbs_mont_square(r: &mut [Limb], m: &[Limb], n0: &N0, _cpu_features: cpu::Features)808 fn limbs_mont_square(r: &mut [Limb], m: &[Limb], n0: &N0, _cpu_features: cpu::Features) {
809     debug_assert_eq!(r.len(), m.len());
810     unsafe {
811         bn_mul_mont(
812             r.as_mut_ptr(),
813             r.as_ptr(),
814             r.as_ptr(),
815             m.as_ptr(),
816             n0,
817             r.len(),
818         )
819     }
820 }
821 
822 prefixed_extern! {
823     // `r` and/or 'a' and/or 'b' may alias.
824     fn bn_mul_mont(
825         r: *mut Limb,
826         a: *const Limb,
827         b: *const Limb,
828         n: *const Limb,
829         n0: &N0,
830         num_limbs: c::size_t,
831     );
832 }
833 
834 #[cfg(test)]
835 mod tests {
836     use super::{modulus::MODULUS_MIN_LIMBS, *};
837     use crate::{limb::LIMB_BYTES, test};
838     use alloc::format;
839 
840     // Type-level representation of an arbitrary modulus.
841     struct M {}
842 
843     impl PublicModulus for M {}
844 
845     #[test]
test_elem_exp_consttime()846     fn test_elem_exp_consttime() {
847         let cpu_features = cpu::features();
848         test::run(
849             test_file!("../../crypto/fipsmodule/bn/test/mod_exp_tests.txt"),
850             |section, test_case| {
851                 assert_eq!(section, "");
852 
853                 let m = consume_modulus::<M>(test_case, "M", cpu_features);
854                 let expected_result = consume_elem(test_case, "ModExp", &m);
855                 let base = consume_elem(test_case, "A", &m);
856                 let e = {
857                     let bytes = test_case.consume_bytes("E");
858                     PrivateExponent::from_be_bytes_for_test_only(untrusted::Input::from(&bytes), &m)
859                         .expect("valid exponent")
860                 };
861                 let base = into_encoded(base, &m);
862                 let actual_result = elem_exp_consttime(base, &e, &m).unwrap();
863                 assert_elem_eq(&actual_result, &expected_result);
864 
865                 Ok(())
866             },
867         )
868     }
869 
870     // TODO: fn test_elem_exp_vartime() using
871     // "src/rsa/bigint_elem_exp_vartime_tests.txt". See that file for details.
872     // In the meantime, the function is tested indirectly via the RSA
873     // verification and signing tests.
874     #[test]
test_elem_mul()875     fn test_elem_mul() {
876         let cpu_features = cpu::features();
877         test::run(
878             test_file!("../../crypto/fipsmodule/bn/test/mod_mul_tests.txt"),
879             |section, test_case| {
880                 assert_eq!(section, "");
881 
882                 let m = consume_modulus::<M>(test_case, "M", cpu_features);
883                 let expected_result = consume_elem(test_case, "ModMul", &m);
884                 let a = consume_elem(test_case, "A", &m);
885                 let b = consume_elem(test_case, "B", &m);
886 
887                 let b = into_encoded(b, &m);
888                 let a = into_encoded(a, &m);
889                 let actual_result = elem_mul(&a, b, &m);
890                 let actual_result = actual_result.into_unencoded(&m);
891                 assert_elem_eq(&actual_result, &expected_result);
892 
893                 Ok(())
894             },
895         )
896     }
897 
898     #[test]
test_elem_squared()899     fn test_elem_squared() {
900         let cpu_features = cpu::features();
901         test::run(
902             test_file!("bigint_elem_squared_tests.txt"),
903             |section, test_case| {
904                 assert_eq!(section, "");
905 
906                 let m = consume_modulus::<M>(test_case, "M", cpu_features);
907                 let expected_result = consume_elem(test_case, "ModSquare", &m);
908                 let a = consume_elem(test_case, "A", &m);
909 
910                 let a = into_encoded(a, &m);
911                 let actual_result = elem_squared(a, &m.as_partial());
912                 let actual_result = actual_result.into_unencoded(&m);
913                 assert_elem_eq(&actual_result, &expected_result);
914 
915                 Ok(())
916             },
917         )
918     }
919 
920     #[test]
test_elem_reduced()921     fn test_elem_reduced() {
922         let cpu_features = cpu::features();
923         test::run(
924             test_file!("bigint_elem_reduced_tests.txt"),
925             |section, test_case| {
926                 assert_eq!(section, "");
927 
928                 struct MM {}
929                 unsafe impl SmallerModulus<MM> for M {}
930                 unsafe impl NotMuchSmallerModulus<MM> for M {}
931 
932                 let m = consume_modulus::<M>(test_case, "M", cpu_features);
933                 let expected_result = consume_elem(test_case, "R", &m);
934                 let a =
935                     consume_elem_unchecked::<MM>(test_case, "A", expected_result.limbs.len() * 2);
936 
937                 let actual_result = elem_reduced(&a, &m);
938                 let oneRR = m.oneRR();
939                 let actual_result = elem_mul(oneRR.as_ref(), actual_result, &m);
940                 assert_elem_eq(&actual_result, &expected_result);
941 
942                 Ok(())
943             },
944         )
945     }
946 
947     #[test]
test_elem_reduced_once()948     fn test_elem_reduced_once() {
949         let cpu_features = cpu::features();
950         test::run(
951             test_file!("bigint_elem_reduced_once_tests.txt"),
952             |section, test_case| {
953                 assert_eq!(section, "");
954 
955                 struct N {}
956                 struct QQ {}
957                 unsafe impl SmallerModulus<N> for QQ {}
958                 unsafe impl SlightlySmallerModulus<N> for QQ {}
959 
960                 let qq = consume_modulus::<QQ>(test_case, "QQ", cpu_features);
961                 let expected_result = consume_elem::<QQ>(test_case, "R", &qq);
962                 let n = consume_modulus::<N>(test_case, "N", cpu_features);
963                 let a = consume_elem::<N>(test_case, "A", &n);
964 
965                 let actual_result = elem_reduced_once(&a, &qq);
966                 assert_elem_eq(&actual_result, &expected_result);
967 
968                 Ok(())
969             },
970         )
971     }
972 
973     #[test]
test_modulus_debug()974     fn test_modulus_debug() {
975         let (modulus, _) = Modulus::<M>::from_be_bytes_with_bit_length(
976             untrusted::Input::from(&[0xff; LIMB_BYTES * MODULUS_MIN_LIMBS]),
977             cpu::features(),
978         )
979         .unwrap();
980         assert_eq!("Modulus", format!("{:?}", modulus));
981     }
982 
consume_elem<M>( test_case: &mut test::TestCase, name: &str, m: &Modulus<M>, ) -> Elem<M, Unencoded>983     fn consume_elem<M>(
984         test_case: &mut test::TestCase,
985         name: &str,
986         m: &Modulus<M>,
987     ) -> Elem<M, Unencoded> {
988         let value = test_case.consume_bytes(name);
989         Elem::from_be_bytes_padded(untrusted::Input::from(&value), m).unwrap()
990     }
991 
consume_elem_unchecked<M>( test_case: &mut test::TestCase, name: &str, num_limbs: usize, ) -> Elem<M, Unencoded>992     fn consume_elem_unchecked<M>(
993         test_case: &mut test::TestCase,
994         name: &str,
995         num_limbs: usize,
996     ) -> Elem<M, Unencoded> {
997         let value = consume_nonnegative(test_case, name);
998         let mut limbs = BoxedLimbs::zero(Width {
999             num_limbs,
1000             m: PhantomData,
1001         });
1002         limbs[0..value.limbs().len()].copy_from_slice(value.limbs());
1003         Elem {
1004             limbs,
1005             encoding: PhantomData,
1006         }
1007     }
1008 
consume_modulus<M>( test_case: &mut test::TestCase, name: &str, cpu_features: cpu::Features, ) -> Modulus<M>1009     fn consume_modulus<M>(
1010         test_case: &mut test::TestCase,
1011         name: &str,
1012         cpu_features: cpu::Features,
1013     ) -> Modulus<M> {
1014         let value = test_case.consume_bytes(name);
1015         let (value, _) =
1016             Modulus::from_be_bytes_with_bit_length(untrusted::Input::from(&value), cpu_features)
1017                 .unwrap();
1018         value
1019     }
1020 
consume_nonnegative(test_case: &mut test::TestCase, name: &str) -> Nonnegative1021     fn consume_nonnegative(test_case: &mut test::TestCase, name: &str) -> Nonnegative {
1022         let bytes = test_case.consume_bytes(name);
1023         let (r, _r_bits) =
1024             Nonnegative::from_be_bytes_with_bit_length(untrusted::Input::from(&bytes)).unwrap();
1025         r
1026     }
1027 
assert_elem_eq<M, E>(a: &Elem<M, E>, b: &Elem<M, E>)1028     fn assert_elem_eq<M, E>(a: &Elem<M, E>, b: &Elem<M, E>) {
1029         if elem_verify_equal_consttime(a, b).is_err() {
1030             panic!("{:x?} != {:x?}", &*a.limbs, &*b.limbs);
1031         }
1032     }
1033 
into_encoded<M>(a: Elem<M, Unencoded>, m: &Modulus<M>) -> Elem<M, R>1034     fn into_encoded<M>(a: Elem<M, Unencoded>, m: &Modulus<M>) -> Elem<M, R> {
1035         elem_mul(m.oneRR().as_ref(), a, m)
1036     }
1037 }
1038