1 /* Copyright (c) 2023, Google Inc.
2 *
3 * Permission to use, copy, modify, and/or distribute this software for any
4 * purpose with or without fee is hereby granted, provided that the above
5 * copyright notice and this permission notice appear in all copies.
6 *
7 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
10 * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12 * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13 * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */
14
15 #define OPENSSL_UNSTABLE_EXPERIMENTAL_KYBER
16 #include <openssl/experimental/kyber.h>
17
18 #include <assert.h>
19 #include <stdlib.h>
20
21 #include <openssl/bytestring.h>
22 #include <openssl/rand.h>
23
24 #include "../internal.h"
25 #include "../keccak/internal.h"
26 #include "./internal.h"
27
28
29 // See
30 // https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf
31
prf(uint8_t * out,size_t out_len,const uint8_t in[33])32 static void prf(uint8_t *out, size_t out_len, const uint8_t in[33]) {
33 BORINGSSL_keccak(out, out_len, in, 33, boringssl_shake256);
34 }
35
hash_h(uint8_t out[32],const uint8_t * in,size_t len)36 static void hash_h(uint8_t out[32], const uint8_t *in, size_t len) {
37 BORINGSSL_keccak(out, 32, in, len, boringssl_sha3_256);
38 }
39
hash_g(uint8_t out[64],const uint8_t * in,size_t len)40 static void hash_g(uint8_t out[64], const uint8_t *in, size_t len) {
41 BORINGSSL_keccak(out, 64, in, len, boringssl_sha3_512);
42 }
43
kdf(uint8_t * out,size_t out_len,const uint8_t * in,size_t len)44 static void kdf(uint8_t *out, size_t out_len, const uint8_t *in, size_t len) {
45 BORINGSSL_keccak(out, out_len, in, len, boringssl_shake256);
46 }
47
48 #define DEGREE 256
49 #define RANK 3
50
51 static const size_t kBarrettMultiplier = 5039;
52 static const unsigned kBarrettShift = 24;
53 static const uint16_t kPrime = 3329;
54 static const int kLog2Prime = 12;
55 static const uint16_t kHalfPrime = (/*kPrime=*/3329 - 1) / 2;
56 static const int kDU = 10;
57 static const int kDV = 4;
58 // kInverseDegree is 128^-1 mod 3329; 128 because kPrime does not have a 512th
59 // root of unity.
60 static const uint16_t kInverseDegree = 3303;
61 static const size_t kEncodedVectorSize =
62 (/*kLog2Prime=*/12 * DEGREE / 8) * RANK;
63 static const size_t kCompressedVectorSize = /*kDU=*/10 * RANK * DEGREE / 8;
64
65 typedef struct scalar {
66 // On every function entry and exit, 0 <= c < kPrime.
67 uint16_t c[DEGREE];
68 } scalar;
69
70 typedef struct vector {
71 scalar v[RANK];
72 } vector;
73
74 typedef struct matrix {
75 scalar v[RANK][RANK];
76 } matrix;
77
78 // This bit of Python will be referenced in some of the following comments:
79 //
80 // p = 3329
81 //
82 // def bitreverse(i):
83 // ret = 0
84 // for n in range(7):
85 // bit = i & 1
86 // ret <<= 1
87 // ret |= bit
88 // i >>= 1
89 // return ret
90
91 // kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)]
92 static const uint16_t kNTTRoots[128] = {
93 1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797,
94 2786, 3260, 569, 1746, 296, 2447, 1339, 1476, 3046, 56, 2240, 1333,
95 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, 289, 331, 3253, 1756,
96 1197, 2304, 2277, 2055, 650, 1977, 2513, 632, 2865, 33, 1320, 1915,
97 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, 2647, 2617, 1481, 648,
98 2474, 3110, 1227, 910, 17, 2761, 583, 2649, 1637, 723, 2288, 1100,
99 1409, 2662, 3281, 233, 756, 2156, 3015, 3050, 1703, 1651, 2789, 1789,
100 1847, 952, 1461, 2687, 939, 2308, 2437, 2388, 733, 2337, 268, 641,
101 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, 1063, 319, 2773, 757,
102 2099, 561, 2466, 2594, 2804, 1092, 403, 1026, 1143, 2150, 2775, 886,
103 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154,
104 };
105
106 // kInverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)]
107 static const uint16_t kInverseNTTRoots[128] = {
108 1, 1600, 40, 749, 2481, 1432, 2699, 687, 1583, 2760, 69, 543,
109 2532, 3136, 1410, 2267, 2508, 1355, 450, 936, 447, 2794, 1235, 1903,
110 1996, 1089, 3273, 283, 1853, 1990, 882, 3033, 2419, 2102, 219, 855,
111 2681, 1848, 712, 682, 927, 1795, 461, 1891, 2877, 2522, 1894, 1010,
112 1414, 2009, 3296, 464, 2697, 816, 1352, 2679, 1274, 1052, 1025, 2132,
113 1573, 76, 2998, 3040, 1175, 2444, 394, 1219, 2300, 1455, 2117, 1607,
114 2443, 554, 1179, 2186, 2303, 2926, 2237, 525, 735, 863, 2768, 1230,
115 2572, 556, 3010, 2266, 1684, 1239, 780, 2954, 109, 1292, 1031, 1745,
116 2688, 3061, 992, 2596, 941, 892, 1021, 2390, 642, 1868, 2377, 1482,
117 1540, 540, 1678, 1626, 279, 314, 1173, 2573, 3096, 48, 667, 1920,
118 2229, 1041, 2606, 1692, 680, 2746, 568, 3312,
119 };
120
121 // kModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)]
122 static const uint16_t kModRoots[128] = {
123 17, 3312, 2761, 568, 583, 2746, 2649, 680, 1637, 1692, 723, 2606,
124 2288, 1041, 1100, 2229, 1409, 1920, 2662, 667, 3281, 48, 233, 3096,
125 756, 2573, 2156, 1173, 3015, 314, 3050, 279, 1703, 1626, 1651, 1678,
126 2789, 540, 1789, 1540, 1847, 1482, 952, 2377, 1461, 1868, 2687, 642,
127 939, 2390, 2308, 1021, 2437, 892, 2388, 941, 733, 2596, 2337, 992,
128 268, 3061, 641, 2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109,
129 375, 2954, 2549, 780, 2090, 1239, 1645, 1684, 1063, 2266, 319, 3010,
130 2773, 556, 757, 2572, 2099, 1230, 561, 2768, 2466, 863, 2594, 735,
131 2804, 525, 1092, 2237, 403, 2926, 1026, 2303, 1143, 2186, 2150, 1179,
132 2775, 554, 886, 2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300,
133 2110, 1219, 2935, 394, 885, 2444, 2154, 1175,
134 };
135
136 // reduce_once reduces 0 <= x < 2*kPrime, mod kPrime.
reduce_once(uint16_t x)137 static uint16_t reduce_once(uint16_t x) {
138 assert(x < 2 * kPrime);
139 const uint16_t subtracted = x - kPrime;
140 uint16_t mask = 0u - (subtracted >> 15);
141 // On Aarch64, omitting a |value_barrier_u16| results in a 2x speedup of Kyber
142 // overall and Clang still produces constant-time code using `csel`. On other
143 // platforms & compilers on godbolt that we care about, this code also
144 // produces constant-time output.
145 return (mask & x) | (~mask & subtracted);
146 }
147
148 // constant time reduce x mod kPrime using Barrett reduction. x must be less
149 // than kPrime + 2×kPrime².
reduce(uint32_t x)150 static uint16_t reduce(uint32_t x) {
151 assert(x < kPrime + 2u * kPrime * kPrime);
152 uint64_t product = (uint64_t)x * kBarrettMultiplier;
153 uint32_t quotient = (uint32_t)(product >> kBarrettShift);
154 uint32_t remainder = x - quotient * kPrime;
155 return reduce_once(remainder);
156 }
157
scalar_zero(scalar * out)158 static void scalar_zero(scalar *out) { OPENSSL_memset(out, 0, sizeof(*out)); }
159
vector_zero(vector * out)160 static void vector_zero(vector *out) { OPENSSL_memset(out, 0, sizeof(*out)); }
161
162 // In place number theoretic transform of a given scalar.
163 // Note that Kyber's kPrime 3329 does not have a 512th root of unity, so this
164 // transform leaves off the last iteration of the usual FFT code, with the 128
165 // relevant roots of unity being stored in |kNTTRoots|. This means the output
166 // should be seen as 128 elements in GF(3329^2), with the coefficients of the
167 // elements being consecutive entries in |s->c|.
scalar_ntt(scalar * s)168 static void scalar_ntt(scalar *s) {
169 int offset = DEGREE;
170 // `int` is used here because using `size_t` throughout caused a ~5% slowdown
171 // with Clang 14 on Aarch64.
172 for (int step = 1; step < DEGREE / 2; step <<= 1) {
173 offset >>= 1;
174 int k = 0;
175 for (int i = 0; i < step; i++) {
176 const uint32_t step_root = kNTTRoots[i + step];
177 for (int j = k; j < k + offset; j++) {
178 uint16_t odd = reduce(step_root * s->c[j + offset]);
179 uint16_t even = s->c[j];
180 s->c[j] = reduce_once(odd + even);
181 s->c[j + offset] = reduce_once(even - odd + kPrime);
182 }
183 k += 2 * offset;
184 }
185 }
186 }
187
vector_ntt(vector * a)188 static void vector_ntt(vector *a) {
189 for (int i = 0; i < RANK; i++) {
190 scalar_ntt(&a->v[i]);
191 }
192 }
193
194 // In place inverse number theoretic transform of a given scalar, with pairs of
195 // entries of s->v being interpreted as elements of GF(3329^2). Just as with the
196 // number theoretic transform, this leaves off the first step of the normal iFFT
197 // to account for the fact that 3329 does not have a 512th root of unity, using
198 // the precomputed 128 roots of unity stored in |kInverseNTTRoots|.
scalar_inverse_ntt(scalar * s)199 static void scalar_inverse_ntt(scalar *s) {
200 int step = DEGREE / 2;
201 // `int` is used here because using `size_t` throughout caused a ~5% slowdown
202 // with Clang 14 on Aarch64.
203 for (int offset = 2; offset < DEGREE; offset <<= 1) {
204 step >>= 1;
205 int k = 0;
206 for (int i = 0; i < step; i++) {
207 uint32_t step_root = kInverseNTTRoots[i + step];
208 for (int j = k; j < k + offset; j++) {
209 uint16_t odd = s->c[j + offset];
210 uint16_t even = s->c[j];
211 s->c[j] = reduce_once(odd + even);
212 s->c[j + offset] = reduce(step_root * (even - odd + kPrime));
213 }
214 k += 2 * offset;
215 }
216 }
217 for (int i = 0; i < DEGREE; i++) {
218 s->c[i] = reduce(s->c[i] * kInverseDegree);
219 }
220 }
221
vector_inverse_ntt(vector * a)222 static void vector_inverse_ntt(vector *a) {
223 for (int i = 0; i < RANK; i++) {
224 scalar_inverse_ntt(&a->v[i]);
225 }
226 }
227
scalar_add(scalar * lhs,const scalar * rhs)228 static void scalar_add(scalar *lhs, const scalar *rhs) {
229 for (int i = 0; i < DEGREE; i++) {
230 lhs->c[i] = reduce_once(lhs->c[i] + rhs->c[i]);
231 }
232 }
233
scalar_sub(scalar * lhs,const scalar * rhs)234 static void scalar_sub(scalar *lhs, const scalar *rhs) {
235 for (int i = 0; i < DEGREE; i++) {
236 lhs->c[i] = reduce_once(lhs->c[i] - rhs->c[i] + kPrime);
237 }
238 }
239
240 // Multiplying two scalars in the number theoretically transformed state. Since
241 // 3329 does not have a 512th root of unity, this means we have to interpret
242 // the 2*ith and (2*i+1)th entries of the scalar as elements of GF(3329)[X]/(X^2
243 // - 17^(2*bitreverse(i)+1)) The value of 17^(2*bitreverse(i)+1) mod 3329 is
244 // stored in the precomputed |kModRoots| table. Note that our Barrett transform
245 // only allows us to multipy two reduced numbers together, so we need some
246 // intermediate reduction steps, even if an uint64_t could hold 3 multiplied
247 // numbers.
scalar_mult(scalar * out,const scalar * lhs,const scalar * rhs)248 static void scalar_mult(scalar *out, const scalar *lhs, const scalar *rhs) {
249 for (int i = 0; i < DEGREE / 2; i++) {
250 uint32_t real_real = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i];
251 uint32_t img_img = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i + 1];
252 uint32_t real_img = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i + 1];
253 uint32_t img_real = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i];
254 out->c[2 * i] =
255 reduce(real_real + (uint32_t)reduce(img_img) * kModRoots[i]);
256 out->c[2 * i + 1] = reduce(img_real + real_img);
257 }
258 }
259
vector_add(vector * lhs,const vector * rhs)260 static void vector_add(vector *lhs, const vector *rhs) {
261 for (int i = 0; i < RANK; i++) {
262 scalar_add(&lhs->v[i], &rhs->v[i]);
263 }
264 }
265
matrix_mult(vector * out,const matrix * m,const vector * a)266 static void matrix_mult(vector *out, const matrix *m, const vector *a) {
267 vector_zero(out);
268 for (int i = 0; i < RANK; i++) {
269 for (int j = 0; j < RANK; j++) {
270 scalar product;
271 scalar_mult(&product, &m->v[i][j], &a->v[j]);
272 scalar_add(&out->v[i], &product);
273 }
274 }
275 }
276
matrix_mult_transpose(vector * out,const matrix * m,const vector * a)277 static void matrix_mult_transpose(vector *out, const matrix *m,
278 const vector *a) {
279 vector_zero(out);
280 for (int i = 0; i < RANK; i++) {
281 for (int j = 0; j < RANK; j++) {
282 scalar product;
283 scalar_mult(&product, &m->v[j][i], &a->v[j]);
284 scalar_add(&out->v[i], &product);
285 }
286 }
287 }
288
scalar_inner_product(scalar * out,const vector * lhs,const vector * rhs)289 static void scalar_inner_product(scalar *out, const vector *lhs,
290 const vector *rhs) {
291 scalar_zero(out);
292 for (int i = 0; i < RANK; i++) {
293 scalar product;
294 scalar_mult(&product, &lhs->v[i], &rhs->v[i]);
295 scalar_add(out, &product);
296 }
297 }
298
299 // Algorithm 1 of the Kyber spec. Rejection samples a Keccak stream to get
300 // uniformly distributed elements. This is used for matrix expansion and only
301 // operates on public inputs.
scalar_from_keccak_vartime(scalar * out,struct BORINGSSL_keccak_st * keccak_ctx)302 static void scalar_from_keccak_vartime(scalar *out,
303 struct BORINGSSL_keccak_st *keccak_ctx) {
304 assert(keccak_ctx->squeeze_offset == 0);
305 assert(keccak_ctx->rate_bytes == 168);
306 static_assert(168 % 3 == 0, "block and coefficient boundaries do not align");
307
308 int done = 0;
309 while (done < DEGREE) {
310 uint8_t block[168];
311 BORINGSSL_keccak_squeeze(keccak_ctx, block, sizeof(block));
312 for (size_t i = 0; i < sizeof(block) && done < DEGREE; i += 3) {
313 uint16_t d1 = block[i] + 256 * (block[i + 1] % 16);
314 uint16_t d2 = block[i + 1] / 16 + 16 * block[i + 2];
315 if (d1 < kPrime) {
316 out->c[done++] = d1;
317 }
318 if (d2 < kPrime && done < DEGREE) {
319 out->c[done++] = d2;
320 }
321 }
322 }
323 }
324
325 // Algorithm 2 of the Kyber spec, with eta fixed to two and the PRF call
326 // included. Creates binominally distributed elements by sampling 2*|eta| bits,
327 // and setting the coefficient to the count of the first bits minus the count of
328 // the second bits, resulting in a centered binomial distribution. Since eta is
329 // two this gives -2/2 with a probability of 1/16, -1/1 with probability 1/4,
330 // and 0 with probability 3/8.
scalar_centered_binomial_distribution_eta_2_with_prf(scalar * out,const uint8_t input[33])331 static void scalar_centered_binomial_distribution_eta_2_with_prf(
332 scalar *out, const uint8_t input[33]) {
333 uint8_t entropy[128];
334 static_assert(sizeof(entropy) == 2 * /*kEta=*/2 * DEGREE / 8, "");
335 prf(entropy, sizeof(entropy), input);
336
337 for (int i = 0; i < DEGREE; i += 2) {
338 uint8_t byte = entropy[i / 2];
339
340 uint16_t value = kPrime;
341 value += (byte & 1) + ((byte >> 1) & 1);
342 value -= ((byte >> 2) & 1) + ((byte >> 3) & 1);
343 out->c[i] = reduce_once(value);
344
345 byte >>= 4;
346 value = kPrime;
347 value += (byte & 1) + ((byte >> 1) & 1);
348 value -= ((byte >> 2) & 1) + ((byte >> 3) & 1);
349 out->c[i + 1] = reduce_once(value);
350 }
351 }
352
353 // Generates a secret vector by using
354 // |scalar_centered_binomial_distribution_eta_2_with_prf|, using the given seed
355 // appending and incrementing |counter| for entry of the vector.
vector_generate_secret_eta_2(vector * out,uint8_t * counter,const uint8_t seed[32])356 static void vector_generate_secret_eta_2(vector *out, uint8_t *counter,
357 const uint8_t seed[32]) {
358 uint8_t input[33];
359 OPENSSL_memcpy(input, seed, 32);
360 for (int i = 0; i < RANK; i++) {
361 input[32] = (*counter)++;
362 scalar_centered_binomial_distribution_eta_2_with_prf(&out->v[i], input);
363 }
364 }
365
366 // Expands the matrix of a seed for key generation and for encaps-CPA.
matrix_expand(matrix * out,const uint8_t rho[32])367 static void matrix_expand(matrix *out, const uint8_t rho[32]) {
368 uint8_t input[34];
369 OPENSSL_memcpy(input, rho, 32);
370 for (int i = 0; i < RANK; i++) {
371 for (int j = 0; j < RANK; j++) {
372 input[32] = i;
373 input[33] = j;
374 struct BORINGSSL_keccak_st keccak_ctx;
375 BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake128);
376 BORINGSSL_keccak_absorb(&keccak_ctx, input, sizeof(input));
377 scalar_from_keccak_vartime(&out->v[i][j], &keccak_ctx);
378 }
379 }
380 }
381
382 static const uint8_t kMasks[8] = {0x01, 0x03, 0x07, 0x0f,
383 0x1f, 0x3f, 0x7f, 0xff};
384
scalar_encode(uint8_t * out,const scalar * s,int bits)385 static void scalar_encode(uint8_t *out, const scalar *s, int bits) {
386 assert(bits <= (int)sizeof(*s->c) * 8 && bits != 1);
387
388 uint8_t out_byte = 0;
389 int out_byte_bits = 0;
390
391 for (int i = 0; i < DEGREE; i++) {
392 uint16_t element = s->c[i];
393 int element_bits_done = 0;
394
395 while (element_bits_done < bits) {
396 int chunk_bits = bits - element_bits_done;
397 int out_bits_remaining = 8 - out_byte_bits;
398 if (chunk_bits >= out_bits_remaining) {
399 chunk_bits = out_bits_remaining;
400 out_byte |= (element & kMasks[chunk_bits - 1]) << out_byte_bits;
401 *out = out_byte;
402 out++;
403 out_byte_bits = 0;
404 out_byte = 0;
405 } else {
406 out_byte |= (element & kMasks[chunk_bits - 1]) << out_byte_bits;
407 out_byte_bits += chunk_bits;
408 }
409
410 element_bits_done += chunk_bits;
411 element >>= chunk_bits;
412 }
413 }
414
415 if (out_byte_bits > 0) {
416 *out = out_byte;
417 }
418 }
419
420 // scalar_encode_1 is |scalar_encode| specialised for |bits| == 1.
scalar_encode_1(uint8_t out[32],const scalar * s)421 static void scalar_encode_1(uint8_t out[32], const scalar *s) {
422 for (int i = 0; i < DEGREE; i += 8) {
423 uint8_t out_byte = 0;
424 for (int j = 0; j < 8; j++) {
425 out_byte |= (s->c[i + j] & 1) << j;
426 }
427 *out = out_byte;
428 out++;
429 }
430 }
431
432 // Encodes an entire vector into 32*|RANK|*|bits| bytes. Note that since 256
433 // (DEGREE) is divisible by 8, the individual vector entries will always fill a
434 // whole number of bytes, so we do not need to worry about bit packing here.
vector_encode(uint8_t * out,const vector * a,int bits)435 static void vector_encode(uint8_t *out, const vector *a, int bits) {
436 for (int i = 0; i < RANK; i++) {
437 scalar_encode(out + i * bits * DEGREE / 8, &a->v[i], bits);
438 }
439 }
440
441 // scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in
442 // |out|. It returns one on success and zero if any parsed value is >=
443 // |kPrime|.
scalar_decode(scalar * out,const uint8_t * in,int bits)444 static int scalar_decode(scalar *out, const uint8_t *in, int bits) {
445 assert(bits <= (int)sizeof(*out->c) * 8 && bits != 1);
446
447 uint8_t in_byte = 0;
448 int in_byte_bits_left = 0;
449
450 for (int i = 0; i < DEGREE; i++) {
451 uint16_t element = 0;
452 int element_bits_done = 0;
453
454 while (element_bits_done < bits) {
455 if (in_byte_bits_left == 0) {
456 in_byte = *in;
457 in++;
458 in_byte_bits_left = 8;
459 }
460
461 int chunk_bits = bits - element_bits_done;
462 if (chunk_bits > in_byte_bits_left) {
463 chunk_bits = in_byte_bits_left;
464 }
465
466 element |= (in_byte & kMasks[chunk_bits - 1]) << element_bits_done;
467 in_byte_bits_left -= chunk_bits;
468 in_byte >>= chunk_bits;
469
470 element_bits_done += chunk_bits;
471 }
472
473 if (element >= kPrime) {
474 return 0;
475 }
476 out->c[i] = element;
477 }
478
479 return 1;
480 }
481
482 // scalar_decode_1 is |scalar_decode| specialised for |bits| == 1.
scalar_decode_1(scalar * out,const uint8_t in[32])483 static void scalar_decode_1(scalar *out, const uint8_t in[32]) {
484 for (int i = 0; i < DEGREE; i += 8) {
485 uint8_t in_byte = *in;
486 in++;
487 for (int j = 0; j < 8; j++) {
488 out->c[i + j] = in_byte & 1;
489 in_byte >>= 1;
490 }
491 }
492 }
493
494 // Decodes 32*|RANK|*|bits| bytes from |in| into |out|. It returns one on
495 // success or zero if any parsed value is >= |kPrime|.
vector_decode(vector * out,const uint8_t * in,int bits)496 static int vector_decode(vector *out, const uint8_t *in, int bits) {
497 for (int i = 0; i < RANK; i++) {
498 if (!scalar_decode(&out->v[i], in + i * bits * DEGREE / 8, bits)) {
499 return 0;
500 }
501 }
502 return 1;
503 }
504
505 // Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping
506 // numbers close to each other together. The formula used is
507 // round(2^|bits|/kPrime*x) mod 2^|bits|.
508 // Uses Barrett reduction to achieve constant time. Since we need both the
509 // remainder (for rounding) and the quotient (as the result), we cannot use
510 // |reduce| here, but need to do the Barrett reduction directly.
compress(uint16_t x,int bits)511 static uint16_t compress(uint16_t x, int bits) {
512 uint32_t shifted = (uint32_t)x << bits;
513 uint64_t product = (uint64_t)shifted * kBarrettMultiplier;
514 uint32_t quotient = (uint32_t)(product >> kBarrettShift);
515 uint32_t remainder = shifted - quotient * kPrime;
516
517 // Adjust the quotient to round correctly:
518 // 0 <= remainder <= kHalfPrime round to 0
519 // kHalfPrime < remainder <= kPrime + kHalfPrime round to 1
520 // kPrime + kHalfPrime < remainder < 2 * kPrime round to 2
521 assert(remainder < 2u * kPrime);
522 quotient += 1 & constant_time_lt_w(kHalfPrime, remainder);
523 quotient += 1 & constant_time_lt_w(kPrime + kHalfPrime, remainder);
524 return quotient & ((1 << bits) - 1);
525 }
526
527 // Decompresses |x| by using an equi-distant representative. The formula is
528 // round(kPrime/2^|bits|*x). Note that 2^|bits| being the divisor allows us to
529 // implement this logic using only bit operations.
decompress(uint16_t x,int bits)530 static uint16_t decompress(uint16_t x, int bits) {
531 uint32_t product = (uint32_t)x * kPrime;
532 uint32_t power = 1 << bits;
533 // This is |product| % power, since |power| is a power of 2.
534 uint32_t remainder = product & (power - 1);
535 // This is |product| / power, since |power| is a power of 2.
536 uint32_t lower = product >> bits;
537 // The rounding logic works since the first half of numbers mod |power| have a
538 // 0 as first bit, and the second half has a 1 as first bit, since |power| is
539 // a power of 2. As a 12 bit number, |remainder| is always positive, so we
540 // will shift in 0s for a right shift.
541 return lower + (remainder >> (bits - 1));
542 }
543
scalar_compress(scalar * s,int bits)544 static void scalar_compress(scalar *s, int bits) {
545 for (int i = 0; i < DEGREE; i++) {
546 s->c[i] = compress(s->c[i], bits);
547 }
548 }
549
scalar_decompress(scalar * s,int bits)550 static void scalar_decompress(scalar *s, int bits) {
551 for (int i = 0; i < DEGREE; i++) {
552 s->c[i] = decompress(s->c[i], bits);
553 }
554 }
555
vector_compress(vector * a,int bits)556 static void vector_compress(vector *a, int bits) {
557 for (int i = 0; i < RANK; i++) {
558 scalar_compress(&a->v[i], bits);
559 }
560 }
561
vector_decompress(vector * a,int bits)562 static void vector_decompress(vector *a, int bits) {
563 for (int i = 0; i < RANK; i++) {
564 scalar_decompress(&a->v[i], bits);
565 }
566 }
567
568 struct public_key {
569 vector t;
570 uint8_t rho[32];
571 uint8_t public_key_hash[32];
572 matrix m;
573 };
574
public_key_from_external(const struct KYBER_public_key * external)575 static struct public_key *public_key_from_external(
576 const struct KYBER_public_key *external) {
577 static_assert(sizeof(struct KYBER_public_key) >= sizeof(struct public_key),
578 "Kyber public key is too small");
579 static_assert(alignof(struct KYBER_public_key) >= alignof(struct public_key),
580 "Kyber public key align incorrect");
581 return (struct public_key *)external;
582 }
583
584 struct private_key {
585 struct public_key pub;
586 vector s;
587 uint8_t fo_failure_secret[32];
588 };
589
private_key_from_external(const struct KYBER_private_key * external)590 static struct private_key *private_key_from_external(
591 const struct KYBER_private_key *external) {
592 static_assert(sizeof(struct KYBER_private_key) >= sizeof(struct private_key),
593 "Kyber private key too small");
594 static_assert(
595 alignof(struct KYBER_private_key) >= alignof(struct private_key),
596 "Kyber private key align incorrect");
597 return (struct private_key *)external;
598 }
599
600 // Calls |KYBER_generate_key_external_entropy| with random bytes from
601 // |RAND_bytes|.
KYBER_generate_key(uint8_t out_encoded_public_key[KYBER_PUBLIC_KEY_BYTES],struct KYBER_private_key * out_private_key)602 void KYBER_generate_key(uint8_t out_encoded_public_key[KYBER_PUBLIC_KEY_BYTES],
603 struct KYBER_private_key *out_private_key) {
604 uint8_t entropy[KYBER_GENERATE_KEY_ENTROPY];
605 RAND_bytes(entropy, sizeof(entropy));
606 KYBER_generate_key_external_entropy(out_encoded_public_key, out_private_key,
607 entropy);
608 }
609
kyber_marshal_public_key(CBB * out,const struct public_key * pub)610 static int kyber_marshal_public_key(CBB *out, const struct public_key *pub) {
611 uint8_t *vector_output;
612 if (!CBB_add_space(out, &vector_output, kEncodedVectorSize)) {
613 return 0;
614 }
615 vector_encode(vector_output, &pub->t, kLog2Prime);
616 if (!CBB_add_bytes(out, pub->rho, sizeof(pub->rho))) {
617 return 0;
618 }
619 return 1;
620 }
621
622 // Algorithms 4 and 7 of the Kyber spec. Algorithms are combined since key
623 // generation is not part of the FO transform, and the spec uses Algorithm 7 to
624 // specify the actual key format.
KYBER_generate_key_external_entropy(uint8_t out_encoded_public_key[KYBER_PUBLIC_KEY_BYTES],struct KYBER_private_key * out_private_key,const uint8_t entropy[KYBER_GENERATE_KEY_ENTROPY])625 void KYBER_generate_key_external_entropy(
626 uint8_t out_encoded_public_key[KYBER_PUBLIC_KEY_BYTES],
627 struct KYBER_private_key *out_private_key,
628 const uint8_t entropy[KYBER_GENERATE_KEY_ENTROPY]) {
629 struct private_key *priv = private_key_from_external(out_private_key);
630 uint8_t hashed[64];
631 hash_g(hashed, entropy, 32);
632 const uint8_t *const rho = hashed;
633 const uint8_t *const sigma = hashed + 32;
634 OPENSSL_memcpy(priv->pub.rho, hashed, sizeof(priv->pub.rho));
635 matrix_expand(&priv->pub.m, rho);
636 uint8_t counter = 0;
637 vector_generate_secret_eta_2(&priv->s, &counter, sigma);
638 vector_ntt(&priv->s);
639 vector error;
640 vector_generate_secret_eta_2(&error, &counter, sigma);
641 vector_ntt(&error);
642 matrix_mult_transpose(&priv->pub.t, &priv->pub.m, &priv->s);
643 vector_add(&priv->pub.t, &error);
644
645 CBB cbb;
646 CBB_init_fixed(&cbb, out_encoded_public_key, KYBER_PUBLIC_KEY_BYTES);
647 if (!kyber_marshal_public_key(&cbb, &priv->pub)) {
648 abort();
649 }
650
651 hash_h(priv->pub.public_key_hash, out_encoded_public_key,
652 KYBER_PUBLIC_KEY_BYTES);
653 OPENSSL_memcpy(priv->fo_failure_secret, entropy + 32, 32);
654 }
655
KYBER_public_from_private(struct KYBER_public_key * out_public_key,const struct KYBER_private_key * private_key)656 void KYBER_public_from_private(struct KYBER_public_key *out_public_key,
657 const struct KYBER_private_key *private_key) {
658 struct public_key *const pub = public_key_from_external(out_public_key);
659 const struct private_key *const priv = private_key_from_external(private_key);
660 *pub = priv->pub;
661 }
662
663 // Algorithm 5 of the Kyber spec. Encrypts a message with given randomness to
664 // the ciphertext in |out|. Without applying the Fujisaki-Okamoto transform this
665 // would not result in a CCA secure scheme, since lattice schemes are vulnerable
666 // to decryption failure oracles.
encrypt_cpa(uint8_t out[KYBER_CIPHERTEXT_BYTES],const struct public_key * pub,const uint8_t message[32],const uint8_t randomness[32])667 static void encrypt_cpa(uint8_t out[KYBER_CIPHERTEXT_BYTES],
668 const struct public_key *pub, const uint8_t message[32],
669 const uint8_t randomness[32]) {
670 uint8_t counter = 0;
671 vector secret;
672 vector_generate_secret_eta_2(&secret, &counter, randomness);
673 vector_ntt(&secret);
674 vector error;
675 vector_generate_secret_eta_2(&error, &counter, randomness);
676 uint8_t input[33];
677 OPENSSL_memcpy(input, randomness, 32);
678 input[32] = counter;
679 scalar scalar_error;
680 scalar_centered_binomial_distribution_eta_2_with_prf(&scalar_error, input);
681 vector u;
682 matrix_mult(&u, &pub->m, &secret);
683 vector_inverse_ntt(&u);
684 vector_add(&u, &error);
685 scalar v;
686 scalar_inner_product(&v, &pub->t, &secret);
687 scalar_inverse_ntt(&v);
688 scalar_add(&v, &scalar_error);
689 scalar expanded_message;
690 scalar_decode_1(&expanded_message, message);
691 scalar_decompress(&expanded_message, 1);
692 scalar_add(&v, &expanded_message);
693 vector_compress(&u, kDU);
694 vector_encode(out, &u, kDU);
695 scalar_compress(&v, kDV);
696 scalar_encode(out + kCompressedVectorSize, &v, kDV);
697 }
698
699 // Calls KYBER_encap_external_entropy| with random bytes from |RAND_bytes|
KYBER_encap(uint8_t out_ciphertext[KYBER_CIPHERTEXT_BYTES],uint8_t out_shared_secret[KYBER_SHARED_SECRET_BYTES],const struct KYBER_public_key * public_key)700 void KYBER_encap(uint8_t out_ciphertext[KYBER_CIPHERTEXT_BYTES],
701 uint8_t out_shared_secret[KYBER_SHARED_SECRET_BYTES],
702 const struct KYBER_public_key *public_key) {
703 uint8_t entropy[KYBER_ENCAP_ENTROPY];
704 RAND_bytes(entropy, KYBER_ENCAP_ENTROPY);
705 KYBER_encap_external_entropy(out_ciphertext, out_shared_secret, public_key,
706 entropy);
707 }
708
709 // Algorithm 8 of the Kyber spec, safe for line 2 of the spec. The spec there
710 // hashes the output of the system's random number generator, since the FO
711 // transform will reveal it to the decrypting party. There is no reason to do
712 // this when a secure random number generator is used. When an insecure random
713 // number generator is used, the caller should switch to a secure one before
714 // calling this method.
KYBER_encap_external_entropy(uint8_t out_ciphertext[KYBER_CIPHERTEXT_BYTES],uint8_t out_shared_secret[KYBER_SHARED_SECRET_BYTES],const struct KYBER_public_key * public_key,const uint8_t entropy[KYBER_ENCAP_ENTROPY])715 void KYBER_encap_external_entropy(
716 uint8_t out_ciphertext[KYBER_CIPHERTEXT_BYTES],
717 uint8_t out_shared_secret[KYBER_SHARED_SECRET_BYTES],
718 const struct KYBER_public_key *public_key,
719 const uint8_t entropy[KYBER_ENCAP_ENTROPY]) {
720 const struct public_key *pub = public_key_from_external(public_key);
721 uint8_t input[64];
722 OPENSSL_memcpy(input, entropy, KYBER_ENCAP_ENTROPY);
723 OPENSSL_memcpy(input + KYBER_ENCAP_ENTROPY, pub->public_key_hash,
724 sizeof(input) - KYBER_ENCAP_ENTROPY);
725 uint8_t prekey_and_randomness[64];
726 hash_g(prekey_and_randomness, input, sizeof(input));
727 encrypt_cpa(out_ciphertext, pub, entropy, prekey_and_randomness + 32);
728 hash_h(prekey_and_randomness + 32, out_ciphertext, KYBER_CIPHERTEXT_BYTES);
729 kdf(out_shared_secret, KYBER_SHARED_SECRET_BYTES, prekey_and_randomness,
730 sizeof(prekey_and_randomness));
731 }
732
733 // Algorithm 6 of the Kyber spec.
decrypt_cpa(uint8_t out[32],const struct private_key * priv,const uint8_t ciphertext[KYBER_CIPHERTEXT_BYTES])734 static void decrypt_cpa(uint8_t out[32], const struct private_key *priv,
735 const uint8_t ciphertext[KYBER_CIPHERTEXT_BYTES]) {
736 vector u;
737 vector_decode(&u, ciphertext, kDU);
738 vector_decompress(&u, kDU);
739 vector_ntt(&u);
740 scalar v;
741 scalar_decode(&v, ciphertext + kCompressedVectorSize, kDV);
742 scalar_decompress(&v, kDV);
743 scalar mask;
744 scalar_inner_product(&mask, &priv->s, &u);
745 scalar_inverse_ntt(&mask);
746 scalar_sub(&v, &mask);
747 scalar_compress(&v, 1);
748 scalar_encode_1(out, &v);
749 }
750
751 // Algorithm 9 of the Kyber spec, performing the FO transform by running
752 // encrypt_cpa on the decrypted message. The spec does not allow the decryption
753 // failure to be passed on to the caller, and instead returns a result that is
754 // deterministic but unpredictable to anyone without knowledge of the private
755 // key.
KYBER_decap(uint8_t out_shared_secret[KYBER_SHARED_SECRET_BYTES],const uint8_t ciphertext[KYBER_CIPHERTEXT_BYTES],const struct KYBER_private_key * private_key)756 void KYBER_decap(uint8_t out_shared_secret[KYBER_SHARED_SECRET_BYTES],
757 const uint8_t ciphertext[KYBER_CIPHERTEXT_BYTES],
758 const struct KYBER_private_key *private_key) {
759 const struct private_key *priv = private_key_from_external(private_key);
760 uint8_t decrypted[64];
761 decrypt_cpa(decrypted, priv, ciphertext);
762 OPENSSL_memcpy(decrypted + 32, priv->pub.public_key_hash,
763 sizeof(decrypted) - 32);
764 uint8_t prekey_and_randomness[64];
765 hash_g(prekey_and_randomness, decrypted, sizeof(decrypted));
766 uint8_t expected_ciphertext[KYBER_CIPHERTEXT_BYTES];
767 encrypt_cpa(expected_ciphertext, &priv->pub, decrypted,
768 prekey_and_randomness + 32);
769 uint8_t mask =
770 constant_time_eq_int_8(CRYPTO_memcmp(ciphertext, expected_ciphertext,
771 sizeof(expected_ciphertext)),
772 0);
773 uint8_t input[64];
774 for (int i = 0; i < 32; i++) {
775 input[i] = constant_time_select_8(mask, prekey_and_randomness[i],
776 priv->fo_failure_secret[i]);
777 }
778 hash_h(input + 32, ciphertext, KYBER_CIPHERTEXT_BYTES);
779 kdf(out_shared_secret, KYBER_SHARED_SECRET_BYTES, input, sizeof(input));
780 }
781
KYBER_marshal_public_key(CBB * out,const struct KYBER_public_key * public_key)782 int KYBER_marshal_public_key(CBB *out,
783 const struct KYBER_public_key *public_key) {
784 return kyber_marshal_public_key(out, public_key_from_external(public_key));
785 }
786
787 // kyber_parse_public_key_no_hash parses |in| into |pub| but doesn't calculate
788 // the value of |pub->public_key_hash|.
kyber_parse_public_key_no_hash(struct public_key * pub,CBS * in)789 static int kyber_parse_public_key_no_hash(struct public_key *pub, CBS *in) {
790 CBS t_bytes;
791 if (!CBS_get_bytes(in, &t_bytes, kEncodedVectorSize) ||
792 !vector_decode(&pub->t, CBS_data(&t_bytes), kLog2Prime) ||
793 !CBS_copy_bytes(in, pub->rho, sizeof(pub->rho))) {
794 return 0;
795 }
796 matrix_expand(&pub->m, pub->rho);
797 return 1;
798 }
799
KYBER_parse_public_key(struct KYBER_public_key * public_key,CBS * in)800 int KYBER_parse_public_key(struct KYBER_public_key *public_key, CBS *in) {
801 struct public_key *pub = public_key_from_external(public_key);
802 CBS orig_in = *in;
803 if (!kyber_parse_public_key_no_hash(pub, in) || //
804 CBS_len(in) != 0) {
805 return 0;
806 }
807 hash_h(pub->public_key_hash, CBS_data(&orig_in), CBS_len(&orig_in));
808 return 1;
809 }
810
KYBER_marshal_private_key(CBB * out,const struct KYBER_private_key * private_key)811 int KYBER_marshal_private_key(CBB *out,
812 const struct KYBER_private_key *private_key) {
813 const struct private_key *const priv = private_key_from_external(private_key);
814 uint8_t *s_output;
815 if (!CBB_add_space(out, &s_output, kEncodedVectorSize)) {
816 return 0;
817 }
818 vector_encode(s_output, &priv->s, kLog2Prime);
819 if (!kyber_marshal_public_key(out, &priv->pub) ||
820 !CBB_add_bytes(out, priv->pub.public_key_hash,
821 sizeof(priv->pub.public_key_hash)) ||
822 !CBB_add_bytes(out, priv->fo_failure_secret,
823 sizeof(priv->fo_failure_secret))) {
824 return 0;
825 }
826 return 1;
827 }
828
KYBER_parse_private_key(struct KYBER_private_key * out_private_key,CBS * in)829 int KYBER_parse_private_key(struct KYBER_private_key *out_private_key,
830 CBS *in) {
831 struct private_key *const priv = private_key_from_external(out_private_key);
832
833 CBS s_bytes;
834 if (!CBS_get_bytes(in, &s_bytes, kEncodedVectorSize) ||
835 !vector_decode(&priv->s, CBS_data(&s_bytes), kLog2Prime) ||
836 !kyber_parse_public_key_no_hash(&priv->pub, in) ||
837 !CBS_copy_bytes(in, priv->pub.public_key_hash,
838 sizeof(priv->pub.public_key_hash)) ||
839 !CBS_copy_bytes(in, priv->fo_failure_secret,
840 sizeof(priv->fo_failure_secret)) ||
841 CBS_len(in) != 0) {
842 return 0;
843 }
844 return 1;
845 }
846