xref: /aosp_15_r20/external/boringssl/src/crypto/kyber/kyber.c (revision 8fb009dc861624b67b6cdb62ea21f0f22d0c584b)
1 /* Copyright (c) 2023, Google Inc.
2  *
3  * Permission to use, copy, modify, and/or distribute this software for any
4  * purpose with or without fee is hereby granted, provided that the above
5  * copyright notice and this permission notice appear in all copies.
6  *
7  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
10  * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12  * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13  * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */
14 
15 #define OPENSSL_UNSTABLE_EXPERIMENTAL_KYBER
16 #include <openssl/experimental/kyber.h>
17 
18 #include <assert.h>
19 #include <stdlib.h>
20 
21 #include <openssl/bytestring.h>
22 #include <openssl/rand.h>
23 
24 #include "../internal.h"
25 #include "../keccak/internal.h"
26 #include "./internal.h"
27 
28 
29 // See
30 // https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf
31 
prf(uint8_t * out,size_t out_len,const uint8_t in[33])32 static void prf(uint8_t *out, size_t out_len, const uint8_t in[33]) {
33   BORINGSSL_keccak(out, out_len, in, 33, boringssl_shake256);
34 }
35 
hash_h(uint8_t out[32],const uint8_t * in,size_t len)36 static void hash_h(uint8_t out[32], const uint8_t *in, size_t len) {
37   BORINGSSL_keccak(out, 32, in, len, boringssl_sha3_256);
38 }
39 
hash_g(uint8_t out[64],const uint8_t * in,size_t len)40 static void hash_g(uint8_t out[64], const uint8_t *in, size_t len) {
41   BORINGSSL_keccak(out, 64, in, len, boringssl_sha3_512);
42 }
43 
kdf(uint8_t * out,size_t out_len,const uint8_t * in,size_t len)44 static void kdf(uint8_t *out, size_t out_len, const uint8_t *in, size_t len) {
45   BORINGSSL_keccak(out, out_len, in, len, boringssl_shake256);
46 }
47 
48 #define DEGREE 256
49 #define RANK 3
50 
51 static const size_t kBarrettMultiplier = 5039;
52 static const unsigned kBarrettShift = 24;
53 static const uint16_t kPrime = 3329;
54 static const int kLog2Prime = 12;
55 static const uint16_t kHalfPrime = (/*kPrime=*/3329 - 1) / 2;
56 static const int kDU = 10;
57 static const int kDV = 4;
58 // kInverseDegree is 128^-1 mod 3329; 128 because kPrime does not have a 512th
59 // root of unity.
60 static const uint16_t kInverseDegree = 3303;
61 static const size_t kEncodedVectorSize =
62     (/*kLog2Prime=*/12 * DEGREE / 8) * RANK;
63 static const size_t kCompressedVectorSize = /*kDU=*/10 * RANK * DEGREE / 8;
64 
65 typedef struct scalar {
66   // On every function entry and exit, 0 <= c < kPrime.
67   uint16_t c[DEGREE];
68 } scalar;
69 
70 typedef struct vector {
71   scalar v[RANK];
72 } vector;
73 
74 typedef struct matrix {
75   scalar v[RANK][RANK];
76 } matrix;
77 
78 // This bit of Python will be referenced in some of the following comments:
79 //
80 // p = 3329
81 //
82 // def bitreverse(i):
83 //     ret = 0
84 //     for n in range(7):
85 //         bit = i & 1
86 //         ret <<= 1
87 //         ret |= bit
88 //         i >>= 1
89 //     return ret
90 
91 // kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)]
92 static const uint16_t kNTTRoots[128] = {
93     1,    1729, 2580, 3289, 2642, 630,  1897, 848,  1062, 1919, 193,  797,
94     2786, 3260, 569,  1746, 296,  2447, 1339, 1476, 3046, 56,   2240, 1333,
95     1426, 2094, 535,  2882, 2393, 2879, 1974, 821,  289,  331,  3253, 1756,
96     1197, 2304, 2277, 2055, 650,  1977, 2513, 632,  2865, 33,   1320, 1915,
97     2319, 1435, 807,  452,  1438, 2868, 1534, 2402, 2647, 2617, 1481, 648,
98     2474, 3110, 1227, 910,  17,   2761, 583,  2649, 1637, 723,  2288, 1100,
99     1409, 2662, 3281, 233,  756,  2156, 3015, 3050, 1703, 1651, 2789, 1789,
100     1847, 952,  1461, 2687, 939,  2308, 2437, 2388, 733,  2337, 268,  641,
101     1584, 2298, 2037, 3220, 375,  2549, 2090, 1645, 1063, 319,  2773, 757,
102     2099, 561,  2466, 2594, 2804, 1092, 403,  1026, 1143, 2150, 2775, 886,
103     1722, 1212, 1874, 1029, 2110, 2935, 885,  2154,
104 };
105 
106 // kInverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)]
107 static const uint16_t kInverseNTTRoots[128] = {
108     1,    1600, 40,   749,  2481, 1432, 2699, 687,  1583, 2760, 69,   543,
109     2532, 3136, 1410, 2267, 2508, 1355, 450,  936,  447,  2794, 1235, 1903,
110     1996, 1089, 3273, 283,  1853, 1990, 882,  3033, 2419, 2102, 219,  855,
111     2681, 1848, 712,  682,  927,  1795, 461,  1891, 2877, 2522, 1894, 1010,
112     1414, 2009, 3296, 464,  2697, 816,  1352, 2679, 1274, 1052, 1025, 2132,
113     1573, 76,   2998, 3040, 1175, 2444, 394,  1219, 2300, 1455, 2117, 1607,
114     2443, 554,  1179, 2186, 2303, 2926, 2237, 525,  735,  863,  2768, 1230,
115     2572, 556,  3010, 2266, 1684, 1239, 780,  2954, 109,  1292, 1031, 1745,
116     2688, 3061, 992,  2596, 941,  892,  1021, 2390, 642,  1868, 2377, 1482,
117     1540, 540,  1678, 1626, 279,  314,  1173, 2573, 3096, 48,   667,  1920,
118     2229, 1041, 2606, 1692, 680,  2746, 568,  3312,
119 };
120 
121 // kModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)]
122 static const uint16_t kModRoots[128] = {
123     17,   3312, 2761, 568,  583,  2746, 2649, 680,  1637, 1692, 723,  2606,
124     2288, 1041, 1100, 2229, 1409, 1920, 2662, 667,  3281, 48,   233,  3096,
125     756,  2573, 2156, 1173, 3015, 314,  3050, 279,  1703, 1626, 1651, 1678,
126     2789, 540,  1789, 1540, 1847, 1482, 952,  2377, 1461, 1868, 2687, 642,
127     939,  2390, 2308, 1021, 2437, 892,  2388, 941,  733,  2596, 2337, 992,
128     268,  3061, 641,  2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109,
129     375,  2954, 2549, 780,  2090, 1239, 1645, 1684, 1063, 2266, 319,  3010,
130     2773, 556,  757,  2572, 2099, 1230, 561,  2768, 2466, 863,  2594, 735,
131     2804, 525,  1092, 2237, 403,  2926, 1026, 2303, 1143, 2186, 2150, 1179,
132     2775, 554,  886,  2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300,
133     2110, 1219, 2935, 394,  885,  2444, 2154, 1175,
134 };
135 
136 // reduce_once reduces 0 <= x < 2*kPrime, mod kPrime.
reduce_once(uint16_t x)137 static uint16_t reduce_once(uint16_t x) {
138   assert(x < 2 * kPrime);
139   const uint16_t subtracted = x - kPrime;
140   uint16_t mask = 0u - (subtracted >> 15);
141   // On Aarch64, omitting a |value_barrier_u16| results in a 2x speedup of Kyber
142   // overall and Clang still produces constant-time code using `csel`. On other
143   // platforms & compilers on godbolt that we care about, this code also
144   // produces constant-time output.
145   return (mask & x) | (~mask & subtracted);
146 }
147 
148 // constant time reduce x mod kPrime using Barrett reduction. x must be less
149 // than kPrime + 2×kPrime².
reduce(uint32_t x)150 static uint16_t reduce(uint32_t x) {
151   assert(x < kPrime + 2u * kPrime * kPrime);
152   uint64_t product = (uint64_t)x * kBarrettMultiplier;
153   uint32_t quotient = (uint32_t)(product >> kBarrettShift);
154   uint32_t remainder = x - quotient * kPrime;
155   return reduce_once(remainder);
156 }
157 
scalar_zero(scalar * out)158 static void scalar_zero(scalar *out) { OPENSSL_memset(out, 0, sizeof(*out)); }
159 
vector_zero(vector * out)160 static void vector_zero(vector *out) { OPENSSL_memset(out, 0, sizeof(*out)); }
161 
162 // In place number theoretic transform of a given scalar.
163 // Note that Kyber's kPrime 3329 does not have a 512th root of unity, so this
164 // transform leaves off the last iteration of the usual FFT code, with the 128
165 // relevant roots of unity being stored in |kNTTRoots|. This means the output
166 // should be seen as 128 elements in GF(3329^2), with the coefficients of the
167 // elements being consecutive entries in |s->c|.
scalar_ntt(scalar * s)168 static void scalar_ntt(scalar *s) {
169   int offset = DEGREE;
170   // `int` is used here because using `size_t` throughout caused a ~5% slowdown
171   // with Clang 14 on Aarch64.
172   for (int step = 1; step < DEGREE / 2; step <<= 1) {
173     offset >>= 1;
174     int k = 0;
175     for (int i = 0; i < step; i++) {
176       const uint32_t step_root = kNTTRoots[i + step];
177       for (int j = k; j < k + offset; j++) {
178         uint16_t odd = reduce(step_root * s->c[j + offset]);
179         uint16_t even = s->c[j];
180         s->c[j] = reduce_once(odd + even);
181         s->c[j + offset] = reduce_once(even - odd + kPrime);
182       }
183       k += 2 * offset;
184     }
185   }
186 }
187 
vector_ntt(vector * a)188 static void vector_ntt(vector *a) {
189   for (int i = 0; i < RANK; i++) {
190     scalar_ntt(&a->v[i]);
191   }
192 }
193 
194 // In place inverse number theoretic transform of a given scalar, with pairs of
195 // entries of s->v being interpreted as elements of GF(3329^2). Just as with the
196 // number theoretic transform, this leaves off the first step of the normal iFFT
197 // to account for the fact that 3329 does not have a 512th root of unity, using
198 // the precomputed 128 roots of unity stored in |kInverseNTTRoots|.
scalar_inverse_ntt(scalar * s)199 static void scalar_inverse_ntt(scalar *s) {
200   int step = DEGREE / 2;
201   // `int` is used here because using `size_t` throughout caused a ~5% slowdown
202   // with Clang 14 on Aarch64.
203   for (int offset = 2; offset < DEGREE; offset <<= 1) {
204     step >>= 1;
205     int k = 0;
206     for (int i = 0; i < step; i++) {
207       uint32_t step_root = kInverseNTTRoots[i + step];
208       for (int j = k; j < k + offset; j++) {
209         uint16_t odd = s->c[j + offset];
210         uint16_t even = s->c[j];
211         s->c[j] = reduce_once(odd + even);
212         s->c[j + offset] = reduce(step_root * (even - odd + kPrime));
213       }
214       k += 2 * offset;
215     }
216   }
217   for (int i = 0; i < DEGREE; i++) {
218     s->c[i] = reduce(s->c[i] * kInverseDegree);
219   }
220 }
221 
vector_inverse_ntt(vector * a)222 static void vector_inverse_ntt(vector *a) {
223   for (int i = 0; i < RANK; i++) {
224     scalar_inverse_ntt(&a->v[i]);
225   }
226 }
227 
scalar_add(scalar * lhs,const scalar * rhs)228 static void scalar_add(scalar *lhs, const scalar *rhs) {
229   for (int i = 0; i < DEGREE; i++) {
230     lhs->c[i] = reduce_once(lhs->c[i] + rhs->c[i]);
231   }
232 }
233 
scalar_sub(scalar * lhs,const scalar * rhs)234 static void scalar_sub(scalar *lhs, const scalar *rhs) {
235   for (int i = 0; i < DEGREE; i++) {
236     lhs->c[i] = reduce_once(lhs->c[i] - rhs->c[i] + kPrime);
237   }
238 }
239 
240 // Multiplying two scalars in the number theoretically transformed state. Since
241 // 3329 does not have a 512th root of unity, this means we have to interpret
242 // the 2*ith and (2*i+1)th entries of the scalar as elements of GF(3329)[X]/(X^2
243 // - 17^(2*bitreverse(i)+1)) The value of 17^(2*bitreverse(i)+1) mod 3329 is
244 // stored in the precomputed |kModRoots| table. Note that our Barrett transform
245 // only allows us to multipy two reduced numbers together, so we need some
246 // intermediate reduction steps, even if an uint64_t could hold 3 multiplied
247 // numbers.
scalar_mult(scalar * out,const scalar * lhs,const scalar * rhs)248 static void scalar_mult(scalar *out, const scalar *lhs, const scalar *rhs) {
249   for (int i = 0; i < DEGREE / 2; i++) {
250     uint32_t real_real = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i];
251     uint32_t img_img = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i + 1];
252     uint32_t real_img = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i + 1];
253     uint32_t img_real = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i];
254     out->c[2 * i] =
255         reduce(real_real + (uint32_t)reduce(img_img) * kModRoots[i]);
256     out->c[2 * i + 1] = reduce(img_real + real_img);
257   }
258 }
259 
vector_add(vector * lhs,const vector * rhs)260 static void vector_add(vector *lhs, const vector *rhs) {
261   for (int i = 0; i < RANK; i++) {
262     scalar_add(&lhs->v[i], &rhs->v[i]);
263   }
264 }
265 
matrix_mult(vector * out,const matrix * m,const vector * a)266 static void matrix_mult(vector *out, const matrix *m, const vector *a) {
267   vector_zero(out);
268   for (int i = 0; i < RANK; i++) {
269     for (int j = 0; j < RANK; j++) {
270       scalar product;
271       scalar_mult(&product, &m->v[i][j], &a->v[j]);
272       scalar_add(&out->v[i], &product);
273     }
274   }
275 }
276 
matrix_mult_transpose(vector * out,const matrix * m,const vector * a)277 static void matrix_mult_transpose(vector *out, const matrix *m,
278                                   const vector *a) {
279   vector_zero(out);
280   for (int i = 0; i < RANK; i++) {
281     for (int j = 0; j < RANK; j++) {
282       scalar product;
283       scalar_mult(&product, &m->v[j][i], &a->v[j]);
284       scalar_add(&out->v[i], &product);
285     }
286   }
287 }
288 
scalar_inner_product(scalar * out,const vector * lhs,const vector * rhs)289 static void scalar_inner_product(scalar *out, const vector *lhs,
290                                  const vector *rhs) {
291   scalar_zero(out);
292   for (int i = 0; i < RANK; i++) {
293     scalar product;
294     scalar_mult(&product, &lhs->v[i], &rhs->v[i]);
295     scalar_add(out, &product);
296   }
297 }
298 
299 // Algorithm 1 of the Kyber spec. Rejection samples a Keccak stream to get
300 // uniformly distributed elements. This is used for matrix expansion and only
301 // operates on public inputs.
scalar_from_keccak_vartime(scalar * out,struct BORINGSSL_keccak_st * keccak_ctx)302 static void scalar_from_keccak_vartime(scalar *out,
303                                        struct BORINGSSL_keccak_st *keccak_ctx) {
304   assert(keccak_ctx->squeeze_offset == 0);
305   assert(keccak_ctx->rate_bytes == 168);
306   static_assert(168 % 3 == 0, "block and coefficient boundaries do not align");
307 
308   int done = 0;
309   while (done < DEGREE) {
310     uint8_t block[168];
311     BORINGSSL_keccak_squeeze(keccak_ctx, block, sizeof(block));
312     for (size_t i = 0; i < sizeof(block) && done < DEGREE; i += 3) {
313       uint16_t d1 = block[i] + 256 * (block[i + 1] % 16);
314       uint16_t d2 = block[i + 1] / 16 + 16 * block[i + 2];
315       if (d1 < kPrime) {
316         out->c[done++] = d1;
317       }
318       if (d2 < kPrime && done < DEGREE) {
319         out->c[done++] = d2;
320       }
321     }
322   }
323 }
324 
325 // Algorithm 2 of the Kyber spec, with eta fixed to two and the PRF call
326 // included. Creates binominally distributed elements by sampling 2*|eta| bits,
327 // and setting the coefficient to the count of the first bits minus the count of
328 // the second bits, resulting in a centered binomial distribution. Since eta is
329 // two this gives -2/2 with a probability of 1/16, -1/1 with probability 1/4,
330 // and 0 with probability 3/8.
scalar_centered_binomial_distribution_eta_2_with_prf(scalar * out,const uint8_t input[33])331 static void scalar_centered_binomial_distribution_eta_2_with_prf(
332     scalar *out, const uint8_t input[33]) {
333   uint8_t entropy[128];
334   static_assert(sizeof(entropy) == 2 * /*kEta=*/2 * DEGREE / 8, "");
335   prf(entropy, sizeof(entropy), input);
336 
337   for (int i = 0; i < DEGREE; i += 2) {
338     uint8_t byte = entropy[i / 2];
339 
340     uint16_t value = kPrime;
341     value += (byte & 1) + ((byte >> 1) & 1);
342     value -= ((byte >> 2) & 1) + ((byte >> 3) & 1);
343     out->c[i] = reduce_once(value);
344 
345     byte >>= 4;
346     value = kPrime;
347     value += (byte & 1) + ((byte >> 1) & 1);
348     value -= ((byte >> 2) & 1) + ((byte >> 3) & 1);
349     out->c[i + 1] = reduce_once(value);
350   }
351 }
352 
353 // Generates a secret vector by using
354 // |scalar_centered_binomial_distribution_eta_2_with_prf|, using the given seed
355 // appending and incrementing |counter| for entry of the vector.
vector_generate_secret_eta_2(vector * out,uint8_t * counter,const uint8_t seed[32])356 static void vector_generate_secret_eta_2(vector *out, uint8_t *counter,
357                                          const uint8_t seed[32]) {
358   uint8_t input[33];
359   OPENSSL_memcpy(input, seed, 32);
360   for (int i = 0; i < RANK; i++) {
361     input[32] = (*counter)++;
362     scalar_centered_binomial_distribution_eta_2_with_prf(&out->v[i], input);
363   }
364 }
365 
366 // Expands the matrix of a seed for key generation and for encaps-CPA.
matrix_expand(matrix * out,const uint8_t rho[32])367 static void matrix_expand(matrix *out, const uint8_t rho[32]) {
368   uint8_t input[34];
369   OPENSSL_memcpy(input, rho, 32);
370   for (int i = 0; i < RANK; i++) {
371     for (int j = 0; j < RANK; j++) {
372       input[32] = i;
373       input[33] = j;
374       struct BORINGSSL_keccak_st keccak_ctx;
375       BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake128);
376       BORINGSSL_keccak_absorb(&keccak_ctx, input, sizeof(input));
377       scalar_from_keccak_vartime(&out->v[i][j], &keccak_ctx);
378     }
379   }
380 }
381 
382 static const uint8_t kMasks[8] = {0x01, 0x03, 0x07, 0x0f,
383                                   0x1f, 0x3f, 0x7f, 0xff};
384 
scalar_encode(uint8_t * out,const scalar * s,int bits)385 static void scalar_encode(uint8_t *out, const scalar *s, int bits) {
386   assert(bits <= (int)sizeof(*s->c) * 8 && bits != 1);
387 
388   uint8_t out_byte = 0;
389   int out_byte_bits = 0;
390 
391   for (int i = 0; i < DEGREE; i++) {
392     uint16_t element = s->c[i];
393     int element_bits_done = 0;
394 
395     while (element_bits_done < bits) {
396       int chunk_bits = bits - element_bits_done;
397       int out_bits_remaining = 8 - out_byte_bits;
398       if (chunk_bits >= out_bits_remaining) {
399         chunk_bits = out_bits_remaining;
400         out_byte |= (element & kMasks[chunk_bits - 1]) << out_byte_bits;
401         *out = out_byte;
402         out++;
403         out_byte_bits = 0;
404         out_byte = 0;
405       } else {
406         out_byte |= (element & kMasks[chunk_bits - 1]) << out_byte_bits;
407         out_byte_bits += chunk_bits;
408       }
409 
410       element_bits_done += chunk_bits;
411       element >>= chunk_bits;
412     }
413   }
414 
415   if (out_byte_bits > 0) {
416     *out = out_byte;
417   }
418 }
419 
420 // scalar_encode_1 is |scalar_encode| specialised for |bits| == 1.
scalar_encode_1(uint8_t out[32],const scalar * s)421 static void scalar_encode_1(uint8_t out[32], const scalar *s) {
422   for (int i = 0; i < DEGREE; i += 8) {
423     uint8_t out_byte = 0;
424     for (int j = 0; j < 8; j++) {
425       out_byte |= (s->c[i + j] & 1) << j;
426     }
427     *out = out_byte;
428     out++;
429   }
430 }
431 
432 // Encodes an entire vector into 32*|RANK|*|bits| bytes. Note that since 256
433 // (DEGREE) is divisible by 8, the individual vector entries will always fill a
434 // whole number of bytes, so we do not need to worry about bit packing here.
vector_encode(uint8_t * out,const vector * a,int bits)435 static void vector_encode(uint8_t *out, const vector *a, int bits) {
436   for (int i = 0; i < RANK; i++) {
437     scalar_encode(out + i * bits * DEGREE / 8, &a->v[i], bits);
438   }
439 }
440 
441 // scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in
442 // |out|. It returns one on success and zero if any parsed value is >=
443 // |kPrime|.
scalar_decode(scalar * out,const uint8_t * in,int bits)444 static int scalar_decode(scalar *out, const uint8_t *in, int bits) {
445   assert(bits <= (int)sizeof(*out->c) * 8 && bits != 1);
446 
447   uint8_t in_byte = 0;
448   int in_byte_bits_left = 0;
449 
450   for (int i = 0; i < DEGREE; i++) {
451     uint16_t element = 0;
452     int element_bits_done = 0;
453 
454     while (element_bits_done < bits) {
455       if (in_byte_bits_left == 0) {
456         in_byte = *in;
457         in++;
458         in_byte_bits_left = 8;
459       }
460 
461       int chunk_bits = bits - element_bits_done;
462       if (chunk_bits > in_byte_bits_left) {
463         chunk_bits = in_byte_bits_left;
464       }
465 
466       element |= (in_byte & kMasks[chunk_bits - 1]) << element_bits_done;
467       in_byte_bits_left -= chunk_bits;
468       in_byte >>= chunk_bits;
469 
470       element_bits_done += chunk_bits;
471     }
472 
473     if (element >= kPrime) {
474       return 0;
475     }
476     out->c[i] = element;
477   }
478 
479   return 1;
480 }
481 
482 // scalar_decode_1 is |scalar_decode| specialised for |bits| == 1.
scalar_decode_1(scalar * out,const uint8_t in[32])483 static void scalar_decode_1(scalar *out, const uint8_t in[32]) {
484   for (int i = 0; i < DEGREE; i += 8) {
485     uint8_t in_byte = *in;
486     in++;
487     for (int j = 0; j < 8; j++) {
488       out->c[i + j] = in_byte & 1;
489       in_byte >>= 1;
490     }
491   }
492 }
493 
494 // Decodes 32*|RANK|*|bits| bytes from |in| into |out|. It returns one on
495 // success or zero if any parsed value is >= |kPrime|.
vector_decode(vector * out,const uint8_t * in,int bits)496 static int vector_decode(vector *out, const uint8_t *in, int bits) {
497   for (int i = 0; i < RANK; i++) {
498     if (!scalar_decode(&out->v[i], in + i * bits * DEGREE / 8, bits)) {
499       return 0;
500     }
501   }
502   return 1;
503 }
504 
505 // Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping
506 // numbers close to each other together. The formula used is
507 // round(2^|bits|/kPrime*x) mod 2^|bits|.
508 // Uses Barrett reduction to achieve constant time. Since we need both the
509 // remainder (for rounding) and the quotient (as the result), we cannot use
510 // |reduce| here, but need to do the Barrett reduction directly.
compress(uint16_t x,int bits)511 static uint16_t compress(uint16_t x, int bits) {
512   uint32_t shifted = (uint32_t)x << bits;
513   uint64_t product = (uint64_t)shifted * kBarrettMultiplier;
514   uint32_t quotient = (uint32_t)(product >> kBarrettShift);
515   uint32_t remainder = shifted - quotient * kPrime;
516 
517   // Adjust the quotient to round correctly:
518   //   0 <= remainder <= kHalfPrime round to 0
519   //   kHalfPrime < remainder <= kPrime + kHalfPrime round to 1
520   //   kPrime + kHalfPrime < remainder < 2 * kPrime round to 2
521   assert(remainder < 2u * kPrime);
522   quotient += 1 & constant_time_lt_w(kHalfPrime, remainder);
523   quotient += 1 & constant_time_lt_w(kPrime + kHalfPrime, remainder);
524   return quotient & ((1 << bits) - 1);
525 }
526 
527 // Decompresses |x| by using an equi-distant representative. The formula is
528 // round(kPrime/2^|bits|*x). Note that 2^|bits| being the divisor allows us to
529 // implement this logic using only bit operations.
decompress(uint16_t x,int bits)530 static uint16_t decompress(uint16_t x, int bits) {
531   uint32_t product = (uint32_t)x * kPrime;
532   uint32_t power = 1 << bits;
533   // This is |product| % power, since |power| is a power of 2.
534   uint32_t remainder = product & (power - 1);
535   // This is |product| / power, since |power| is a power of 2.
536   uint32_t lower = product >> bits;
537   // The rounding logic works since the first half of numbers mod |power| have a
538   // 0 as first bit, and the second half has a 1 as first bit, since |power| is
539   // a power of 2. As a 12 bit number, |remainder| is always positive, so we
540   // will shift in 0s for a right shift.
541   return lower + (remainder >> (bits - 1));
542 }
543 
scalar_compress(scalar * s,int bits)544 static void scalar_compress(scalar *s, int bits) {
545   for (int i = 0; i < DEGREE; i++) {
546     s->c[i] = compress(s->c[i], bits);
547   }
548 }
549 
scalar_decompress(scalar * s,int bits)550 static void scalar_decompress(scalar *s, int bits) {
551   for (int i = 0; i < DEGREE; i++) {
552     s->c[i] = decompress(s->c[i], bits);
553   }
554 }
555 
vector_compress(vector * a,int bits)556 static void vector_compress(vector *a, int bits) {
557   for (int i = 0; i < RANK; i++) {
558     scalar_compress(&a->v[i], bits);
559   }
560 }
561 
vector_decompress(vector * a,int bits)562 static void vector_decompress(vector *a, int bits) {
563   for (int i = 0; i < RANK; i++) {
564     scalar_decompress(&a->v[i], bits);
565   }
566 }
567 
568 struct public_key {
569   vector t;
570   uint8_t rho[32];
571   uint8_t public_key_hash[32];
572   matrix m;
573 };
574 
public_key_from_external(const struct KYBER_public_key * external)575 static struct public_key *public_key_from_external(
576     const struct KYBER_public_key *external) {
577   static_assert(sizeof(struct KYBER_public_key) >= sizeof(struct public_key),
578                 "Kyber public key is too small");
579   static_assert(alignof(struct KYBER_public_key) >= alignof(struct public_key),
580                 "Kyber public key align incorrect");
581   return (struct public_key *)external;
582 }
583 
584 struct private_key {
585   struct public_key pub;
586   vector s;
587   uint8_t fo_failure_secret[32];
588 };
589 
private_key_from_external(const struct KYBER_private_key * external)590 static struct private_key *private_key_from_external(
591     const struct KYBER_private_key *external) {
592   static_assert(sizeof(struct KYBER_private_key) >= sizeof(struct private_key),
593                 "Kyber private key too small");
594   static_assert(
595       alignof(struct KYBER_private_key) >= alignof(struct private_key),
596       "Kyber private key align incorrect");
597   return (struct private_key *)external;
598 }
599 
600 // Calls |KYBER_generate_key_external_entropy| with random bytes from
601 // |RAND_bytes|.
KYBER_generate_key(uint8_t out_encoded_public_key[KYBER_PUBLIC_KEY_BYTES],struct KYBER_private_key * out_private_key)602 void KYBER_generate_key(uint8_t out_encoded_public_key[KYBER_PUBLIC_KEY_BYTES],
603                         struct KYBER_private_key *out_private_key) {
604   uint8_t entropy[KYBER_GENERATE_KEY_ENTROPY];
605   RAND_bytes(entropy, sizeof(entropy));
606   KYBER_generate_key_external_entropy(out_encoded_public_key, out_private_key,
607                                       entropy);
608 }
609 
kyber_marshal_public_key(CBB * out,const struct public_key * pub)610 static int kyber_marshal_public_key(CBB *out, const struct public_key *pub) {
611   uint8_t *vector_output;
612   if (!CBB_add_space(out, &vector_output, kEncodedVectorSize)) {
613     return 0;
614   }
615   vector_encode(vector_output, &pub->t, kLog2Prime);
616   if (!CBB_add_bytes(out, pub->rho, sizeof(pub->rho))) {
617     return 0;
618   }
619   return 1;
620 }
621 
622 // Algorithms 4 and 7 of the Kyber spec. Algorithms are combined since key
623 // generation is not part of the FO transform, and the spec uses Algorithm 7 to
624 // specify the actual key format.
KYBER_generate_key_external_entropy(uint8_t out_encoded_public_key[KYBER_PUBLIC_KEY_BYTES],struct KYBER_private_key * out_private_key,const uint8_t entropy[KYBER_GENERATE_KEY_ENTROPY])625 void KYBER_generate_key_external_entropy(
626     uint8_t out_encoded_public_key[KYBER_PUBLIC_KEY_BYTES],
627     struct KYBER_private_key *out_private_key,
628     const uint8_t entropy[KYBER_GENERATE_KEY_ENTROPY]) {
629   struct private_key *priv = private_key_from_external(out_private_key);
630   uint8_t hashed[64];
631   hash_g(hashed, entropy, 32);
632   const uint8_t *const rho = hashed;
633   const uint8_t *const sigma = hashed + 32;
634   OPENSSL_memcpy(priv->pub.rho, hashed, sizeof(priv->pub.rho));
635   matrix_expand(&priv->pub.m, rho);
636   uint8_t counter = 0;
637   vector_generate_secret_eta_2(&priv->s, &counter, sigma);
638   vector_ntt(&priv->s);
639   vector error;
640   vector_generate_secret_eta_2(&error, &counter, sigma);
641   vector_ntt(&error);
642   matrix_mult_transpose(&priv->pub.t, &priv->pub.m, &priv->s);
643   vector_add(&priv->pub.t, &error);
644 
645   CBB cbb;
646   CBB_init_fixed(&cbb, out_encoded_public_key, KYBER_PUBLIC_KEY_BYTES);
647   if (!kyber_marshal_public_key(&cbb, &priv->pub)) {
648     abort();
649   }
650 
651   hash_h(priv->pub.public_key_hash, out_encoded_public_key,
652          KYBER_PUBLIC_KEY_BYTES);
653   OPENSSL_memcpy(priv->fo_failure_secret, entropy + 32, 32);
654 }
655 
KYBER_public_from_private(struct KYBER_public_key * out_public_key,const struct KYBER_private_key * private_key)656 void KYBER_public_from_private(struct KYBER_public_key *out_public_key,
657                                const struct KYBER_private_key *private_key) {
658   struct public_key *const pub = public_key_from_external(out_public_key);
659   const struct private_key *const priv = private_key_from_external(private_key);
660   *pub = priv->pub;
661 }
662 
663 // Algorithm 5 of the Kyber spec. Encrypts a message with given randomness to
664 // the ciphertext in |out|. Without applying the Fujisaki-Okamoto transform this
665 // would not result in a CCA secure scheme, since lattice schemes are vulnerable
666 // to decryption failure oracles.
encrypt_cpa(uint8_t out[KYBER_CIPHERTEXT_BYTES],const struct public_key * pub,const uint8_t message[32],const uint8_t randomness[32])667 static void encrypt_cpa(uint8_t out[KYBER_CIPHERTEXT_BYTES],
668                         const struct public_key *pub, const uint8_t message[32],
669                         const uint8_t randomness[32]) {
670   uint8_t counter = 0;
671   vector secret;
672   vector_generate_secret_eta_2(&secret, &counter, randomness);
673   vector_ntt(&secret);
674   vector error;
675   vector_generate_secret_eta_2(&error, &counter, randomness);
676   uint8_t input[33];
677   OPENSSL_memcpy(input, randomness, 32);
678   input[32] = counter;
679   scalar scalar_error;
680   scalar_centered_binomial_distribution_eta_2_with_prf(&scalar_error, input);
681   vector u;
682   matrix_mult(&u, &pub->m, &secret);
683   vector_inverse_ntt(&u);
684   vector_add(&u, &error);
685   scalar v;
686   scalar_inner_product(&v, &pub->t, &secret);
687   scalar_inverse_ntt(&v);
688   scalar_add(&v, &scalar_error);
689   scalar expanded_message;
690   scalar_decode_1(&expanded_message, message);
691   scalar_decompress(&expanded_message, 1);
692   scalar_add(&v, &expanded_message);
693   vector_compress(&u, kDU);
694   vector_encode(out, &u, kDU);
695   scalar_compress(&v, kDV);
696   scalar_encode(out + kCompressedVectorSize, &v, kDV);
697 }
698 
699 // Calls KYBER_encap_external_entropy| with random bytes from |RAND_bytes|
KYBER_encap(uint8_t out_ciphertext[KYBER_CIPHERTEXT_BYTES],uint8_t out_shared_secret[KYBER_SHARED_SECRET_BYTES],const struct KYBER_public_key * public_key)700 void KYBER_encap(uint8_t out_ciphertext[KYBER_CIPHERTEXT_BYTES],
701                  uint8_t out_shared_secret[KYBER_SHARED_SECRET_BYTES],
702                  const struct KYBER_public_key *public_key) {
703   uint8_t entropy[KYBER_ENCAP_ENTROPY];
704   RAND_bytes(entropy, KYBER_ENCAP_ENTROPY);
705   KYBER_encap_external_entropy(out_ciphertext, out_shared_secret, public_key,
706                                entropy);
707 }
708 
709 // Algorithm 8 of the Kyber spec, safe for line 2 of the spec. The spec there
710 // hashes the output of the system's random number generator, since the FO
711 // transform will reveal it to the decrypting party. There is no reason to do
712 // this when a secure random number generator is used. When an insecure random
713 // number generator is used, the caller should switch to a secure one before
714 // calling this method.
KYBER_encap_external_entropy(uint8_t out_ciphertext[KYBER_CIPHERTEXT_BYTES],uint8_t out_shared_secret[KYBER_SHARED_SECRET_BYTES],const struct KYBER_public_key * public_key,const uint8_t entropy[KYBER_ENCAP_ENTROPY])715 void KYBER_encap_external_entropy(
716     uint8_t out_ciphertext[KYBER_CIPHERTEXT_BYTES],
717     uint8_t out_shared_secret[KYBER_SHARED_SECRET_BYTES],
718     const struct KYBER_public_key *public_key,
719     const uint8_t entropy[KYBER_ENCAP_ENTROPY]) {
720   const struct public_key *pub = public_key_from_external(public_key);
721   uint8_t input[64];
722   OPENSSL_memcpy(input, entropy, KYBER_ENCAP_ENTROPY);
723   OPENSSL_memcpy(input + KYBER_ENCAP_ENTROPY, pub->public_key_hash,
724                  sizeof(input) - KYBER_ENCAP_ENTROPY);
725   uint8_t prekey_and_randomness[64];
726   hash_g(prekey_and_randomness, input, sizeof(input));
727   encrypt_cpa(out_ciphertext, pub, entropy, prekey_and_randomness + 32);
728   hash_h(prekey_and_randomness + 32, out_ciphertext, KYBER_CIPHERTEXT_BYTES);
729   kdf(out_shared_secret, KYBER_SHARED_SECRET_BYTES, prekey_and_randomness,
730       sizeof(prekey_and_randomness));
731 }
732 
733 // Algorithm 6 of the Kyber spec.
decrypt_cpa(uint8_t out[32],const struct private_key * priv,const uint8_t ciphertext[KYBER_CIPHERTEXT_BYTES])734 static void decrypt_cpa(uint8_t out[32], const struct private_key *priv,
735                         const uint8_t ciphertext[KYBER_CIPHERTEXT_BYTES]) {
736   vector u;
737   vector_decode(&u, ciphertext, kDU);
738   vector_decompress(&u, kDU);
739   vector_ntt(&u);
740   scalar v;
741   scalar_decode(&v, ciphertext + kCompressedVectorSize, kDV);
742   scalar_decompress(&v, kDV);
743   scalar mask;
744   scalar_inner_product(&mask, &priv->s, &u);
745   scalar_inverse_ntt(&mask);
746   scalar_sub(&v, &mask);
747   scalar_compress(&v, 1);
748   scalar_encode_1(out, &v);
749 }
750 
751 // Algorithm 9 of the Kyber spec, performing the FO transform by running
752 // encrypt_cpa on the decrypted message. The spec does not allow the decryption
753 // failure to be passed on to the caller, and instead returns a result that is
754 // deterministic but unpredictable to anyone without knowledge of the private
755 // key.
KYBER_decap(uint8_t out_shared_secret[KYBER_SHARED_SECRET_BYTES],const uint8_t ciphertext[KYBER_CIPHERTEXT_BYTES],const struct KYBER_private_key * private_key)756 void KYBER_decap(uint8_t out_shared_secret[KYBER_SHARED_SECRET_BYTES],
757                  const uint8_t ciphertext[KYBER_CIPHERTEXT_BYTES],
758                  const struct KYBER_private_key *private_key) {
759   const struct private_key *priv = private_key_from_external(private_key);
760   uint8_t decrypted[64];
761   decrypt_cpa(decrypted, priv, ciphertext);
762   OPENSSL_memcpy(decrypted + 32, priv->pub.public_key_hash,
763                  sizeof(decrypted) - 32);
764   uint8_t prekey_and_randomness[64];
765   hash_g(prekey_and_randomness, decrypted, sizeof(decrypted));
766   uint8_t expected_ciphertext[KYBER_CIPHERTEXT_BYTES];
767   encrypt_cpa(expected_ciphertext, &priv->pub, decrypted,
768               prekey_and_randomness + 32);
769   uint8_t mask =
770       constant_time_eq_int_8(CRYPTO_memcmp(ciphertext, expected_ciphertext,
771                                            sizeof(expected_ciphertext)),
772                              0);
773   uint8_t input[64];
774   for (int i = 0; i < 32; i++) {
775     input[i] = constant_time_select_8(mask, prekey_and_randomness[i],
776                                       priv->fo_failure_secret[i]);
777   }
778   hash_h(input + 32, ciphertext, KYBER_CIPHERTEXT_BYTES);
779   kdf(out_shared_secret, KYBER_SHARED_SECRET_BYTES, input, sizeof(input));
780 }
781 
KYBER_marshal_public_key(CBB * out,const struct KYBER_public_key * public_key)782 int KYBER_marshal_public_key(CBB *out,
783                              const struct KYBER_public_key *public_key) {
784   return kyber_marshal_public_key(out, public_key_from_external(public_key));
785 }
786 
787 // kyber_parse_public_key_no_hash parses |in| into |pub| but doesn't calculate
788 // the value of |pub->public_key_hash|.
kyber_parse_public_key_no_hash(struct public_key * pub,CBS * in)789 static int kyber_parse_public_key_no_hash(struct public_key *pub, CBS *in) {
790   CBS t_bytes;
791   if (!CBS_get_bytes(in, &t_bytes, kEncodedVectorSize) ||
792       !vector_decode(&pub->t, CBS_data(&t_bytes), kLog2Prime) ||
793       !CBS_copy_bytes(in, pub->rho, sizeof(pub->rho))) {
794     return 0;
795   }
796   matrix_expand(&pub->m, pub->rho);
797   return 1;
798 }
799 
KYBER_parse_public_key(struct KYBER_public_key * public_key,CBS * in)800 int KYBER_parse_public_key(struct KYBER_public_key *public_key, CBS *in) {
801   struct public_key *pub = public_key_from_external(public_key);
802   CBS orig_in = *in;
803   if (!kyber_parse_public_key_no_hash(pub, in) ||  //
804       CBS_len(in) != 0) {
805     return 0;
806   }
807   hash_h(pub->public_key_hash, CBS_data(&orig_in), CBS_len(&orig_in));
808   return 1;
809 }
810 
KYBER_marshal_private_key(CBB * out,const struct KYBER_private_key * private_key)811 int KYBER_marshal_private_key(CBB *out,
812                               const struct KYBER_private_key *private_key) {
813   const struct private_key *const priv = private_key_from_external(private_key);
814   uint8_t *s_output;
815   if (!CBB_add_space(out, &s_output, kEncodedVectorSize)) {
816     return 0;
817   }
818   vector_encode(s_output, &priv->s, kLog2Prime);
819   if (!kyber_marshal_public_key(out, &priv->pub) ||
820       !CBB_add_bytes(out, priv->pub.public_key_hash,
821                      sizeof(priv->pub.public_key_hash)) ||
822       !CBB_add_bytes(out, priv->fo_failure_secret,
823                      sizeof(priv->fo_failure_secret))) {
824     return 0;
825   }
826   return 1;
827 }
828 
KYBER_parse_private_key(struct KYBER_private_key * out_private_key,CBS * in)829 int KYBER_parse_private_key(struct KYBER_private_key *out_private_key,
830                             CBS *in) {
831   struct private_key *const priv = private_key_from_external(out_private_key);
832 
833   CBS s_bytes;
834   if (!CBS_get_bytes(in, &s_bytes, kEncodedVectorSize) ||
835       !vector_decode(&priv->s, CBS_data(&s_bytes), kLog2Prime) ||
836       !kyber_parse_public_key_no_hash(&priv->pub, in) ||
837       !CBS_copy_bytes(in, priv->pub.public_key_hash,
838                       sizeof(priv->pub.public_key_hash)) ||
839       !CBS_copy_bytes(in, priv->fo_failure_secret,
840                       sizeof(priv->fo_failure_secret)) ||
841       CBS_len(in) != 0) {
842     return 0;
843   }
844   return 1;
845 }
846