1 /*
2  * Copyright 2018 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     https://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/secagg/server/secagg_server_prng_running_state.h"
18 
19 #include <functional>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 
24 #include "gmock/gmock.h"
25 #include "gtest/gtest.h"
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/container/flat_hash_set.h"
28 #include "absl/container/node_hash_map.h"
29 #include "absl/container/node_hash_set.h"
30 #include "absl/strings/str_cat.h"
31 #include "fcp/base/monitoring.h"
32 #include "fcp/base/scheduler.h"
33 #include "fcp/secagg/server/aes/aes_secagg_server_protocol_impl.h"
34 #include "fcp/secagg/server/secagg_scheduler.h"
35 #include "fcp/secagg/server/secagg_server_enums.pb.h"
36 #include "fcp/secagg/server/secret_sharing_graph_factory.h"
37 #include "fcp/secagg/shared/aes_ctr_prng_factory.h"
38 #include "fcp/secagg/shared/aes_key.h"
39 #include "fcp/secagg/shared/ecdh_key_agreement.h"
40 #include "fcp/secagg/shared/ecdh_keys.h"
41 #include "fcp/secagg/shared/input_vector_specification.h"
42 #include "fcp/secagg/shared/map_of_masks.h"
43 #include "fcp/secagg/shared/secagg_messages.pb.h"
44 #include "fcp/secagg/shared/secagg_vector.h"
45 #include "fcp/secagg/shared/shamir_secret_sharing.h"
46 #include "fcp/secagg/testing/ecdh_pregenerated_test_keys.h"
47 #include "fcp/secagg/testing/fake_prng.h"
48 #include "fcp/secagg/testing/server/mock_secagg_server_metrics_listener.h"
49 #include "fcp/secagg/testing/server/mock_send_to_clients_interface.h"
50 #include "fcp/secagg/testing/server/test_async_runner.h"
51 #include "fcp/secagg/testing/test_matchers.h"
52 #include "fcp/testing/testing.h"
53 #include "fcp/tracing/test_tracing_recorder.h"
54 
55 namespace fcp {
56 namespace secagg {
57 namespace {
58 
59 using ::testing::_;
60 using ::testing::Eq;
61 using ::testing::Ge;
62 using ::testing::NiceMock;
63 
64 // For testing purposes, make an AesKey out of a string.
MakeAesKey(const std::string & key)65 AesKey MakeAesKey(const std::string& key) {
66   EXPECT_THAT(key.size(), Eq(AesKey::kSize));
67   return AesKey(reinterpret_cast<const uint8_t*>(key.c_str()));
68 }
69 
70 class MockScheduler : public Scheduler {
71  public:
72   MOCK_METHOD(void, Schedule, (std::function<void()>), (override));
73   MOCK_METHOD(void, WaitUntilIdle, ());
74 };
75 
__anonfc464b740202(const std::function<void()>& f) 76 constexpr auto call_fn = [](const std::function<void()>& f) { f(); };
77 
78 // Default test session_id.
MakeTestSessionId()79 std::unique_ptr<SessionId> MakeTestSessionId() {
80   SessionId session_id = {"session id number, 32 bytes long"};
81   return std::make_unique<SessionId>(session_id);
82 }
83 
CreateSecAggServerProtocolImpl(std::vector<InputVectorSpecification> input_vector_specs,MockSendToClientsInterface * sender,MockSecAggServerMetricsListener * metrics_listener=nullptr)84 std::unique_ptr<AesSecAggServerProtocolImpl> CreateSecAggServerProtocolImpl(
85     std::vector<InputVectorSpecification> input_vector_specs,
86     MockSendToClientsInterface* sender,
87     MockSecAggServerMetricsListener* metrics_listener = nullptr) {
88   SecretSharingGraphFactory factory;
89   auto parallel_scheduler = std::make_unique<NiceMock<MockScheduler>>();
90   auto sequential_scheduler = std::make_unique<NiceMock<MockScheduler>>();
91   EXPECT_CALL(*parallel_scheduler, Schedule(_)).WillRepeatedly(call_fn);
92   EXPECT_CALL(*sequential_scheduler, Schedule(_)).WillRepeatedly(call_fn);
93   auto impl = std::make_unique<AesSecAggServerProtocolImpl>(
94       factory.CreateCompleteGraph(4, 3),  // total number of clients is 4
95       3,  // minimum_number_of_clients_to_proceed
96       input_vector_specs,
97       std::unique_ptr<MockSecAggServerMetricsListener>(metrics_listener),
98       std::make_unique<AesCtrPrngFactory>(), sender,
99       std::make_unique<TestAsyncRunner>(std::move(parallel_scheduler),
100                                         std::move(sequential_scheduler)),
101       std::vector<ClientStatus>(4, ClientStatus::UNMASKING_RESPONSE_RECEIVED),
102       ServerVariant::NATIVE_V1);
103   impl->set_session_id(MakeTestSessionId());
104   EcdhPregeneratedTestKeys ecdh_keys;
105   for (int i = 0; i < 4; ++i) {
106     impl->SetPairwisePublicKeys(i, ecdh_keys.GetPublicKey(i));
107   }
108   impl->set_masked_input(std::make_unique<SecAggUnpackedVectorMap>());
109   return impl;
110 }
111 
112 // Mock class containing a callback that would be called when the PRNG is done.
113 class MockPrngDone {
114  public:
115   MOCK_METHOD(void, Callback, ());
116 };
117 
TEST(SecaggServerPrngRunningStateTest,IsAbortedReturnsFalse)118 TEST(SecaggServerPrngRunningStateTest, IsAbortedReturnsFalse) {
119   auto input_vector_specs = std::vector<InputVectorSpecification>();
120   input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
121   auto sender = std::make_unique<MockSendToClientsInterface>();
122   FakePrng prng;
123   ShamirSecretSharing sharer;
124   auto self_shamir_share_table = std::make_unique<
125       absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
126   for (int i = 0; i < 4; ++i) {
127     self_shamir_share_table->try_emplace(
128         i, sharer.Share(
129                3, 4,
130                MakeAesKey(absl::StrCat("test 32 byte AES key for user #", i))));
131   }
132 
133   auto impl = CreateSecAggServerProtocolImpl(input_vector_specs, sender.get());
134   impl->set_pairwise_shamir_share_table(
135       std::make_unique<
136           absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>());
137   impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
138 
139   SecAggServerPrngRunningState state(
140       std::move(impl),
141       0,   // number_of_clients_failed_after_sending_masked_input
142       0,   // number_of_clients_failed_before_sending_masked_input
143       0);  // number_of_clients_terminated_without_unmasking
144 
145   EXPECT_THAT(state.IsAborted(), Eq(false));
146 }
147 
TEST(SecaggServerPrngRunningStateTest,IsCompletedSuccessfullyReturnsFalse)148 TEST(SecaggServerPrngRunningStateTest, IsCompletedSuccessfullyReturnsFalse) {
149   auto input_vector_specs = std::vector<InputVectorSpecification>();
150   input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
151   auto sender = std::make_unique<MockSendToClientsInterface>();
152   FakePrng prng;
153   ShamirSecretSharing sharer;
154   auto self_shamir_share_table = std::make_unique<
155       absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
156   for (int i = 0; i < 4; ++i) {
157     self_shamir_share_table->try_emplace(
158         i, sharer.Share(
159                3, 4,
160                MakeAesKey(absl::StrCat("test 32 byte AES key for user #", i))));
161   }
162 
163   auto impl = CreateSecAggServerProtocolImpl(input_vector_specs, sender.get());
164   impl->set_pairwise_shamir_share_table(
165       std::make_unique<
166           absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>());
167   impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
168 
169   SecAggServerPrngRunningState state(
170       std::move(impl),
171       0,   // number_of_clients_failed_after_sending_masked_input
172       0,   // number_of_clients_failed_before_sending_masked_input
173       0);  // number_of_clients_terminated_without_unmasking
174 
175   EXPECT_THAT(state.IsCompletedSuccessfully(), Eq(false));
176 }
177 
TEST(SecaggServerPrngRunningStateTest,ErrorMessageRaisesError)178 TEST(SecaggServerPrngRunningStateTest, ErrorMessageRaisesError) {
179   auto input_vector_specs = std::vector<InputVectorSpecification>();
180   input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
181   auto sender = std::make_unique<MockSendToClientsInterface>();
182   FakePrng prng;
183   ShamirSecretSharing sharer;
184   auto self_shamir_share_table = std::make_unique<
185       absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
186   for (int i = 0; i < 4; ++i) {
187     self_shamir_share_table->try_emplace(
188         i, sharer.Share(
189                3, 4,
190                MakeAesKey(absl::StrCat("test 32 byte AES key for user #", i))));
191   }
192 
193   auto impl = CreateSecAggServerProtocolImpl(input_vector_specs, sender.get());
194   impl->set_pairwise_shamir_share_table(
195       std::make_unique<
196           absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>());
197   impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
198 
199   SecAggServerPrngRunningState state(
200       std::move(impl),
201       0,   // number_of_clients_failed_after_sending_masked_input
202       0,   // number_of_clients_failed_before_sending_masked_input
203       0);  // number_of_clients_terminated_without_unmasking
204 
205   EXPECT_THAT(state.ErrorMessage().ok(), Eq(false));
206 }
207 
TEST(SecaggServerPrngRunningStateTest,NumberOfMessagesReceivedInThisRoundReturnsZero)208 TEST(SecaggServerPrngRunningStateTest,
209      NumberOfMessagesReceivedInThisRoundReturnsZero) {
210   auto input_vector_specs = std::vector<InputVectorSpecification>();
211   input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
212   auto sender = std::make_unique<MockSendToClientsInterface>();
213   FakePrng prng;
214   ShamirSecretSharing sharer;
215   auto self_shamir_share_table = std::make_unique<
216       absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
217   for (int i = 0; i < 4; ++i) {
218     self_shamir_share_table->try_emplace(
219         i, sharer.Share(
220                3, 4,
221                MakeAesKey(absl::StrCat("test 32 byte AES key for user #", i))));
222   }
223 
224   auto impl = CreateSecAggServerProtocolImpl(input_vector_specs, sender.get());
225   impl->set_pairwise_shamir_share_table(
226       std::make_unique<
227           absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>());
228   impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
229 
230   SecAggServerPrngRunningState state(
231       std::move(impl),
232       0,   // number_of_clients_failed_after_sending_masked_input
233       0,   // number_of_clients_failed_before_sending_masked_input
234       0);  // number_of_clients_terminated_without_unmasking
235 
236   EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(0));
237 }
238 
TEST(SecaggServerPrngRunningStateTest,NumberOfClientsReadyForNextRoundReturnsZero)239 TEST(SecaggServerPrngRunningStateTest,
240      NumberOfClientsReadyForNextRoundReturnsZero) {
241   auto input_vector_specs = std::vector<InputVectorSpecification>();
242   input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
243   auto sender = std::make_unique<MockSendToClientsInterface>();
244   FakePrng prng;
245   ShamirSecretSharing sharer;
246   auto self_shamir_share_table = std::make_unique<
247       absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
248   for (int i = 0; i < 4; ++i) {
249     self_shamir_share_table->try_emplace(
250         i, sharer.Share(
251                3, 4,
252                MakeAesKey(absl::StrCat("test 32 byte AES key for user #", i))));
253   }
254 
255   auto impl = CreateSecAggServerProtocolImpl(input_vector_specs, sender.get());
256   impl->set_pairwise_shamir_share_table(
257       std::make_unique<
258           absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>());
259   impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
260 
261   SecAggServerPrngRunningState state(
262       std::move(impl),
263       0,   // number_of_clients_failed_after_sending_masked_input
264       0,   // number_of_clients_failed_before_sending_masked_input
265       0);  // number_of_clients_terminated_without_unmasking
266 
267   EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(0));
268 }
269 
TEST(SecaggServerPrngRunningStateTest,HandleNonAbortMessageAbortsClientDoesNotRecordMetrics)270 TEST(SecaggServerPrngRunningStateTest,
271      HandleNonAbortMessageAbortsClientDoesNotRecordMetrics) {
272   TestTracingRecorder tracing_recorder;
273   auto input_vector_specs = std::vector<InputVectorSpecification>();
274   input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
275   MockSecAggServerMetricsListener* metrics =
276       new MockSecAggServerMetricsListener();
277   auto sender = std::make_unique<MockSendToClientsInterface>();
278   FakePrng prng;
279   ShamirSecretSharing sharer;
280   auto self_shamir_share_table = std::make_unique<
281       absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
282   for (int i = 0; i < 4; ++i) {
283     self_shamir_share_table->try_emplace(
284         i, sharer.Share(
285                3, 4,
286                MakeAesKey(absl::StrCat("test 32 byte AES key for user #", i))));
287   }
288 
289   auto impl =
290       CreateSecAggServerProtocolImpl(input_vector_specs, sender.get(), metrics);
291   impl->set_pairwise_shamir_share_table(
292       std::make_unique<
293           absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>());
294   impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
295 
296   SecAggServerPrngRunningState state(
297       std::move(impl),
298       0,   // number_of_clients_failed_after_sending_masked_input
299       0,   // number_of_clients_failed_before_sending_masked_input
300       0);  // number_of_clients_terminated_without_unmasking
301 
302   ServerToClientWrapperMessage abort_message;
303   abort_message.mutable_abort()->set_early_success(false);
304   abort_message.mutable_abort()->set_diagnostic_info(
305       "Non-abort message sent during PrngUnmasking step.");
306 
307   ClientToServerWrapperMessage client_message;
308   EXPECT_CALL(*sender, Send(Eq(0), EqualsProto(abort_message)));
309   EXPECT_CALL(*metrics, MessageReceivedSizes(
310                             Eq(ClientToServerWrapperMessage::
311                                    MessageContentCase::MESSAGE_CONTENT_NOT_SET),
312                             Eq(false), Eq(client_message.ByteSizeLong())));
313   EXPECT_CALL(*metrics,
314               IndividualMessageSizes(
315                   Eq(ServerToClientWrapperMessage::MessageContentCase::kAbort),
316                   Eq(abort_message.ByteSizeLong())));
317   EXPECT_CALL(*metrics, ClientsDropped(_, _)).Times(0);
318 
319   EXPECT_THAT(state.HandleMessage(0, client_message), IsOk());
320   EXPECT_THAT(state.NumberOfClientsFailedAfterSendingMaskedInput(), Eq(0));
321   ASSERT_THAT(state.AbortedClientIds().contains(0), Eq(true));
322   EXPECT_THAT(tracing_recorder.FindAllEvents<IndividualMessageSent>(),
323               ElementsAre(IsEvent<IndividualMessageSent>(
324                   Eq(0), Eq(ServerToClientMessageType_Abort),
325                   Eq(abort_message.ByteSizeLong()))));
326   EXPECT_THAT(tracing_recorder.FindAllEvents<ClientMessageReceived>(),
327               ElementsAre(IsEvent<ClientMessageReceived>(
328                   Eq(ClientToServerMessageType_MessageContentNotSet),
329                   Eq(client_message.ByteSizeLong()), Eq(false), Ge(0))));
330 }
331 
TEST(SecaggServerPrngRunningStateTest,HandleAbortMessageAbortsClientDoesNotRecordMetrics)332 TEST(SecaggServerPrngRunningStateTest,
333      HandleAbortMessageAbortsClientDoesNotRecordMetrics) {
334   TestTracingRecorder tracing_recorder;
335   auto input_vector_specs = std::vector<InputVectorSpecification>();
336   input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
337   MockSecAggServerMetricsListener* metrics =
338       new MockSecAggServerMetricsListener();
339   auto sender = std::make_unique<MockSendToClientsInterface>();
340   FakePrng prng;
341   ShamirSecretSharing sharer;
342   auto self_shamir_share_table = std::make_unique<
343       absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
344   for (int i = 0; i < 4; ++i) {
345     self_shamir_share_table->try_emplace(
346         i, sharer.Share(
347                3, 4,
348                MakeAesKey(absl::StrCat("test 32 byte AES key for user #", i))));
349   }
350 
351   auto impl =
352       CreateSecAggServerProtocolImpl(input_vector_specs, sender.get(), metrics);
353   impl->set_pairwise_shamir_share_table(
354       std::make_unique<
355           absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>());
356   impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
357 
358   SecAggServerPrngRunningState state(
359       std::move(impl),
360       0,   // number_of_clients_failed_after_sending_masked_input
361       0,   // number_of_clients_failed_before_sending_masked_input
362       0);  // number_of_clients_terminated_without_unmasking
363 
364   ClientToServerWrapperMessage client_message;
365   client_message.mutable_abort();
366   EXPECT_CALL(*metrics,
367               MessageReceivedSizes(
368                   Eq(ClientToServerWrapperMessage::MessageContentCase::kAbort),
369                   Eq(false), Eq(client_message.ByteSizeLong())));
370   EXPECT_CALL(*metrics, ClientsDropped(_, _)).Times(0);
371   EXPECT_CALL(*sender, Send(Eq(0), _)).Times(0);
372 
373   EXPECT_THAT(state.HandleMessage(0, client_message), IsOk());
374   EXPECT_THAT(state.NumberOfClientsFailedAfterSendingMaskedInput(), Eq(0));
375   ASSERT_THAT(state.AbortedClientIds().contains(0), Eq(true));
376   EXPECT_THAT(tracing_recorder.FindAllEvents<ClientMessageReceived>(),
377               ElementsAre(IsEvent<ClientMessageReceived>(
378                   Eq(ClientToServerMessageType_Abort),
379                   Eq(client_message.ByteSizeLong()), Eq(false), Ge(0))));
380 }
381 
TEST(SecaggServerPrngRunningStateTest,AbortReturnsValidStateAndNotifiesClients)382 TEST(SecaggServerPrngRunningStateTest,
383      AbortReturnsValidStateAndNotifiesClients) {
384   TestTracingRecorder tracing_recorder;
385   auto input_vector_specs = std::vector<InputVectorSpecification>();
386   input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
387   MockSecAggServerMetricsListener* metrics =
388       new MockSecAggServerMetricsListener();
389   auto sender = std::make_unique<MockSendToClientsInterface>();
390   FakePrng prng;
391   ShamirSecretSharing sharer;
392   auto self_shamir_share_table = std::make_unique<
393       absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
394   for (int i = 0; i < 4; ++i) {
395     self_shamir_share_table->try_emplace(
396         i, sharer.Share(
397                3, 4,
398                MakeAesKey(absl::StrCat("test 32 byte AES key for user #", i))));
399   }
400 
401   auto impl =
402       CreateSecAggServerProtocolImpl(input_vector_specs, sender.get(), metrics);
403   impl->set_pairwise_shamir_share_table(
404       std::make_unique<
405           absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>());
406   impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
407 
408   SecAggServerPrngRunningState state(
409       std::move(impl),
410       0,   // number_of_clients_failed_after_sending_masked_input
411       0,   // number_of_clients_failed_before_sending_masked_input
412       0);  // number_of_clients_terminated_without_unmasking
413 
414   ServerToClientWrapperMessage abort_message;
415   abort_message.mutable_abort()->set_early_success(false);
416   abort_message.mutable_abort()->set_diagnostic_info("test abort reason");
417 
418   EXPECT_CALL(*metrics,
419               ProtocolOutcomes(Eq(SecAggServerOutcome::UNHANDLED_ERROR)));
420   EXPECT_CALL(*sender, SendBroadcast(EqualsProto(abort_message)));
421   auto next_state =
422       state.Abort("test abort reason", SecAggServerOutcome::UNHANDLED_ERROR);
423 
424   ASSERT_THAT(next_state->State(), Eq(SecAggServerStateKind::ABORTED));
425   ASSERT_THAT(next_state->ErrorMessage().ok(), Eq(true));
426   EXPECT_THAT(next_state->ErrorMessage().value(), Eq("test abort reason"));
427   EXPECT_THAT(tracing_recorder.FindAllEvents<BroadcastMessageSent>(),
428               ElementsAre(IsEvent<BroadcastMessageSent>(
429                   Eq(ServerToClientMessageType_Abort),
430                   Eq(abort_message.ByteSizeLong()))));
431 }
432 
TEST(SecaggServerPrngRunningStateTest,PrngGetsRightMasksWhenAllClientsSurvive)433 TEST(SecaggServerPrngRunningStateTest,
434      PrngGetsRightMasksWhenAllClientsSurvive) {
435   // First, set up necessary data for the SecAggServerPrngRunningState
436   auto input_vector_specs = std::vector<InputVectorSpecification>();
437   input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
438   auto sender = std::make_unique<MockSendToClientsInterface>();
439   FakePrng prng;
440   ShamirSecretSharing sharer;
441   auto self_shamir_share_table = std::make_unique<
442       absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
443   for (int i = 0; i < 4; ++i) {
444     self_shamir_share_table->insert(std::make_pair(
445         i, sharer.Share(3, 4,
446                         MakeAesKey(absl::StrCat(
447                             "test 32 byte AES key for user #", i)))));
448   }
449 
450   // Generate the expected (negative) sum of masking vectors using MapofMasks.
451   std::vector<AesKey> prng_keys_to_add;
452   std::vector<AesKey> prng_keys_to_subtract;
453   for (int i = 0; i < 4; ++i) {
454     prng_keys_to_subtract.push_back(
455         MakeAesKey(absl::StrCat("test 32 byte AES key for user #", i)));
456   }
457   auto session_id = MakeTestSessionId();
458   auto expected_map_of_masks =
459       MapOfMasks(prng_keys_to_add, prng_keys_to_subtract, input_vector_specs,
460                  *session_id, AesCtrPrngFactory());
461 
462   auto impl = CreateSecAggServerProtocolImpl(input_vector_specs, sender.get());
463   auto zero_map = std::make_unique<SecAggUnpackedVectorMap>();
464   zero_map->emplace("foobar", SecAggUnpackedVector({0, 0, 0, 0}, 32));
465   impl->set_masked_input(std::move(zero_map));
466   impl->set_pairwise_shamir_share_table(
467       std::make_unique<
468           absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>());
469   impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
470 
471   SecAggServerPrngRunningState state(
472       std::move(impl),
473       0,   // number_of_clients_failed_after_sending_masked_input
474       0,   // number_of_clients_failed_before_sending_masked_input
475       0);  // number_of_clients_terminated_without_unmasking
476 
477   MockPrngDone prng_done;
478   EXPECT_CALL(prng_done, Callback());
479 
480   state.EnterState();
481   state.SetAsyncCallback([&]() { prng_done.Callback(); });
482 
483   EXPECT_THAT(state.ReadyForNextRound(), Eq(true));
484 
485   auto next_state = state.ProceedToNextRound();
486   ASSERT_THAT(next_state.ok(), Eq(true));
487   ASSERT_THAT(next_state.value()->State(),
488               Eq(SecAggServerStateKind::COMPLETED));
489   auto result = next_state.value()->Result();
490   ASSERT_THAT(result.ok(), Eq(true));
491   EXPECT_THAT(*result.value(),
492               testing::MatchesSecAggVectorMap(*expected_map_of_masks));
493 }
494 
TEST(SecaggServerPrngRunningStateTest,PrngGetsRightMasksWithOneDeadClientAfterSendingInput)495 TEST(SecaggServerPrngRunningStateTest,
496      PrngGetsRightMasksWithOneDeadClientAfterSendingInput) {
497   // In this test, client 1 died after sending its masked input. Its input will
498   // still be included.
499   //
500   // First, set up necessary data for the SecAggServerPrngRunningState.
501   auto input_vector_specs = std::vector<InputVectorSpecification>();
502   input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
503   auto sender = std::make_unique<MockSendToClientsInterface>();
504   FakePrng prng;
505   ShamirSecretSharing sharer;
506   auto pairwise_shamir_share_table = std::make_unique<
507       absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
508   auto self_shamir_share_table = std::make_unique<
509       absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
510 
511   auto aborted_client_ids = std::make_unique<absl::flat_hash_set<uint32_t>>();
512   auto impl = CreateSecAggServerProtocolImpl(input_vector_specs, sender.get());
513   impl->set_client_status(
514       1, ClientStatus::DEAD_AFTER_MASKED_INPUT_RESPONSE_RECEIVED);
515 
516   aborted_client_ids->insert(1);
517 
518   for (int i = 0; i < 4; ++i) {
519     self_shamir_share_table->insert(std::make_pair(
520         i, sharer.Share(3, 4,
521                         MakeAesKey(absl::StrCat(
522                             "test 32 byte AES key for user #", i)))));
523     // Blank out the share in position 1 because it would not have been sent.
524     (*self_shamir_share_table)[i][1] = {""};
525   }
526 
527   auto zero_map = std::make_unique<SecAggUnpackedVectorMap>();
528   zero_map->insert(std::make_pair(
529       "foobar", SecAggUnpackedVector(std::vector<uint64_t>{0, 0, 0, 0}, 32)));
530   impl->set_masked_input(std::move(zero_map));
531   impl->set_pairwise_shamir_share_table(std::move(pairwise_shamir_share_table));
532   impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
533 
534   // Generate the expected (negative) sum of masking vectors using MapofMasks.
535   std::vector<AesKey> prng_keys_to_add;
536   std::vector<AesKey> prng_keys_to_subtract;
537   for (int i = 0; i < 4; ++i) {
538     prng_keys_to_subtract.push_back(
539         MakeAesKey(absl::StrCat("test 32 byte AES key for user #", i)));
540   }
541   auto session_id = MakeTestSessionId();
542   auto expected_map_of_masks =
543       MapOfMasks(prng_keys_to_add, prng_keys_to_subtract, input_vector_specs,
544                  *session_id, AesCtrPrngFactory());
545 
546   SecAggServerPrngRunningState state(
547       std::move(impl),
548       1,   // number_of_clients_failed_after_sending_masked_input
549       0,   // number_of_clients_failed_before_sending_masked_input
550       1);  // number_of_clients_terminated_without_unmasking
551 
552   MockPrngDone prng_done;
553   EXPECT_CALL(prng_done, Callback());
554 
555   state.EnterState();
556   state.SetAsyncCallback([&]() { prng_done.Callback(); });
557 
558   EXPECT_THAT(state.ReadyForNextRound(), Eq(true));
559 
560   auto next_state = state.ProceedToNextRound();
561   ASSERT_THAT(next_state.ok(), Eq(true));
562   ASSERT_THAT(next_state.value()->State(),
563               Eq(SecAggServerStateKind::COMPLETED));
564   auto result = next_state.value()->Result();
565   ASSERT_THAT(result.ok(), Eq(true));
566   EXPECT_THAT(*result.value(),
567               testing::MatchesSecAggVectorMap(*expected_map_of_masks));
568 }
569 
TEST(SecaggServerPrngRunningStateTest,PrngGetsRightMasksWithOneDeadClientBeforeSendingInput)570 TEST(SecaggServerPrngRunningStateTest,
571      PrngGetsRightMasksWithOneDeadClientBeforeSendingInput) {
572   // In this test, client 1 died before sending its masked input but after other
573   // clients computed theirs, so its pairwise key will need to be canceled out.
574   //
575   // First, set up necessary data for the SecAggServerPrngRunningState.
576   auto input_vector_specs = std::vector<InputVectorSpecification>();
577   input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
578   auto sender = std::make_unique<MockSendToClientsInterface>();
579   FakePrng prng;
580   ShamirSecretSharing sharer;
581   auto pairwise_shamir_share_table = std::make_unique<
582       absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
583   auto self_shamir_share_table = std::make_unique<
584       absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
585 
586   auto impl = CreateSecAggServerProtocolImpl(input_vector_specs, sender.get());
587   impl->set_client_status(1, ClientStatus::DEAD_AFTER_SHARE_KEYS_RECEIVED);
588 
589   auto aborted_client_ids = std::make_unique<absl::flat_hash_set<uint32_t>>();
590   aborted_client_ids->insert(1);
591 
592   EcdhPregeneratedTestKeys ecdh_keys;
593   for (int i = 0; i < 4; ++i) {
594     if (i == 1) {
595       // Client 1 died in the previous step, so the other clients will have sent
596       // shares of its pairwise key instead.
597       pairwise_shamir_share_table->insert(
598           std::make_pair(i, sharer.Share(3, 4, ecdh_keys.GetPrivateKey(i))));
599       // Blank out the share in position 1 because it would not have been sent.
600       (*pairwise_shamir_share_table)[i][1] = {""};
601     } else {
602       self_shamir_share_table->insert(std::make_pair(
603           i, sharer.Share(3, 4,
604                           MakeAesKey(absl::StrCat(
605                               "test 32 byte AES key for user #", i)))));
606       // Blank out the share in position 1 because it would not have been sent.
607       (*self_shamir_share_table)[i][1] = {""};
608     }
609   }
610 
611   auto zero_map = std::make_unique<SecAggUnpackedVectorMap>();
612   zero_map->emplace("foobar", SecAggUnpackedVector({0, 0, 0, 0}, 32));
613   impl->set_masked_input(std::move(zero_map));
614   impl->set_pairwise_shamir_share_table(std::move(pairwise_shamir_share_table));
615   impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
616 
617   // Generate the expected (negative) sum of masking vectors using MapofMasks.
618   // We should subtract the self masks of clients 0, 2, and 3. We should
619   // subtract the pairwise mask 2 and 3 added for 1, and add the pairwise mask
620   // that 0 subtracted for 1.
621   auto aborted_client_key_agreement =
622       EcdhKeyAgreement::CreateFromPrivateKey(ecdh_keys.GetPrivateKey(1));
623   std::vector<AesKey> prng_keys_to_add;
624   std::vector<AesKey> prng_keys_to_subtract;
625   for (int i = 0; i < 4; ++i) {
626     if (i == 1) {
627       continue;
628     }
629     prng_keys_to_subtract.push_back(
630         MakeAesKey(absl::StrCat("test 32 byte AES key for user #", i)));
631     AesKey pairwise_key = aborted_client_key_agreement.value()
632                               ->ComputeSharedSecret(ecdh_keys.GetPublicKey(i))
633                               .value();
634     if (i == 0) {
635       prng_keys_to_add.push_back(pairwise_key);
636     } else {
637       prng_keys_to_subtract.push_back(pairwise_key);
638     }
639   }
640   auto session_id = MakeTestSessionId();
641   auto expected_map_of_masks =
642       MapOfMasks(prng_keys_to_add, prng_keys_to_subtract, input_vector_specs,
643                  *session_id, AesCtrPrngFactory());
644 
645   SecAggServerPrngRunningState state(
646       std::move(impl),
647       0,   // number_of_clients_failed_after_sending_masked_input
648       1,   // number_of_clients_failed_before_sending_masked_input
649       0);  // number_of_clients_terminated_without_unmasking
650 
651   MockPrngDone prng_done;
652   EXPECT_CALL(prng_done, Callback());
653 
654   state.EnterState();
655   state.SetAsyncCallback([&]() { prng_done.Callback(); });
656 
657   EXPECT_THAT(state.ReadyForNextRound(), Eq(true));
658 
659   auto next_state = state.ProceedToNextRound();
660   ASSERT_THAT(next_state.ok(), Eq(true));
661   ASSERT_THAT(next_state.value()->State(),
662               Eq(SecAggServerStateKind::COMPLETED));
663   auto result = next_state.value()->Result();
664   ASSERT_THAT(result.ok(), Eq(true));
665   EXPECT_THAT(*result.value(),
666               testing::MatchesSecAggVectorMap(*expected_map_of_masks));
667 }
668 
TEST(SecaggServerPrngRunningStateTest,PrngGetsRightMasksAndCallsCallbackIfSpecified)669 TEST(SecaggServerPrngRunningStateTest,
670      PrngGetsRightMasksAndCallsCallbackIfSpecified) {
671   // In this test, there is now a callback that should be called when the PRNG
672   // is done running.
673   //
674   // First, set up necessary data for the SecAggServerPrngRunningState.
675   auto input_vector_specs = std::vector<InputVectorSpecification>();
676   input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
677   auto sender = std::make_unique<MockSendToClientsInterface>();
678   FakePrng prng;
679   ShamirSecretSharing sharer;
680   auto pairwise_shamir_share_table = std::make_unique<
681       absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
682   auto self_shamir_share_table = std::make_unique<
683       absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
684 
685   auto impl = CreateSecAggServerProtocolImpl(input_vector_specs, sender.get());
686   impl->set_client_status(1, ClientStatus::DEAD_AFTER_SHARE_KEYS_RECEIVED);
687 
688   auto aborted_client_ids = std::make_unique<absl::flat_hash_set<uint32_t>>();
689   aborted_client_ids->insert(1);
690 
691   EcdhPregeneratedTestKeys ecdh_keys;
692   for (int i = 0; i < 4; ++i) {
693     if (i == 1) {
694       // Client 1 died in the previous step, so the other clients will have sent
695       // shares of its pairwise key instead.
696       pairwise_shamir_share_table->insert(
697           std::make_pair(i, sharer.Share(3, 4, ecdh_keys.GetPrivateKey(i))));
698       // Blank out the share in position 1 because it would not have been sent.
699       (*pairwise_shamir_share_table)[i][1] = {""};
700     } else {
701       self_shamir_share_table->insert(std::make_pair(
702           i, sharer.Share(3, 4,
703                           MakeAesKey(absl::StrCat(
704                               "test 32 byte AES key for user #", i)))));
705       // Blank out the share in position 1 because it would not have been sent.
706       (*self_shamir_share_table)[i][1] = {""};
707     }
708   }
709 
710   auto zero_map = std::make_unique<SecAggUnpackedVectorMap>();
711   zero_map->emplace("foobar", SecAggUnpackedVector({0, 0, 0, 0}, 32));
712   impl->set_masked_input(std::move(zero_map));
713   impl->set_pairwise_shamir_share_table(std::move(pairwise_shamir_share_table));
714   impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
715 
716   // Generate the expected (negative) sum of masking vectors using MapofMasks.
717   // We should subtract the self masks of clients 0, 2, and 3. We should
718   // subtract the pairwise mask 2 and 3 added for 1, and add the pairwise mask
719   // that 0 subtracted for 1.
720   auto aborted_client_key_agreement =
721       EcdhKeyAgreement::CreateFromPrivateKey(ecdh_keys.GetPrivateKey(1));
722   std::vector<AesKey> prng_keys_to_add;
723   std::vector<AesKey> prng_keys_to_subtract;
724   for (int i = 0; i < 4; ++i) {
725     if (i == 1) {
726       continue;
727     }
728     prng_keys_to_subtract.push_back(
729         MakeAesKey(absl::StrCat("test 32 byte AES key for user #", i)));
730     AesKey pairwise_key = aborted_client_key_agreement.value()
731                               ->ComputeSharedSecret(ecdh_keys.GetPublicKey(i))
732                               .value();
733     if (i == 0) {
734       prng_keys_to_add.push_back(pairwise_key);
735     } else {
736       prng_keys_to_subtract.push_back(pairwise_key);
737     }
738   }
739   auto session_id = MakeTestSessionId();
740   auto expected_map_of_masks =
741       MapOfMasks(prng_keys_to_add, prng_keys_to_subtract, input_vector_specs,
742                  *session_id, AesCtrPrngFactory());
743 
744   SecAggServerPrngRunningState state(
745       std::move(impl),
746       0,   // number_of_clients_failed_after_sending_masked_input
747       1,   // number_of_clients_failed_before_sending_masked_input
748       0);  // number_of_clients_terminated_without_unmasking
749 
750   MockPrngDone prng_done;
751   EXPECT_CALL(prng_done, Callback());
752 
753   state.EnterState();
754   state.SetAsyncCallback([&]() { prng_done.Callback(); });
755 
756   EXPECT_THAT(state.ReadyForNextRound(), Eq(true));
757 
758   auto next_state = state.ProceedToNextRound();
759   ASSERT_THAT(next_state.ok(), Eq(true));
760   ASSERT_THAT(next_state.value()->State(),
761               Eq(SecAggServerStateKind::COMPLETED));
762   auto result = next_state.value()->Result();
763   ASSERT_THAT(result.ok(), Eq(true));
764   EXPECT_THAT(*result.value(),
765               testing::MatchesSecAggVectorMap(*expected_map_of_masks));
766 }
767 
TEST(SecaggServerPrngRunningStateTest,SetAsyncCallbackCanBeCalledTwice)768 TEST(SecaggServerPrngRunningStateTest, SetAsyncCallbackCanBeCalledTwice) {
769   // StartPrng should have the property that it can be called after it has
770   // already run successfully without any problems. It should just return OK
771   // again.
772   //
773   // First, set up necessary data for the SecAggServerPrngRunningState.
774   auto input_vector_specs = std::vector<InputVectorSpecification>();
775   input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
776   auto sender = std::make_unique<MockSendToClientsInterface>();
777   FakePrng prng;
778   ShamirSecretSharing sharer;
779   auto pairwise_shamir_share_table = std::make_unique<
780       absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
781   auto self_shamir_share_table = std::make_unique<
782       absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
783 
784   auto impl = CreateSecAggServerProtocolImpl(input_vector_specs, sender.get());
785   impl->set_client_status(1, ClientStatus::DEAD_AFTER_SHARE_KEYS_RECEIVED);
786 
787   auto aborted_client_ids = std::make_unique<absl::flat_hash_set<uint32_t>>();
788   aborted_client_ids->insert(1);
789 
790   EcdhPregeneratedTestKeys ecdh_keys;
791   for (int i = 0; i < 4; ++i) {
792     if (i == 1) {
793       // Client 1 died in the previous step, so the other clients will have sent
794       // shares of its pairwise key instead.
795       pairwise_shamir_share_table->insert(
796           std::make_pair(i, sharer.Share(3, 4, ecdh_keys.GetPrivateKey(i))));
797       // Blank out the share in position 1 because it would not have been sent.
798       (*pairwise_shamir_share_table)[i][1] = {""};
799     } else {
800       self_shamir_share_table->insert(std::make_pair(
801           i, sharer.Share(3, 4,
802                           MakeAesKey(absl::StrCat(
803                               "test 32 byte AES key for user #", i)))));
804       // Blank out the share in position 1 because it would not have been sent.
805       (*self_shamir_share_table)[i][1] = {""};
806     }
807   }
808 
809   auto zero_map = std::make_unique<SecAggUnpackedVectorMap>();
810   zero_map->emplace("foobar", SecAggUnpackedVector({0, 0, 0, 0}, 32));
811   impl->set_masked_input(std::move(zero_map));
812   impl->set_pairwise_shamir_share_table(std::move(pairwise_shamir_share_table));
813   impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
814 
815   // Generate the expected (negative) sum of masking vectors using MapofMasks.
816   // We should subtract the self masks of clients 0, 2, and 3. We should
817   // subtract the pairwise mask 2 and 3 added for 1, and add the pairwise mask
818   // that 0 subtracted for 1.
819   auto aborted_client_key_agreement =
820       EcdhKeyAgreement::CreateFromPrivateKey(ecdh_keys.GetPrivateKey(1));
821   std::vector<AesKey> prng_keys_to_add;
822   std::vector<AesKey> prng_keys_to_subtract;
823   for (int i = 0; i < 4; ++i) {
824     if (i == 1) {
825       continue;
826     }
827     prng_keys_to_subtract.push_back(
828         MakeAesKey(absl::StrCat("test 32 byte AES key for user #", i)));
829     AesKey pairwise_key = aborted_client_key_agreement.value()
830                               ->ComputeSharedSecret(ecdh_keys.GetPublicKey(i))
831                               .value();
832     if (i == 0) {
833       prng_keys_to_add.push_back(pairwise_key);
834     } else {
835       prng_keys_to_subtract.push_back(pairwise_key);
836     }
837   }
838   auto session_id = MakeTestSessionId();
839   auto expected_map_of_masks =
840       MapOfMasks(prng_keys_to_add, prng_keys_to_subtract, input_vector_specs,
841                  *session_id, AesCtrPrngFactory());
842 
843   SecAggServerPrngRunningState state(
844       std::move(impl),
845       0,   // number_of_clients_failed_after_sending_masked_input
846       1,   // number_of_clients_failed_before_sending_masked_input
847       0);  // number_of_clients_terminated_without_unmasking
848 
849   MockPrngDone prng_done;
850   EXPECT_CALL(prng_done, Callback());
851 
852   state.EnterState();
853   state.SetAsyncCallback([&]() { prng_done.Callback(); });
854 
855   EXPECT_THAT(state.ReadyForNextRound(), Eq(true));
856 
857   // Make sure we can call SetAsyncCallback again.
858   MockPrngDone prng_done_2;
859   EXPECT_CALL(prng_done_2, Callback());
860   state.SetAsyncCallback([&]() { prng_done_2.Callback(); });
861 
862   auto next_state = state.ProceedToNextRound();
863   ASSERT_THAT(next_state.ok(), Eq(true));
864   ASSERT_THAT(next_state.value()->State(),
865               Eq(SecAggServerStateKind::COMPLETED));
866   auto result = next_state.value()->Result();
867   ASSERT_THAT(result.ok(), Eq(true));
868   EXPECT_THAT(*result.value(),
869               testing::MatchesSecAggVectorMap(*expected_map_of_masks));
870 }
871 
TEST(SecaggServerPrngRunningStateTest,PrngGetsRightMasksWhenClientsUse16BSelfKeys)872 TEST(SecaggServerPrngRunningStateTest,
873      PrngGetsRightMasksWhenClientsUse16BSelfKeys) {
874   // TODO(team): This test is only for ensuring Java compatibility.
875   // First, set up necessary data for the SecAggServerPrngRunningState
876   auto input_vector_specs = std::vector<InputVectorSpecification>();
877   input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
878   auto sender = std::make_unique<MockSendToClientsInterface>();
879   FakePrng prng;
880   ShamirSecretSharing sharer;
881   auto self_shamir_share_table = std::make_unique<
882       absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
883   for (int i = 0; i < 4; ++i) {
884     self_shamir_share_table->insert(std::make_pair(
885         i, sharer.Share(3, 4,
886                         AesKey(reinterpret_cast<const uint8_t*>(
887                                    absl::StrCat("16B key of user", i).c_str()),
888                                16))));
889   }
890 
891   // Generate the expected (negative) sum of masking vectors using MapofMasks.
892   std::vector<AesKey> prng_keys_to_add;
893   std::vector<AesKey> prng_keys_to_subtract;
894   for (int i = 0; i < 4; ++i) {
895     prng_keys_to_subtract.push_back(
896         AesKey(reinterpret_cast<const uint8_t*>(
897                    absl::StrCat("16B key of user", i).c_str()),
898                16));
899   }
900   auto session_id = MakeTestSessionId();
901   auto expected_map_of_masks =
902       MapOfMasks(prng_keys_to_add, prng_keys_to_subtract, input_vector_specs,
903                  *session_id, AesCtrPrngFactory());
904 
905   auto impl = CreateSecAggServerProtocolImpl(input_vector_specs, sender.get());
906   auto zero_map = std::make_unique<SecAggUnpackedVectorMap>();
907   zero_map->emplace("foobar", SecAggUnpackedVector({0, 0, 0, 0}, 32));
908   impl->set_masked_input(std::move(zero_map));
909   impl->set_pairwise_shamir_share_table(
910       std::make_unique<
911           absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>());
912   impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
913 
914   SecAggServerPrngRunningState state(
915       std::move(impl),
916       0,   // number_of_clients_failed_after_sending_masked_input
917       0,   // number_of_clients_failed_before_sending_masked_input
918       0);  // number_of_clients_terminated_without_unmasking
919 
920   MockPrngDone prng_done;
921   EXPECT_CALL(prng_done, Callback());
922 
923   state.EnterState();
924   state.SetAsyncCallback([&]() { prng_done.Callback(); });
925 
926   EXPECT_THAT(state.ReadyForNextRound(), Eq(true));
927 
928   auto next_state = state.ProceedToNextRound();
929   ASSERT_THAT(next_state.ok(), Eq(true));
930   ASSERT_THAT(next_state.value()->State(),
931               Eq(SecAggServerStateKind::COMPLETED));
932   auto result = next_state.value()->Result();
933   ASSERT_THAT(result.ok(), Eq(true));
934   EXPECT_THAT(*result.value(),
935               testing::MatchesSecAggVectorMap(*expected_map_of_masks));
936 }
937 
TEST(SecaggServerPrngRunningStateTest,TimingMetricsAreRecorded)938 TEST(SecaggServerPrngRunningStateTest, TimingMetricsAreRecorded) {
939   // First, set up necessary data for the SecAggServerPrngRunningState
940   TestTracingRecorder tracing_recorder;
941   auto input_vector_specs = std::vector<InputVectorSpecification>();
942   input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
943   MockSecAggServerMetricsListener* metrics =
944       new MockSecAggServerMetricsListener();
945   auto sender = std::make_unique<MockSendToClientsInterface>();
946   FakePrng prng;
947   ShamirSecretSharing sharer;
948   auto self_shamir_share_table = std::make_unique<
949       absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
950   for (int i = 0; i < 4; ++i) {
951     self_shamir_share_table->insert(std::make_pair(
952         i, sharer.Share(3, 4,
953                         MakeAesKey(absl::StrCat(
954                             "test 32 byte AES key for user #", i)))));
955   }
956 
957   auto impl =
958       CreateSecAggServerProtocolImpl(input_vector_specs, sender.get(), metrics);
959   auto zero_map = std::make_unique<SecAggUnpackedVectorMap>();
960   zero_map->emplace("foobar", SecAggUnpackedVector({0, 0, 0, 0}, 32));
961   impl->set_masked_input(std::move(zero_map));
962   impl->set_pairwise_shamir_share_table(
963       std::make_unique<
964           absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>());
965   impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
966 
967   SecAggServerPrngRunningState state(
968       std::move(impl),
969       0,   // number_of_clients_failed_after_sending_masked_input
970       0,   // number_of_clients_failed_before_sending_masked_input
971       0);  // number_of_clients_terminated_without_unmasking
972 
973   MockPrngDone prng_done;
974   EXPECT_CALL(prng_done, Callback());
975 
976   EXPECT_CALL(*metrics, PrngExpansionTimes(Ge(0)));
977   EXPECT_CALL(*metrics, RoundTimes(Eq(SecAggServerStateKind::PRNG_RUNNING),
978                                    Eq(true), Ge(0)));
979   EXPECT_CALL(*metrics, ShamirReconstructionTimes(Ge(0)));
980 
981   state.EnterState();
982   state.SetAsyncCallback([&]() { prng_done.Callback(); });
983   EXPECT_THAT(state.ReadyForNextRound(), Eq(true));
984 
985   auto next_state = state.ProceedToNextRound();
986   ASSERT_THAT(next_state.ok(), Eq(true));
987   ASSERT_THAT(next_state.value()->State(),
988               Eq(SecAggServerStateKind::COMPLETED));
989   EXPECT_THAT(tracing_recorder.FindAllEvents<ShamirReconstruction>(),
990               ElementsAre(IsEvent<ShamirReconstruction>(Ge(0))));
991   EXPECT_THAT(tracing_recorder.FindAllEvents<PrngExpansion>(),
992               ElementsAre(IsEvent<PrngExpansion>(Ge(0))));
993 }
994 
995 }  // namespace
996 }  // namespace secagg
997 }  // namespace fcp
998