xref: /aosp_15_r20/external/federated-compute/fcp/secagg/shared/map_of_masks_test.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2018 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     https://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/secagg/shared/map_of_masks.h"
18 
19 #include <array>
20 #include <cstdint>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 
26 #include "gmock/gmock.h"
27 #include "gtest/gtest.h"
28 #include "absl/container/flat_hash_map.h"
29 #include "absl/container/node_hash_map.h"
30 #include "absl/numeric/bits.h"
31 #include "absl/strings/str_cat.h"
32 #include "fcp/secagg/shared/aes_ctr_prng_factory.h"
33 #include "fcp/secagg/shared/input_vector_specification.h"
34 #include "fcp/secagg/shared/math.h"
35 #include "fcp/secagg/shared/secagg_vector.h"
36 
37 namespace fcp {
38 namespace secagg {
39 namespace {
40 
41 using ::testing::Eq;
42 using ::testing::Lt;
43 using ::testing::Ne;
44 
45 const std::array<uint64_t, 20> kArbitraryModuli{5,
46                                                 39,
47                                                 485,
48                                                 2400,
49                                                 14901,
50                                                 51813,
51                                                 532021,
52                                                 13916946,
53                                                 39549497,
54                                                 548811945,
55                                                 590549014,
56                                                 48296031686,
57                                                 156712951284,
58                                                 2636861836189,
59                                                 14673852658160,
60                                                 92971495438615,
61                                                 304436005557271,
62                                                 14046234330484262,
63                                                 38067457113486645,
64                                                 175631339105057682};
65 
TEST(AddMapsTest,AddMapsGetsRightSum_PowerOfTwo)66 TEST(AddMapsTest, AddMapsGetsRightSum_PowerOfTwo) {
67   std::vector<uint64_t> vec_a{25, 50, 75, 100, 150};
68   std::vector<uint64_t> vec_b{50, 100, 150, 200, 250};
69   SecAggVectorMap map_a;
70   map_a.emplace("test", SecAggVector(vec_a, 256));
71   SecAggVectorMap map_b;
72   map_b.emplace("test", SecAggVector(vec_b, 256));
73 
74   auto map_sum = AddMaps(map_a, map_b);
75   std::vector<uint64_t> vec_sum = map_sum->at("test").GetAsUint64Vector();
76   for (int i = 0; i < vec_a.size(); ++i) {
77     EXPECT_THAT(vec_sum[i], Eq((vec_a[i] + vec_b[i]) % 256));
78   }
79 }
80 
TEST(AddMapsTest,AddMapsGetsRightSum_AribraryModuli)81 TEST(AddMapsTest, AddMapsGetsRightSum_AribraryModuli) {
82   std::vector<uint64_t> vec_a{25, 50, 75, 100, 150};
83   std::vector<uint64_t> vec_b{50, 100, 150, 200, 250};
84   SecAggVectorMap map_a;
85   map_a.emplace("test", SecAggVector(vec_a, 255));
86   SecAggVectorMap map_b;
87   map_b.emplace("test", SecAggVector(vec_b, 255));
88 
89   auto map_sum = AddMaps(map_a, map_b);
90   std::vector<uint64_t> vec_sum = map_sum->at("test").GetAsUint64Vector();
91   for (int i = 0; i < vec_a.size(); ++i) {
92     EXPECT_THAT(vec_sum[i], Eq((vec_a[i] + vec_b[i]) % 255));
93   }
94 }
95 
TEST(AddMapsTest,AddMapsExhaustiveTest_PowerOfTwo)96 TEST(AddMapsTest, AddMapsExhaustiveTest_PowerOfTwo) {
97   // Make SecurePrng instance to be used as a consistent pseudo-random number
98   // generator.
99   uint8_t seed_data[32];
100   memset(seed_data, '1', 32);
101   AesKey seed(seed_data);
102   AesCtrPrngFactory prng_factory;
103   std::unique_ptr<SecurePrng> prng = prng_factory.MakePrng(seed);
104 
105   // Iterate through all possible bitwidths, add two random vectors, and
106   // verify the results.
107   for (int number_of_bits = 1;
108        number_of_bits <= absl::bit_width(SecAggVector::kMaxModulus - 1);
109        ++number_of_bits) {
110     uint64_t modulus = 1ULL << number_of_bits;
111     constexpr size_t kSize = 1000;
112     std::vector<uint64_t> vec_a(kSize);
113     std::vector<uint64_t> vec_b(kSize);
114     for (size_t i = 0; i < kSize; i++) {
115       vec_a[i] = prng->Rand64() % modulus;
116       vec_b[i] = prng->Rand64() % modulus;
117     }
118 
119     SecAggVectorMap map_a;
120     map_a.emplace("test", SecAggVector(vec_a, modulus));
121     SecAggVectorMap map_b;
122     map_b.emplace("test", SecAggVector(vec_b, modulus));
123 
124     auto map_sum = AddMaps(map_a, map_b);
125     std::vector<uint64_t> vec_sum = map_sum->at("test").GetAsUint64Vector();
126     for (size_t i = 0; i < kSize; i++) {
127       EXPECT_THAT(vec_sum[i], Eq((vec_a[i] + vec_b[i]) % modulus));
128     }
129   }
130 }
131 
TEST(AddMapsTest,AddMapsExhaustiveTest_ArbitraryModuli)132 TEST(AddMapsTest, AddMapsExhaustiveTest_ArbitraryModuli) {
133   // Make SecurePrng instance to be used as a consistent pseudo-random number
134   // generator.
135   uint8_t seed_data[32];
136   memset(seed_data, '1', 32);
137   AesKey seed(seed_data);
138   AesCtrPrngFactory prng_factory;
139   std::unique_ptr<SecurePrng> prng = prng_factory.MakePrng(seed);
140 
141   // Iterate through all possible bitwidths, add two random vectors, and
142   // verify the results.
143   for (uint64_t modulus : kArbitraryModuli) {
144     constexpr size_t kSize = 1000;
145     std::vector<uint64_t> vec_a(kSize);
146     std::vector<uint64_t> vec_b(kSize);
147     for (size_t i = 0; i < kSize; i++) {
148       vec_a[i] = prng->Rand64() % modulus;
149       vec_b[i] = prng->Rand64() % modulus;
150     }
151 
152     SecAggVectorMap map_a;
153     map_a.emplace("test", SecAggVector(vec_a, modulus));
154     SecAggVectorMap map_b;
155     map_b.emplace("test", SecAggVector(vec_b, modulus));
156 
157     auto map_sum = AddMaps(map_a, map_b);
158     std::vector<uint64_t> vec_sum = map_sum->at("test").GetAsUint64Vector();
159     for (size_t i = 0; i < kSize; i++) {
160       EXPECT_THAT(vec_sum[i], Eq((vec_a[i] + vec_b[i]) % modulus));
161     }
162   }
163 }
164 
165 enum MapOfMasksVersion { CURRENT, V3, UNPACKED };
166 
167 class MapOfMasksTest : public ::testing::TestWithParam<MapOfMasksVersion> {
168  public:
169   using Uint64VectorMap =
170       absl::node_hash_map<std::string, std::vector<uint64_t>>;
171 
MapOfMasks(const std::vector<AesKey> & prng_keys_to_add,const std::vector<AesKey> & prng_keys_to_subtract,const std::vector<InputVectorSpecification> & input_vector_specs,const SessionId & session_id,const AesPrngFactory & prng_factory)172   std::unique_ptr<Uint64VectorMap> MapOfMasks(
173       const std::vector<AesKey>& prng_keys_to_add,
174       const std::vector<AesKey>& prng_keys_to_subtract,
175       const std::vector<InputVectorSpecification>& input_vector_specs,
176       const SessionId& session_id, const AesPrngFactory& prng_factory) {
177     if (GetParam() == MapOfMasksVersion::UNPACKED) {
178       return ToUint64VectorMap(fcp::secagg::UnpackedMapOfMasks(
179           prng_keys_to_add, prng_keys_to_subtract, input_vector_specs,
180           session_id, prng_factory));
181     } else if (GetParam() == MapOfMasksVersion::V3) {
182       return ToUint64VectorMap(fcp::secagg::MapOfMasksV3(
183           prng_keys_to_add, prng_keys_to_subtract, input_vector_specs,
184           session_id, prng_factory));
185     } else {
186       return ToUint64VectorMap(fcp::secagg::MapOfMasks(
187           prng_keys_to_add, prng_keys_to_subtract, input_vector_specs,
188           session_id, prng_factory));
189     }
190   }
191 
192  private:
ToUint64VectorMap(std::unique_ptr<SecAggVectorMap> map)193   std::unique_ptr<Uint64VectorMap> ToUint64VectorMap(
194       std::unique_ptr<SecAggVectorMap> map) {
195     auto result = std::make_unique<Uint64VectorMap>();
196     for (auto& [name, vec] : *map) {
197       result->emplace(name, vec.GetAsUint64Vector());
198     }
199     return result;
200   }
201 
ToUint64VectorMap(std::unique_ptr<SecAggUnpackedVectorMap> map)202   std::unique_ptr<Uint64VectorMap> ToUint64VectorMap(
203       std::unique_ptr<SecAggUnpackedVectorMap> map) {
204     auto result = std::make_unique<Uint64VectorMap>();
205     for (auto& [name, vec] : *map) {
206       result->emplace(name, std::move(vec));
207     }
208     return result;
209   }
210 };
211 
212 // AES MapOfMasks: Power-of-two Moduli
213 
TEST_P(MapOfMasksTest,ReturnsZeroIfNoKeysSpecified_PowerOfTwo)214 TEST_P(MapOfMasksTest, ReturnsZeroIfNoKeysSpecified_PowerOfTwo) {
215   std::vector<AesKey> prng_keys_to_add;
216   std::vector<AesKey> prng_keys_to_subtract;
217   SessionId session_id = {std::string(32, 'Z')};
218   std::vector<InputVectorSpecification> vector_specs;
219   vector_specs.push_back(InputVectorSpecification("test", 10, 1ULL << 20));
220 
221   auto masks = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract, vector_specs,
222                           session_id, AesCtrPrngFactory());
223 
224   EXPECT_THAT(masks->size(), Eq(1));
225   std::vector<uint64_t> zeroes(10, 0);
226   EXPECT_THAT(masks->at("test"), Eq(std::vector<uint64_t>(10, 0)));
227 }
228 
TEST_P(MapOfMasksTest,ReturnsNonZeroIfOneKeySpecified_PowerOfTwo)229 TEST_P(MapOfMasksTest, ReturnsNonZeroIfOneKeySpecified_PowerOfTwo) {
230   std::vector<AesKey> prng_keys_to_add;
231   uint8_t key[AesKey::kSize];
232   memset(key, 'A', AesKey::kSize);
233   prng_keys_to_add.push_back(AesKey(key));
234   std::vector<AesKey> prng_keys_to_subtract;
235   SessionId session_id = {std::string(32, 'Z')};
236   std::vector<InputVectorSpecification> vector_specs;
237   vector_specs.push_back(InputVectorSpecification("test", 10, 1ULL << 20));
238 
239   auto masks = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract, vector_specs,
240                           session_id, AesCtrPrngFactory());
241 
242   EXPECT_THAT(masks->size(), Eq(1));
243   EXPECT_THAT(masks->at("test"), Ne(std::vector<uint64_t>(10, 0)));
244 }
245 
TEST_P(MapOfMasksTest,MapWithOneKeyDiffersFromMapWithTwoKeys_PowerOfTwo)246 TEST_P(MapOfMasksTest, MapWithOneKeyDiffersFromMapWithTwoKeys_PowerOfTwo) {
247   std::vector<AesKey> prng_keys_to_add;
248   uint8_t
249       key[AesKey::kSize];  // This key is reusable because AesKey makes a copy
250   memset(key, 'A', AesKey::kSize);
251   prng_keys_to_add.push_back(AesKey(key));
252   std::vector<AesKey> prng_keys_to_subtract;
253   SessionId session_id = {std::string(32, 'Z')};
254   std::vector<InputVectorSpecification> vector_specs;
255   vector_specs.push_back(InputVectorSpecification("test", 10, 1ULL << 20));
256 
257   auto masks1 = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract,
258                            vector_specs, session_id, AesCtrPrngFactory());
259 
260   memset(key, 'B', AesKey::kSize);
261   prng_keys_to_add.push_back(AesKey(key));
262   auto masks2 = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract,
263                            vector_specs, session_id, AesCtrPrngFactory());
264 
265   EXPECT_THAT(masks1->size(), Eq(1));
266   EXPECT_THAT(masks2->size(), Eq(1));
267   EXPECT_THAT(masks2->at("test"), Ne(masks1->at("test")));
268 }
269 
TEST_P(MapOfMasksTest,MapsWithOppositeMasksCancel_PowerOfTwo)270 TEST_P(MapOfMasksTest, MapsWithOppositeMasksCancel_PowerOfTwo) {
271   std::vector<AesKey> prng_keys_to_add;
272   uint8_t key[AesKey::kSize];
273   memset(key, 'A', AesKey::kSize);
274   prng_keys_to_add.push_back(AesKey(key));
275   memset(key, 'B', AesKey::kSize);
276   prng_keys_to_add.push_back(AesKey(key));
277   std::vector<AesKey> prng_keys_to_subtract;
278   SessionId session_id = {std::string(32, 'Z')};
279   std::vector<InputVectorSpecification> vector_specs;
280   vector_specs.push_back(InputVectorSpecification("test", 10, 1ULL << 20));
281 
282   auto masks1 = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract,
283                            vector_specs, session_id, AesCtrPrngFactory());
284 
285   auto masks2 = MapOfMasks(prng_keys_to_subtract, prng_keys_to_add,
286                            vector_specs, session_id, AesCtrPrngFactory());
287 
288   EXPECT_THAT(masks1->size(), Eq(1));
289   EXPECT_THAT(masks2->size(), Eq(1));
290   std::vector<uint64_t> mask_vector1 = masks1->at("test");
291   std::vector<uint64_t> mask_vector2 = masks2->at("test");
292   for (int i = 0; i < 10; ++i) {
293     EXPECT_THAT(AddMod(mask_vector1[i], mask_vector2[i], 1ULL << 20), Eq(0));
294   }
295 }
296 
TEST_P(MapOfMasksTest,MapsWithMixedOppositeMasksCancel_PowerOfTwo)297 TEST_P(MapOfMasksTest, MapsWithMixedOppositeMasksCancel_PowerOfTwo) {
298   std::vector<AesKey> prng_keys_to_add;
299   uint8_t key[AesKey::kSize];
300   memset(key, 'A', AesKey::kSize);
301   prng_keys_to_add.push_back(AesKey(key));
302   memset(key, 'B', AesKey::kSize);
303   std::vector<AesKey> prng_keys_to_subtract;
304   prng_keys_to_subtract.push_back(AesKey(key));
305   SessionId session_id = {std::string(32, 'Z')};
306   std::vector<InputVectorSpecification> vector_specs;
307   vector_specs.push_back(InputVectorSpecification("test", 10, 1ULL << 20));
308 
309   auto masks1 = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract,
310                            vector_specs, session_id, AesCtrPrngFactory());
311 
312   auto masks2 = MapOfMasks(prng_keys_to_subtract, prng_keys_to_add,
313                            vector_specs, session_id, AesCtrPrngFactory());
314 
315   EXPECT_THAT(masks1->size(), Eq(1));
316   EXPECT_THAT(masks2->size(), Eq(1));
317   std::vector<uint64_t> mask_vector1 = masks1->at("test");
318   std::vector<uint64_t> mask_vector2 = masks2->at("test");
319   for (int i = 0; i < 10; ++i) {
320     EXPECT_THAT(AddMod(mask_vector1[i], mask_vector2[i], 1ULL << 20), Eq(0));
321   }
322 }
323 
TEST_P(MapOfMasksTest,PrngMaskGeneratesCorrectBitwidthMasks_PowerOfTwo)324 TEST_P(MapOfMasksTest, PrngMaskGeneratesCorrectBitwidthMasks_PowerOfTwo) {
325   std::vector<AesKey> prng_keys_to_add;
326   uint8_t key[AesKey::kSize];
327   memset(key, 'A', AesKey::kSize);
328   prng_keys_to_add.push_back(AesKey(key));
329   std::vector<AesKey> prng_keys_to_subtract;
330   SessionId session_id = {std::string(32, 'Z')};
331   std::vector<InputVectorSpecification> vector_specs;
332 
333   // Check a variety of bit_widths
334   std::vector<uint64_t> moduli{1ULL << 1, 1ULL << 4, 1ULL << 20, 1ULL << 24,
335                                SecAggVector::kMaxModulus};
336   for (uint64_t i : moduli) {
337     vector_specs.push_back(
338         InputVectorSpecification(absl::StrCat("test", i), 50, i));
339   }
340 
341   auto masks = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract, vector_specs,
342                           session_id, AesCtrPrngFactory());
343 
344   // Make sure all elements are less than the bound, and also at least one of
345   // them has the highest-allowed bit set.
346   for (uint64_t modulus : moduli) {
347     auto vec = masks->at(absl::StrCat("test", modulus));
348     bool high_order_bit_set = false;
349     for (uint64_t mask : vec) {
350       EXPECT_THAT(mask, Lt(modulus));
351       if (mask >= (modulus >> 1)) {
352         high_order_bit_set = true;
353       }
354     }
355     EXPECT_THAT(high_order_bit_set, Eq(true));
356   }
357 }
358 
359 // AES MapOfMasks: Arbitrary Moduli
360 
TEST_P(MapOfMasksTest,ReturnsZeroIfNoKeysSpecified_ArbitraryModuli)361 TEST_P(MapOfMasksTest, ReturnsZeroIfNoKeysSpecified_ArbitraryModuli) {
362   uint64_t modulus = 2636861836189;
363   std::vector<AesKey> prng_keys_to_add;
364   std::vector<AesKey> prng_keys_to_subtract;
365   SessionId session_id = {std::string(32, 'Z')};
366   std::vector<InputVectorSpecification> vector_specs;
367   vector_specs.push_back(InputVectorSpecification("test", 10, modulus));
368 
369   auto masks = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract, vector_specs,
370                           session_id, AesCtrPrngFactory());
371 
372   EXPECT_THAT(masks->size(), Eq(1));
373   std::vector<uint64_t> zeroes(10, 0);
374   EXPECT_THAT(masks->at("test"), Eq(std::vector<uint64_t>(10, 0)));
375 }
376 
TEST_P(MapOfMasksTest,ReturnsNonZeroIfOneKeySpecified_ArbitraryModuli)377 TEST_P(MapOfMasksTest, ReturnsNonZeroIfOneKeySpecified_ArbitraryModuli) {
378   uint64_t modulus = 2636861836189;
379   std::vector<AesKey> prng_keys_to_add;
380   uint8_t key[AesKey::kSize];
381   memset(key, 'A', AesKey::kSize);
382   prng_keys_to_add.push_back(AesKey(key));
383   std::vector<AesKey> prng_keys_to_subtract;
384   SessionId session_id = {std::string(32, 'Z')};
385   std::vector<InputVectorSpecification> vector_specs;
386   vector_specs.push_back(InputVectorSpecification("test", 10, modulus));
387 
388   auto masks = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract, vector_specs,
389                           session_id, AesCtrPrngFactory());
390 
391   EXPECT_THAT(masks->size(), Eq(1));
392   EXPECT_THAT(masks->at("test"), Ne(std::vector<uint64_t>(10, 0)));
393 }
394 
TEST_P(MapOfMasksTest,MapWithOneKeyDiffersFromMapWithTwoKeys_ArbitraryModuli)395 TEST_P(MapOfMasksTest, MapWithOneKeyDiffersFromMapWithTwoKeys_ArbitraryModuli) {
396   uint64_t modulus = 2636861836189;
397   std::vector<AesKey> prng_keys_to_add;
398   uint8_t
399       key[AesKey::kSize];  // This key is reusable because AesKey makes a copy
400   memset(key, 'A', AesKey::kSize);
401   prng_keys_to_add.push_back(AesKey(key));
402   std::vector<AesKey> prng_keys_to_subtract;
403   SessionId session_id = {std::string(32, 'Z')};
404   std::vector<InputVectorSpecification> vector_specs;
405   vector_specs.push_back(InputVectorSpecification("test", 10, modulus));
406 
407   auto masks1 = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract,
408                            vector_specs, session_id, AesCtrPrngFactory());
409 
410   memset(key, 'B', AesKey::kSize);
411   prng_keys_to_add.push_back(AesKey(key));
412   auto masks2 = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract,
413                            vector_specs, session_id, AesCtrPrngFactory());
414 
415   EXPECT_THAT(masks1->size(), Eq(1));
416   EXPECT_THAT(masks2->size(), Eq(1));
417   EXPECT_THAT(masks2->at("test"), Ne(masks1->at("test")));
418 }
419 
TEST_P(MapOfMasksTest,MapsAreDeterministic_KeysToAdd_ArbitraryModuli)420 TEST_P(MapOfMasksTest, MapsAreDeterministic_KeysToAdd_ArbitraryModuli) {
421   uint64_t modulus = 2636861836189;
422   uint8_t key[AesKey::kSize];
423   // prng_keys_to_add includes A
424   std::vector<AesKey> prng_keys_to_add;
425   memset(key, 'A', AesKey::kSize);
426   prng_keys_to_add.push_back(AesKey(key));
427 
428   // prng_keys_to_subtract includes B
429   std::vector<AesKey> prng_keys_to_subtract;
430   memset(key, 'B', AesKey::kSize);
431   prng_keys_to_subtract.push_back(AesKey(key));
432 
433   SessionId session_id = {std::string(32, 'Z')};
434   std::vector<InputVectorSpecification> vector_specs;
435   vector_specs.push_back(InputVectorSpecification("test", 10, modulus));
436 
437   auto masks1 = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract,
438                            vector_specs, session_id, AesCtrPrngFactory());
439 
440   auto masks2 = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract,
441                            vector_specs, session_id, AesCtrPrngFactory());
442 
443   EXPECT_THAT(masks1->size(), Eq(1));
444   EXPECT_THAT(masks2->size(), Eq(1));
445   std::vector<uint64_t> mask_vector1 = masks1->at("test");
446   std::vector<uint64_t> mask_vector2 = masks2->at("test");
447   for (int i = 0; i < 10; ++i) {
448     EXPECT_THAT(mask_vector1[i], Eq(mask_vector2[i]));
449   }
450 }
451 
TEST_P(MapOfMasksTest,MapsWithOppositeMasksCancel_ArbitraryModuli)452 TEST_P(MapOfMasksTest, MapsWithOppositeMasksCancel_ArbitraryModuli) {
453   uint64_t modulus = 2636861836189;
454   uint8_t key[AesKey::kSize];
455   // prng_keys_to_add includes A & B
456   std::vector<AesKey> prng_keys_to_add;
457   memset(key, 'A', AesKey::kSize);
458   prng_keys_to_add.push_back(AesKey(key));
459   memset(key, 'B', AesKey::kSize);
460   prng_keys_to_add.push_back(AesKey(key));
461   // prng_keys_to_subtract is empty
462   std::vector<AesKey> prng_keys_to_subtract;
463 
464   SessionId session_id = {std::string(32, 'Z')};
465   std::vector<InputVectorSpecification> vector_specs;
466   vector_specs.push_back(InputVectorSpecification("test", 10, modulus));
467 
468   auto masks1 = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract,
469                            vector_specs, session_id, AesCtrPrngFactory());
470 
471   auto masks2 = MapOfMasks(prng_keys_to_subtract, prng_keys_to_add,
472                            vector_specs, session_id, AesCtrPrngFactory());
473 
474   EXPECT_THAT(masks1->size(), Eq(1));
475   EXPECT_THAT(masks2->size(), Eq(1));
476   std::vector<uint64_t> mask_vector1 = masks1->at("test");
477   std::vector<uint64_t> mask_vector2 = masks2->at("test");
478   for (int i = 0; i < 10; ++i) {
479     EXPECT_THAT(AddMod(mask_vector1[i], mask_vector2[i], modulus), Eq(0));
480   }
481 }
482 
TEST_P(MapOfMasksTest,MapsWithMixedOppositeMasksCancel_ArbitraryModuli)483 TEST_P(MapOfMasksTest, MapsWithMixedOppositeMasksCancel_ArbitraryModuli) {
484   uint64_t modulus = 2636861836189;
485   uint8_t key[AesKey::kSize];
486   // prng_keys_to_add includes A
487   std::vector<AesKey> prng_keys_to_add;
488   memset(key, 'A', AesKey::kSize);
489   prng_keys_to_add.push_back(AesKey(key));
490   // prng_keys_to_subtract includes B
491   std::vector<AesKey> prng_keys_to_subtract;
492   memset(key, 'B', AesKey::kSize);
493   prng_keys_to_subtract.push_back(AesKey(key));
494 
495   SessionId session_id = {std::string(32, 'Z')};
496   std::vector<InputVectorSpecification> vector_specs;
497   vector_specs.push_back(InputVectorSpecification("test", 10, modulus));
498 
499   auto masks1 = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract,
500                            vector_specs, session_id, AesCtrPrngFactory());
501 
502   auto masks2 = MapOfMasks(prng_keys_to_subtract, prng_keys_to_add,
503                            vector_specs, session_id, AesCtrPrngFactory());
504 
505   EXPECT_THAT(masks1->size(), Eq(1));
506   EXPECT_THAT(masks2->size(), Eq(1));
507   std::vector<uint64_t> mask_vector1 = masks1->at("test");
508   std::vector<uint64_t> mask_vector2 = masks2->at("test");
509   for (int i = 0; i < 10; ++i) {
510     EXPECT_THAT(AddMod(mask_vector1[i], mask_vector2[i], modulus), Eq(0));
511   }
512 }
513 
TEST_P(MapOfMasksTest,PrngMaskGeneratesCorrectBitwidthMasks_ArbitraryModuli)514 TEST_P(MapOfMasksTest, PrngMaskGeneratesCorrectBitwidthMasks_ArbitraryModuli) {
515   std::vector<AesKey> prng_keys_to_add;
516   uint8_t key[AesKey::kSize];
517   memset(key, 'A', AesKey::kSize);
518   prng_keys_to_add.push_back(AesKey(key));
519   std::vector<AesKey> prng_keys_to_subtract;
520   SessionId session_id = {std::string(32, 'Z')};
521   std::vector<InputVectorSpecification> vector_specs;
522 
523   // Check a variety of bit_widths
524   for (uint64_t i : kArbitraryModuli) {
525     vector_specs.push_back(
526         InputVectorSpecification(absl::StrCat("test", i), 50, i));
527   }
528 
529   auto masks = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract, vector_specs,
530                           session_id, AesCtrPrngFactory());
531 
532   // Make sure all elements are less than the bound, and also at least one of
533   // them has the highest-allowed bit set.
534   for (uint64_t modulus : kArbitraryModuli) {
535     auto vec = masks->at(absl::StrCat("test", modulus));
536     bool high_order_bit_set = false;
537     for (uint64_t mask : vec) {
538       EXPECT_THAT(mask, Lt(modulus));
539       if (mask >= (modulus >> 1)) {
540         high_order_bit_set = true;
541       }
542     }
543     EXPECT_THAT(high_order_bit_set, Eq(true));
544   }
545 }
546 
547 INSTANTIATE_TEST_SUITE_P(MapOfMasksTest, MapOfMasksTest,
548                          ::testing::Values<MapOfMasksVersion>(CURRENT, V3,
549                                                               UNPACKED));
550 
551 }  // namespace
552 }  // namespace secagg
553 }  // namespace fcp
554