1 /*
2 * Copyright 2018 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fcp/secagg/shared/map_of_masks.h"
18
19 #include <array>
20 #include <cstdint>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include <vector>
25
26 #include "gmock/gmock.h"
27 #include "gtest/gtest.h"
28 #include "absl/container/flat_hash_map.h"
29 #include "absl/container/node_hash_map.h"
30 #include "absl/numeric/bits.h"
31 #include "absl/strings/str_cat.h"
32 #include "fcp/secagg/shared/aes_ctr_prng_factory.h"
33 #include "fcp/secagg/shared/input_vector_specification.h"
34 #include "fcp/secagg/shared/math.h"
35 #include "fcp/secagg/shared/secagg_vector.h"
36
37 namespace fcp {
38 namespace secagg {
39 namespace {
40
41 using ::testing::Eq;
42 using ::testing::Lt;
43 using ::testing::Ne;
44
45 const std::array<uint64_t, 20> kArbitraryModuli{5,
46 39,
47 485,
48 2400,
49 14901,
50 51813,
51 532021,
52 13916946,
53 39549497,
54 548811945,
55 590549014,
56 48296031686,
57 156712951284,
58 2636861836189,
59 14673852658160,
60 92971495438615,
61 304436005557271,
62 14046234330484262,
63 38067457113486645,
64 175631339105057682};
65
TEST(AddMapsTest,AddMapsGetsRightSum_PowerOfTwo)66 TEST(AddMapsTest, AddMapsGetsRightSum_PowerOfTwo) {
67 std::vector<uint64_t> vec_a{25, 50, 75, 100, 150};
68 std::vector<uint64_t> vec_b{50, 100, 150, 200, 250};
69 SecAggVectorMap map_a;
70 map_a.emplace("test", SecAggVector(vec_a, 256));
71 SecAggVectorMap map_b;
72 map_b.emplace("test", SecAggVector(vec_b, 256));
73
74 auto map_sum = AddMaps(map_a, map_b);
75 std::vector<uint64_t> vec_sum = map_sum->at("test").GetAsUint64Vector();
76 for (int i = 0; i < vec_a.size(); ++i) {
77 EXPECT_THAT(vec_sum[i], Eq((vec_a[i] + vec_b[i]) % 256));
78 }
79 }
80
TEST(AddMapsTest,AddMapsGetsRightSum_AribraryModuli)81 TEST(AddMapsTest, AddMapsGetsRightSum_AribraryModuli) {
82 std::vector<uint64_t> vec_a{25, 50, 75, 100, 150};
83 std::vector<uint64_t> vec_b{50, 100, 150, 200, 250};
84 SecAggVectorMap map_a;
85 map_a.emplace("test", SecAggVector(vec_a, 255));
86 SecAggVectorMap map_b;
87 map_b.emplace("test", SecAggVector(vec_b, 255));
88
89 auto map_sum = AddMaps(map_a, map_b);
90 std::vector<uint64_t> vec_sum = map_sum->at("test").GetAsUint64Vector();
91 for (int i = 0; i < vec_a.size(); ++i) {
92 EXPECT_THAT(vec_sum[i], Eq((vec_a[i] + vec_b[i]) % 255));
93 }
94 }
95
TEST(AddMapsTest,AddMapsExhaustiveTest_PowerOfTwo)96 TEST(AddMapsTest, AddMapsExhaustiveTest_PowerOfTwo) {
97 // Make SecurePrng instance to be used as a consistent pseudo-random number
98 // generator.
99 uint8_t seed_data[32];
100 memset(seed_data, '1', 32);
101 AesKey seed(seed_data);
102 AesCtrPrngFactory prng_factory;
103 std::unique_ptr<SecurePrng> prng = prng_factory.MakePrng(seed);
104
105 // Iterate through all possible bitwidths, add two random vectors, and
106 // verify the results.
107 for (int number_of_bits = 1;
108 number_of_bits <= absl::bit_width(SecAggVector::kMaxModulus - 1);
109 ++number_of_bits) {
110 uint64_t modulus = 1ULL << number_of_bits;
111 constexpr size_t kSize = 1000;
112 std::vector<uint64_t> vec_a(kSize);
113 std::vector<uint64_t> vec_b(kSize);
114 for (size_t i = 0; i < kSize; i++) {
115 vec_a[i] = prng->Rand64() % modulus;
116 vec_b[i] = prng->Rand64() % modulus;
117 }
118
119 SecAggVectorMap map_a;
120 map_a.emplace("test", SecAggVector(vec_a, modulus));
121 SecAggVectorMap map_b;
122 map_b.emplace("test", SecAggVector(vec_b, modulus));
123
124 auto map_sum = AddMaps(map_a, map_b);
125 std::vector<uint64_t> vec_sum = map_sum->at("test").GetAsUint64Vector();
126 for (size_t i = 0; i < kSize; i++) {
127 EXPECT_THAT(vec_sum[i], Eq((vec_a[i] + vec_b[i]) % modulus));
128 }
129 }
130 }
131
TEST(AddMapsTest,AddMapsExhaustiveTest_ArbitraryModuli)132 TEST(AddMapsTest, AddMapsExhaustiveTest_ArbitraryModuli) {
133 // Make SecurePrng instance to be used as a consistent pseudo-random number
134 // generator.
135 uint8_t seed_data[32];
136 memset(seed_data, '1', 32);
137 AesKey seed(seed_data);
138 AesCtrPrngFactory prng_factory;
139 std::unique_ptr<SecurePrng> prng = prng_factory.MakePrng(seed);
140
141 // Iterate through all possible bitwidths, add two random vectors, and
142 // verify the results.
143 for (uint64_t modulus : kArbitraryModuli) {
144 constexpr size_t kSize = 1000;
145 std::vector<uint64_t> vec_a(kSize);
146 std::vector<uint64_t> vec_b(kSize);
147 for (size_t i = 0; i < kSize; i++) {
148 vec_a[i] = prng->Rand64() % modulus;
149 vec_b[i] = prng->Rand64() % modulus;
150 }
151
152 SecAggVectorMap map_a;
153 map_a.emplace("test", SecAggVector(vec_a, modulus));
154 SecAggVectorMap map_b;
155 map_b.emplace("test", SecAggVector(vec_b, modulus));
156
157 auto map_sum = AddMaps(map_a, map_b);
158 std::vector<uint64_t> vec_sum = map_sum->at("test").GetAsUint64Vector();
159 for (size_t i = 0; i < kSize; i++) {
160 EXPECT_THAT(vec_sum[i], Eq((vec_a[i] + vec_b[i]) % modulus));
161 }
162 }
163 }
164
165 enum MapOfMasksVersion { CURRENT, V3, UNPACKED };
166
167 class MapOfMasksTest : public ::testing::TestWithParam<MapOfMasksVersion> {
168 public:
169 using Uint64VectorMap =
170 absl::node_hash_map<std::string, std::vector<uint64_t>>;
171
MapOfMasks(const std::vector<AesKey> & prng_keys_to_add,const std::vector<AesKey> & prng_keys_to_subtract,const std::vector<InputVectorSpecification> & input_vector_specs,const SessionId & session_id,const AesPrngFactory & prng_factory)172 std::unique_ptr<Uint64VectorMap> MapOfMasks(
173 const std::vector<AesKey>& prng_keys_to_add,
174 const std::vector<AesKey>& prng_keys_to_subtract,
175 const std::vector<InputVectorSpecification>& input_vector_specs,
176 const SessionId& session_id, const AesPrngFactory& prng_factory) {
177 if (GetParam() == MapOfMasksVersion::UNPACKED) {
178 return ToUint64VectorMap(fcp::secagg::UnpackedMapOfMasks(
179 prng_keys_to_add, prng_keys_to_subtract, input_vector_specs,
180 session_id, prng_factory));
181 } else if (GetParam() == MapOfMasksVersion::V3) {
182 return ToUint64VectorMap(fcp::secagg::MapOfMasksV3(
183 prng_keys_to_add, prng_keys_to_subtract, input_vector_specs,
184 session_id, prng_factory));
185 } else {
186 return ToUint64VectorMap(fcp::secagg::MapOfMasks(
187 prng_keys_to_add, prng_keys_to_subtract, input_vector_specs,
188 session_id, prng_factory));
189 }
190 }
191
192 private:
ToUint64VectorMap(std::unique_ptr<SecAggVectorMap> map)193 std::unique_ptr<Uint64VectorMap> ToUint64VectorMap(
194 std::unique_ptr<SecAggVectorMap> map) {
195 auto result = std::make_unique<Uint64VectorMap>();
196 for (auto& [name, vec] : *map) {
197 result->emplace(name, vec.GetAsUint64Vector());
198 }
199 return result;
200 }
201
ToUint64VectorMap(std::unique_ptr<SecAggUnpackedVectorMap> map)202 std::unique_ptr<Uint64VectorMap> ToUint64VectorMap(
203 std::unique_ptr<SecAggUnpackedVectorMap> map) {
204 auto result = std::make_unique<Uint64VectorMap>();
205 for (auto& [name, vec] : *map) {
206 result->emplace(name, std::move(vec));
207 }
208 return result;
209 }
210 };
211
212 // AES MapOfMasks: Power-of-two Moduli
213
TEST_P(MapOfMasksTest,ReturnsZeroIfNoKeysSpecified_PowerOfTwo)214 TEST_P(MapOfMasksTest, ReturnsZeroIfNoKeysSpecified_PowerOfTwo) {
215 std::vector<AesKey> prng_keys_to_add;
216 std::vector<AesKey> prng_keys_to_subtract;
217 SessionId session_id = {std::string(32, 'Z')};
218 std::vector<InputVectorSpecification> vector_specs;
219 vector_specs.push_back(InputVectorSpecification("test", 10, 1ULL << 20));
220
221 auto masks = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract, vector_specs,
222 session_id, AesCtrPrngFactory());
223
224 EXPECT_THAT(masks->size(), Eq(1));
225 std::vector<uint64_t> zeroes(10, 0);
226 EXPECT_THAT(masks->at("test"), Eq(std::vector<uint64_t>(10, 0)));
227 }
228
TEST_P(MapOfMasksTest,ReturnsNonZeroIfOneKeySpecified_PowerOfTwo)229 TEST_P(MapOfMasksTest, ReturnsNonZeroIfOneKeySpecified_PowerOfTwo) {
230 std::vector<AesKey> prng_keys_to_add;
231 uint8_t key[AesKey::kSize];
232 memset(key, 'A', AesKey::kSize);
233 prng_keys_to_add.push_back(AesKey(key));
234 std::vector<AesKey> prng_keys_to_subtract;
235 SessionId session_id = {std::string(32, 'Z')};
236 std::vector<InputVectorSpecification> vector_specs;
237 vector_specs.push_back(InputVectorSpecification("test", 10, 1ULL << 20));
238
239 auto masks = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract, vector_specs,
240 session_id, AesCtrPrngFactory());
241
242 EXPECT_THAT(masks->size(), Eq(1));
243 EXPECT_THAT(masks->at("test"), Ne(std::vector<uint64_t>(10, 0)));
244 }
245
TEST_P(MapOfMasksTest,MapWithOneKeyDiffersFromMapWithTwoKeys_PowerOfTwo)246 TEST_P(MapOfMasksTest, MapWithOneKeyDiffersFromMapWithTwoKeys_PowerOfTwo) {
247 std::vector<AesKey> prng_keys_to_add;
248 uint8_t
249 key[AesKey::kSize]; // This key is reusable because AesKey makes a copy
250 memset(key, 'A', AesKey::kSize);
251 prng_keys_to_add.push_back(AesKey(key));
252 std::vector<AesKey> prng_keys_to_subtract;
253 SessionId session_id = {std::string(32, 'Z')};
254 std::vector<InputVectorSpecification> vector_specs;
255 vector_specs.push_back(InputVectorSpecification("test", 10, 1ULL << 20));
256
257 auto masks1 = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract,
258 vector_specs, session_id, AesCtrPrngFactory());
259
260 memset(key, 'B', AesKey::kSize);
261 prng_keys_to_add.push_back(AesKey(key));
262 auto masks2 = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract,
263 vector_specs, session_id, AesCtrPrngFactory());
264
265 EXPECT_THAT(masks1->size(), Eq(1));
266 EXPECT_THAT(masks2->size(), Eq(1));
267 EXPECT_THAT(masks2->at("test"), Ne(masks1->at("test")));
268 }
269
TEST_P(MapOfMasksTest,MapsWithOppositeMasksCancel_PowerOfTwo)270 TEST_P(MapOfMasksTest, MapsWithOppositeMasksCancel_PowerOfTwo) {
271 std::vector<AesKey> prng_keys_to_add;
272 uint8_t key[AesKey::kSize];
273 memset(key, 'A', AesKey::kSize);
274 prng_keys_to_add.push_back(AesKey(key));
275 memset(key, 'B', AesKey::kSize);
276 prng_keys_to_add.push_back(AesKey(key));
277 std::vector<AesKey> prng_keys_to_subtract;
278 SessionId session_id = {std::string(32, 'Z')};
279 std::vector<InputVectorSpecification> vector_specs;
280 vector_specs.push_back(InputVectorSpecification("test", 10, 1ULL << 20));
281
282 auto masks1 = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract,
283 vector_specs, session_id, AesCtrPrngFactory());
284
285 auto masks2 = MapOfMasks(prng_keys_to_subtract, prng_keys_to_add,
286 vector_specs, session_id, AesCtrPrngFactory());
287
288 EXPECT_THAT(masks1->size(), Eq(1));
289 EXPECT_THAT(masks2->size(), Eq(1));
290 std::vector<uint64_t> mask_vector1 = masks1->at("test");
291 std::vector<uint64_t> mask_vector2 = masks2->at("test");
292 for (int i = 0; i < 10; ++i) {
293 EXPECT_THAT(AddMod(mask_vector1[i], mask_vector2[i], 1ULL << 20), Eq(0));
294 }
295 }
296
TEST_P(MapOfMasksTest,MapsWithMixedOppositeMasksCancel_PowerOfTwo)297 TEST_P(MapOfMasksTest, MapsWithMixedOppositeMasksCancel_PowerOfTwo) {
298 std::vector<AesKey> prng_keys_to_add;
299 uint8_t key[AesKey::kSize];
300 memset(key, 'A', AesKey::kSize);
301 prng_keys_to_add.push_back(AesKey(key));
302 memset(key, 'B', AesKey::kSize);
303 std::vector<AesKey> prng_keys_to_subtract;
304 prng_keys_to_subtract.push_back(AesKey(key));
305 SessionId session_id = {std::string(32, 'Z')};
306 std::vector<InputVectorSpecification> vector_specs;
307 vector_specs.push_back(InputVectorSpecification("test", 10, 1ULL << 20));
308
309 auto masks1 = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract,
310 vector_specs, session_id, AesCtrPrngFactory());
311
312 auto masks2 = MapOfMasks(prng_keys_to_subtract, prng_keys_to_add,
313 vector_specs, session_id, AesCtrPrngFactory());
314
315 EXPECT_THAT(masks1->size(), Eq(1));
316 EXPECT_THAT(masks2->size(), Eq(1));
317 std::vector<uint64_t> mask_vector1 = masks1->at("test");
318 std::vector<uint64_t> mask_vector2 = masks2->at("test");
319 for (int i = 0; i < 10; ++i) {
320 EXPECT_THAT(AddMod(mask_vector1[i], mask_vector2[i], 1ULL << 20), Eq(0));
321 }
322 }
323
TEST_P(MapOfMasksTest,PrngMaskGeneratesCorrectBitwidthMasks_PowerOfTwo)324 TEST_P(MapOfMasksTest, PrngMaskGeneratesCorrectBitwidthMasks_PowerOfTwo) {
325 std::vector<AesKey> prng_keys_to_add;
326 uint8_t key[AesKey::kSize];
327 memset(key, 'A', AesKey::kSize);
328 prng_keys_to_add.push_back(AesKey(key));
329 std::vector<AesKey> prng_keys_to_subtract;
330 SessionId session_id = {std::string(32, 'Z')};
331 std::vector<InputVectorSpecification> vector_specs;
332
333 // Check a variety of bit_widths
334 std::vector<uint64_t> moduli{1ULL << 1, 1ULL << 4, 1ULL << 20, 1ULL << 24,
335 SecAggVector::kMaxModulus};
336 for (uint64_t i : moduli) {
337 vector_specs.push_back(
338 InputVectorSpecification(absl::StrCat("test", i), 50, i));
339 }
340
341 auto masks = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract, vector_specs,
342 session_id, AesCtrPrngFactory());
343
344 // Make sure all elements are less than the bound, and also at least one of
345 // them has the highest-allowed bit set.
346 for (uint64_t modulus : moduli) {
347 auto vec = masks->at(absl::StrCat("test", modulus));
348 bool high_order_bit_set = false;
349 for (uint64_t mask : vec) {
350 EXPECT_THAT(mask, Lt(modulus));
351 if (mask >= (modulus >> 1)) {
352 high_order_bit_set = true;
353 }
354 }
355 EXPECT_THAT(high_order_bit_set, Eq(true));
356 }
357 }
358
359 // AES MapOfMasks: Arbitrary Moduli
360
TEST_P(MapOfMasksTest,ReturnsZeroIfNoKeysSpecified_ArbitraryModuli)361 TEST_P(MapOfMasksTest, ReturnsZeroIfNoKeysSpecified_ArbitraryModuli) {
362 uint64_t modulus = 2636861836189;
363 std::vector<AesKey> prng_keys_to_add;
364 std::vector<AesKey> prng_keys_to_subtract;
365 SessionId session_id = {std::string(32, 'Z')};
366 std::vector<InputVectorSpecification> vector_specs;
367 vector_specs.push_back(InputVectorSpecification("test", 10, modulus));
368
369 auto masks = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract, vector_specs,
370 session_id, AesCtrPrngFactory());
371
372 EXPECT_THAT(masks->size(), Eq(1));
373 std::vector<uint64_t> zeroes(10, 0);
374 EXPECT_THAT(masks->at("test"), Eq(std::vector<uint64_t>(10, 0)));
375 }
376
TEST_P(MapOfMasksTest,ReturnsNonZeroIfOneKeySpecified_ArbitraryModuli)377 TEST_P(MapOfMasksTest, ReturnsNonZeroIfOneKeySpecified_ArbitraryModuli) {
378 uint64_t modulus = 2636861836189;
379 std::vector<AesKey> prng_keys_to_add;
380 uint8_t key[AesKey::kSize];
381 memset(key, 'A', AesKey::kSize);
382 prng_keys_to_add.push_back(AesKey(key));
383 std::vector<AesKey> prng_keys_to_subtract;
384 SessionId session_id = {std::string(32, 'Z')};
385 std::vector<InputVectorSpecification> vector_specs;
386 vector_specs.push_back(InputVectorSpecification("test", 10, modulus));
387
388 auto masks = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract, vector_specs,
389 session_id, AesCtrPrngFactory());
390
391 EXPECT_THAT(masks->size(), Eq(1));
392 EXPECT_THAT(masks->at("test"), Ne(std::vector<uint64_t>(10, 0)));
393 }
394
TEST_P(MapOfMasksTest,MapWithOneKeyDiffersFromMapWithTwoKeys_ArbitraryModuli)395 TEST_P(MapOfMasksTest, MapWithOneKeyDiffersFromMapWithTwoKeys_ArbitraryModuli) {
396 uint64_t modulus = 2636861836189;
397 std::vector<AesKey> prng_keys_to_add;
398 uint8_t
399 key[AesKey::kSize]; // This key is reusable because AesKey makes a copy
400 memset(key, 'A', AesKey::kSize);
401 prng_keys_to_add.push_back(AesKey(key));
402 std::vector<AesKey> prng_keys_to_subtract;
403 SessionId session_id = {std::string(32, 'Z')};
404 std::vector<InputVectorSpecification> vector_specs;
405 vector_specs.push_back(InputVectorSpecification("test", 10, modulus));
406
407 auto masks1 = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract,
408 vector_specs, session_id, AesCtrPrngFactory());
409
410 memset(key, 'B', AesKey::kSize);
411 prng_keys_to_add.push_back(AesKey(key));
412 auto masks2 = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract,
413 vector_specs, session_id, AesCtrPrngFactory());
414
415 EXPECT_THAT(masks1->size(), Eq(1));
416 EXPECT_THAT(masks2->size(), Eq(1));
417 EXPECT_THAT(masks2->at("test"), Ne(masks1->at("test")));
418 }
419
TEST_P(MapOfMasksTest,MapsAreDeterministic_KeysToAdd_ArbitraryModuli)420 TEST_P(MapOfMasksTest, MapsAreDeterministic_KeysToAdd_ArbitraryModuli) {
421 uint64_t modulus = 2636861836189;
422 uint8_t key[AesKey::kSize];
423 // prng_keys_to_add includes A
424 std::vector<AesKey> prng_keys_to_add;
425 memset(key, 'A', AesKey::kSize);
426 prng_keys_to_add.push_back(AesKey(key));
427
428 // prng_keys_to_subtract includes B
429 std::vector<AesKey> prng_keys_to_subtract;
430 memset(key, 'B', AesKey::kSize);
431 prng_keys_to_subtract.push_back(AesKey(key));
432
433 SessionId session_id = {std::string(32, 'Z')};
434 std::vector<InputVectorSpecification> vector_specs;
435 vector_specs.push_back(InputVectorSpecification("test", 10, modulus));
436
437 auto masks1 = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract,
438 vector_specs, session_id, AesCtrPrngFactory());
439
440 auto masks2 = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract,
441 vector_specs, session_id, AesCtrPrngFactory());
442
443 EXPECT_THAT(masks1->size(), Eq(1));
444 EXPECT_THAT(masks2->size(), Eq(1));
445 std::vector<uint64_t> mask_vector1 = masks1->at("test");
446 std::vector<uint64_t> mask_vector2 = masks2->at("test");
447 for (int i = 0; i < 10; ++i) {
448 EXPECT_THAT(mask_vector1[i], Eq(mask_vector2[i]));
449 }
450 }
451
TEST_P(MapOfMasksTest,MapsWithOppositeMasksCancel_ArbitraryModuli)452 TEST_P(MapOfMasksTest, MapsWithOppositeMasksCancel_ArbitraryModuli) {
453 uint64_t modulus = 2636861836189;
454 uint8_t key[AesKey::kSize];
455 // prng_keys_to_add includes A & B
456 std::vector<AesKey> prng_keys_to_add;
457 memset(key, 'A', AesKey::kSize);
458 prng_keys_to_add.push_back(AesKey(key));
459 memset(key, 'B', AesKey::kSize);
460 prng_keys_to_add.push_back(AesKey(key));
461 // prng_keys_to_subtract is empty
462 std::vector<AesKey> prng_keys_to_subtract;
463
464 SessionId session_id = {std::string(32, 'Z')};
465 std::vector<InputVectorSpecification> vector_specs;
466 vector_specs.push_back(InputVectorSpecification("test", 10, modulus));
467
468 auto masks1 = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract,
469 vector_specs, session_id, AesCtrPrngFactory());
470
471 auto masks2 = MapOfMasks(prng_keys_to_subtract, prng_keys_to_add,
472 vector_specs, session_id, AesCtrPrngFactory());
473
474 EXPECT_THAT(masks1->size(), Eq(1));
475 EXPECT_THAT(masks2->size(), Eq(1));
476 std::vector<uint64_t> mask_vector1 = masks1->at("test");
477 std::vector<uint64_t> mask_vector2 = masks2->at("test");
478 for (int i = 0; i < 10; ++i) {
479 EXPECT_THAT(AddMod(mask_vector1[i], mask_vector2[i], modulus), Eq(0));
480 }
481 }
482
TEST_P(MapOfMasksTest,MapsWithMixedOppositeMasksCancel_ArbitraryModuli)483 TEST_P(MapOfMasksTest, MapsWithMixedOppositeMasksCancel_ArbitraryModuli) {
484 uint64_t modulus = 2636861836189;
485 uint8_t key[AesKey::kSize];
486 // prng_keys_to_add includes A
487 std::vector<AesKey> prng_keys_to_add;
488 memset(key, 'A', AesKey::kSize);
489 prng_keys_to_add.push_back(AesKey(key));
490 // prng_keys_to_subtract includes B
491 std::vector<AesKey> prng_keys_to_subtract;
492 memset(key, 'B', AesKey::kSize);
493 prng_keys_to_subtract.push_back(AesKey(key));
494
495 SessionId session_id = {std::string(32, 'Z')};
496 std::vector<InputVectorSpecification> vector_specs;
497 vector_specs.push_back(InputVectorSpecification("test", 10, modulus));
498
499 auto masks1 = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract,
500 vector_specs, session_id, AesCtrPrngFactory());
501
502 auto masks2 = MapOfMasks(prng_keys_to_subtract, prng_keys_to_add,
503 vector_specs, session_id, AesCtrPrngFactory());
504
505 EXPECT_THAT(masks1->size(), Eq(1));
506 EXPECT_THAT(masks2->size(), Eq(1));
507 std::vector<uint64_t> mask_vector1 = masks1->at("test");
508 std::vector<uint64_t> mask_vector2 = masks2->at("test");
509 for (int i = 0; i < 10; ++i) {
510 EXPECT_THAT(AddMod(mask_vector1[i], mask_vector2[i], modulus), Eq(0));
511 }
512 }
513
TEST_P(MapOfMasksTest,PrngMaskGeneratesCorrectBitwidthMasks_ArbitraryModuli)514 TEST_P(MapOfMasksTest, PrngMaskGeneratesCorrectBitwidthMasks_ArbitraryModuli) {
515 std::vector<AesKey> prng_keys_to_add;
516 uint8_t key[AesKey::kSize];
517 memset(key, 'A', AesKey::kSize);
518 prng_keys_to_add.push_back(AesKey(key));
519 std::vector<AesKey> prng_keys_to_subtract;
520 SessionId session_id = {std::string(32, 'Z')};
521 std::vector<InputVectorSpecification> vector_specs;
522
523 // Check a variety of bit_widths
524 for (uint64_t i : kArbitraryModuli) {
525 vector_specs.push_back(
526 InputVectorSpecification(absl::StrCat("test", i), 50, i));
527 }
528
529 auto masks = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract, vector_specs,
530 session_id, AesCtrPrngFactory());
531
532 // Make sure all elements are less than the bound, and also at least one of
533 // them has the highest-allowed bit set.
534 for (uint64_t modulus : kArbitraryModuli) {
535 auto vec = masks->at(absl::StrCat("test", modulus));
536 bool high_order_bit_set = false;
537 for (uint64_t mask : vec) {
538 EXPECT_THAT(mask, Lt(modulus));
539 if (mask >= (modulus >> 1)) {
540 high_order_bit_set = true;
541 }
542 }
543 EXPECT_THAT(high_order_bit_set, Eq(true));
544 }
545 }
546
547 INSTANTIATE_TEST_SUITE_P(MapOfMasksTest, MapOfMasksTest,
548 ::testing::Values<MapOfMasksVersion>(CURRENT, V3,
549 UNPACKED));
550
551 } // namespace
552 } // namespace secagg
553 } // namespace fcp
554