1 /*
2 * Copyright 2018 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fcp/secagg/server/secagg_server_prng_running_state.h"
18
19 #include <functional>
20 #include <memory>
21 #include <string>
22 #include <utility>
23
24 #include "gmock/gmock.h"
25 #include "gtest/gtest.h"
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/container/flat_hash_set.h"
28 #include "absl/container/node_hash_map.h"
29 #include "absl/container/node_hash_set.h"
30 #include "absl/strings/str_cat.h"
31 #include "fcp/base/monitoring.h"
32 #include "fcp/base/scheduler.h"
33 #include "fcp/secagg/server/aes/aes_secagg_server_protocol_impl.h"
34 #include "fcp/secagg/server/secagg_scheduler.h"
35 #include "fcp/secagg/server/secagg_server_enums.pb.h"
36 #include "fcp/secagg/server/secret_sharing_graph_factory.h"
37 #include "fcp/secagg/shared/aes_ctr_prng_factory.h"
38 #include "fcp/secagg/shared/aes_key.h"
39 #include "fcp/secagg/shared/ecdh_key_agreement.h"
40 #include "fcp/secagg/shared/ecdh_keys.h"
41 #include "fcp/secagg/shared/input_vector_specification.h"
42 #include "fcp/secagg/shared/map_of_masks.h"
43 #include "fcp/secagg/shared/secagg_messages.pb.h"
44 #include "fcp/secagg/shared/secagg_vector.h"
45 #include "fcp/secagg/shared/shamir_secret_sharing.h"
46 #include "fcp/secagg/testing/ecdh_pregenerated_test_keys.h"
47 #include "fcp/secagg/testing/fake_prng.h"
48 #include "fcp/secagg/testing/server/mock_secagg_server_metrics_listener.h"
49 #include "fcp/secagg/testing/server/mock_send_to_clients_interface.h"
50 #include "fcp/secagg/testing/server/test_async_runner.h"
51 #include "fcp/secagg/testing/test_matchers.h"
52 #include "fcp/testing/testing.h"
53 #include "fcp/tracing/test_tracing_recorder.h"
54
55 namespace fcp {
56 namespace secagg {
57 namespace {
58
59 using ::testing::_;
60 using ::testing::Eq;
61 using ::testing::Ge;
62 using ::testing::NiceMock;
63
64 // For testing purposes, make an AesKey out of a string.
MakeAesKey(const std::string & key)65 AesKey MakeAesKey(const std::string& key) {
66 EXPECT_THAT(key.size(), Eq(AesKey::kSize));
67 return AesKey(reinterpret_cast<const uint8_t*>(key.c_str()));
68 }
69
70 class MockScheduler : public Scheduler {
71 public:
72 MOCK_METHOD(void, Schedule, (std::function<void()>), (override));
73 MOCK_METHOD(void, WaitUntilIdle, ());
74 };
75
__anonfc464b740202(const std::function<void()>& f) 76 constexpr auto call_fn = [](const std::function<void()>& f) { f(); };
77
78 // Default test session_id.
MakeTestSessionId()79 std::unique_ptr<SessionId> MakeTestSessionId() {
80 SessionId session_id = {"session id number, 32 bytes long"};
81 return std::make_unique<SessionId>(session_id);
82 }
83
CreateSecAggServerProtocolImpl(std::vector<InputVectorSpecification> input_vector_specs,MockSendToClientsInterface * sender,MockSecAggServerMetricsListener * metrics_listener=nullptr)84 std::unique_ptr<AesSecAggServerProtocolImpl> CreateSecAggServerProtocolImpl(
85 std::vector<InputVectorSpecification> input_vector_specs,
86 MockSendToClientsInterface* sender,
87 MockSecAggServerMetricsListener* metrics_listener = nullptr) {
88 SecretSharingGraphFactory factory;
89 auto parallel_scheduler = std::make_unique<NiceMock<MockScheduler>>();
90 auto sequential_scheduler = std::make_unique<NiceMock<MockScheduler>>();
91 EXPECT_CALL(*parallel_scheduler, Schedule(_)).WillRepeatedly(call_fn);
92 EXPECT_CALL(*sequential_scheduler, Schedule(_)).WillRepeatedly(call_fn);
93 auto impl = std::make_unique<AesSecAggServerProtocolImpl>(
94 factory.CreateCompleteGraph(4, 3), // total number of clients is 4
95 3, // minimum_number_of_clients_to_proceed
96 input_vector_specs,
97 std::unique_ptr<MockSecAggServerMetricsListener>(metrics_listener),
98 std::make_unique<AesCtrPrngFactory>(), sender,
99 std::make_unique<TestAsyncRunner>(std::move(parallel_scheduler),
100 std::move(sequential_scheduler)),
101 std::vector<ClientStatus>(4, ClientStatus::UNMASKING_RESPONSE_RECEIVED),
102 ServerVariant::NATIVE_V1);
103 impl->set_session_id(MakeTestSessionId());
104 EcdhPregeneratedTestKeys ecdh_keys;
105 for (int i = 0; i < 4; ++i) {
106 impl->SetPairwisePublicKeys(i, ecdh_keys.GetPublicKey(i));
107 }
108 impl->set_masked_input(std::make_unique<SecAggUnpackedVectorMap>());
109 return impl;
110 }
111
112 // Mock class containing a callback that would be called when the PRNG is done.
113 class MockPrngDone {
114 public:
115 MOCK_METHOD(void, Callback, ());
116 };
117
TEST(SecaggServerPrngRunningStateTest,IsAbortedReturnsFalse)118 TEST(SecaggServerPrngRunningStateTest, IsAbortedReturnsFalse) {
119 auto input_vector_specs = std::vector<InputVectorSpecification>();
120 input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
121 auto sender = std::make_unique<MockSendToClientsInterface>();
122 FakePrng prng;
123 ShamirSecretSharing sharer;
124 auto self_shamir_share_table = std::make_unique<
125 absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
126 for (int i = 0; i < 4; ++i) {
127 self_shamir_share_table->try_emplace(
128 i, sharer.Share(
129 3, 4,
130 MakeAesKey(absl::StrCat("test 32 byte AES key for user #", i))));
131 }
132
133 auto impl = CreateSecAggServerProtocolImpl(input_vector_specs, sender.get());
134 impl->set_pairwise_shamir_share_table(
135 std::make_unique<
136 absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>());
137 impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
138
139 SecAggServerPrngRunningState state(
140 std::move(impl),
141 0, // number_of_clients_failed_after_sending_masked_input
142 0, // number_of_clients_failed_before_sending_masked_input
143 0); // number_of_clients_terminated_without_unmasking
144
145 EXPECT_THAT(state.IsAborted(), Eq(false));
146 }
147
TEST(SecaggServerPrngRunningStateTest,IsCompletedSuccessfullyReturnsFalse)148 TEST(SecaggServerPrngRunningStateTest, IsCompletedSuccessfullyReturnsFalse) {
149 auto input_vector_specs = std::vector<InputVectorSpecification>();
150 input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
151 auto sender = std::make_unique<MockSendToClientsInterface>();
152 FakePrng prng;
153 ShamirSecretSharing sharer;
154 auto self_shamir_share_table = std::make_unique<
155 absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
156 for (int i = 0; i < 4; ++i) {
157 self_shamir_share_table->try_emplace(
158 i, sharer.Share(
159 3, 4,
160 MakeAesKey(absl::StrCat("test 32 byte AES key for user #", i))));
161 }
162
163 auto impl = CreateSecAggServerProtocolImpl(input_vector_specs, sender.get());
164 impl->set_pairwise_shamir_share_table(
165 std::make_unique<
166 absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>());
167 impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
168
169 SecAggServerPrngRunningState state(
170 std::move(impl),
171 0, // number_of_clients_failed_after_sending_masked_input
172 0, // number_of_clients_failed_before_sending_masked_input
173 0); // number_of_clients_terminated_without_unmasking
174
175 EXPECT_THAT(state.IsCompletedSuccessfully(), Eq(false));
176 }
177
TEST(SecaggServerPrngRunningStateTest,ErrorMessageRaisesError)178 TEST(SecaggServerPrngRunningStateTest, ErrorMessageRaisesError) {
179 auto input_vector_specs = std::vector<InputVectorSpecification>();
180 input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
181 auto sender = std::make_unique<MockSendToClientsInterface>();
182 FakePrng prng;
183 ShamirSecretSharing sharer;
184 auto self_shamir_share_table = std::make_unique<
185 absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
186 for (int i = 0; i < 4; ++i) {
187 self_shamir_share_table->try_emplace(
188 i, sharer.Share(
189 3, 4,
190 MakeAesKey(absl::StrCat("test 32 byte AES key for user #", i))));
191 }
192
193 auto impl = CreateSecAggServerProtocolImpl(input_vector_specs, sender.get());
194 impl->set_pairwise_shamir_share_table(
195 std::make_unique<
196 absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>());
197 impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
198
199 SecAggServerPrngRunningState state(
200 std::move(impl),
201 0, // number_of_clients_failed_after_sending_masked_input
202 0, // number_of_clients_failed_before_sending_masked_input
203 0); // number_of_clients_terminated_without_unmasking
204
205 EXPECT_THAT(state.ErrorMessage().ok(), Eq(false));
206 }
207
TEST(SecaggServerPrngRunningStateTest,NumberOfMessagesReceivedInThisRoundReturnsZero)208 TEST(SecaggServerPrngRunningStateTest,
209 NumberOfMessagesReceivedInThisRoundReturnsZero) {
210 auto input_vector_specs = std::vector<InputVectorSpecification>();
211 input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
212 auto sender = std::make_unique<MockSendToClientsInterface>();
213 FakePrng prng;
214 ShamirSecretSharing sharer;
215 auto self_shamir_share_table = std::make_unique<
216 absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
217 for (int i = 0; i < 4; ++i) {
218 self_shamir_share_table->try_emplace(
219 i, sharer.Share(
220 3, 4,
221 MakeAesKey(absl::StrCat("test 32 byte AES key for user #", i))));
222 }
223
224 auto impl = CreateSecAggServerProtocolImpl(input_vector_specs, sender.get());
225 impl->set_pairwise_shamir_share_table(
226 std::make_unique<
227 absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>());
228 impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
229
230 SecAggServerPrngRunningState state(
231 std::move(impl),
232 0, // number_of_clients_failed_after_sending_masked_input
233 0, // number_of_clients_failed_before_sending_masked_input
234 0); // number_of_clients_terminated_without_unmasking
235
236 EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(0));
237 }
238
TEST(SecaggServerPrngRunningStateTest,NumberOfClientsReadyForNextRoundReturnsZero)239 TEST(SecaggServerPrngRunningStateTest,
240 NumberOfClientsReadyForNextRoundReturnsZero) {
241 auto input_vector_specs = std::vector<InputVectorSpecification>();
242 input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
243 auto sender = std::make_unique<MockSendToClientsInterface>();
244 FakePrng prng;
245 ShamirSecretSharing sharer;
246 auto self_shamir_share_table = std::make_unique<
247 absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
248 for (int i = 0; i < 4; ++i) {
249 self_shamir_share_table->try_emplace(
250 i, sharer.Share(
251 3, 4,
252 MakeAesKey(absl::StrCat("test 32 byte AES key for user #", i))));
253 }
254
255 auto impl = CreateSecAggServerProtocolImpl(input_vector_specs, sender.get());
256 impl->set_pairwise_shamir_share_table(
257 std::make_unique<
258 absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>());
259 impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
260
261 SecAggServerPrngRunningState state(
262 std::move(impl),
263 0, // number_of_clients_failed_after_sending_masked_input
264 0, // number_of_clients_failed_before_sending_masked_input
265 0); // number_of_clients_terminated_without_unmasking
266
267 EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(0));
268 }
269
TEST(SecaggServerPrngRunningStateTest,HandleNonAbortMessageAbortsClientDoesNotRecordMetrics)270 TEST(SecaggServerPrngRunningStateTest,
271 HandleNonAbortMessageAbortsClientDoesNotRecordMetrics) {
272 TestTracingRecorder tracing_recorder;
273 auto input_vector_specs = std::vector<InputVectorSpecification>();
274 input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
275 MockSecAggServerMetricsListener* metrics =
276 new MockSecAggServerMetricsListener();
277 auto sender = std::make_unique<MockSendToClientsInterface>();
278 FakePrng prng;
279 ShamirSecretSharing sharer;
280 auto self_shamir_share_table = std::make_unique<
281 absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
282 for (int i = 0; i < 4; ++i) {
283 self_shamir_share_table->try_emplace(
284 i, sharer.Share(
285 3, 4,
286 MakeAesKey(absl::StrCat("test 32 byte AES key for user #", i))));
287 }
288
289 auto impl =
290 CreateSecAggServerProtocolImpl(input_vector_specs, sender.get(), metrics);
291 impl->set_pairwise_shamir_share_table(
292 std::make_unique<
293 absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>());
294 impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
295
296 SecAggServerPrngRunningState state(
297 std::move(impl),
298 0, // number_of_clients_failed_after_sending_masked_input
299 0, // number_of_clients_failed_before_sending_masked_input
300 0); // number_of_clients_terminated_without_unmasking
301
302 ServerToClientWrapperMessage abort_message;
303 abort_message.mutable_abort()->set_early_success(false);
304 abort_message.mutable_abort()->set_diagnostic_info(
305 "Non-abort message sent during PrngUnmasking step.");
306
307 ClientToServerWrapperMessage client_message;
308 EXPECT_CALL(*sender, Send(Eq(0), EqualsProto(abort_message)));
309 EXPECT_CALL(*metrics, MessageReceivedSizes(
310 Eq(ClientToServerWrapperMessage::
311 MessageContentCase::MESSAGE_CONTENT_NOT_SET),
312 Eq(false), Eq(client_message.ByteSizeLong())));
313 EXPECT_CALL(*metrics,
314 IndividualMessageSizes(
315 Eq(ServerToClientWrapperMessage::MessageContentCase::kAbort),
316 Eq(abort_message.ByteSizeLong())));
317 EXPECT_CALL(*metrics, ClientsDropped(_, _)).Times(0);
318
319 EXPECT_THAT(state.HandleMessage(0, client_message), IsOk());
320 EXPECT_THAT(state.NumberOfClientsFailedAfterSendingMaskedInput(), Eq(0));
321 ASSERT_THAT(state.AbortedClientIds().contains(0), Eq(true));
322 EXPECT_THAT(tracing_recorder.FindAllEvents<IndividualMessageSent>(),
323 ElementsAre(IsEvent<IndividualMessageSent>(
324 Eq(0), Eq(ServerToClientMessageType_Abort),
325 Eq(abort_message.ByteSizeLong()))));
326 EXPECT_THAT(tracing_recorder.FindAllEvents<ClientMessageReceived>(),
327 ElementsAre(IsEvent<ClientMessageReceived>(
328 Eq(ClientToServerMessageType_MessageContentNotSet),
329 Eq(client_message.ByteSizeLong()), Eq(false), Ge(0))));
330 }
331
TEST(SecaggServerPrngRunningStateTest,HandleAbortMessageAbortsClientDoesNotRecordMetrics)332 TEST(SecaggServerPrngRunningStateTest,
333 HandleAbortMessageAbortsClientDoesNotRecordMetrics) {
334 TestTracingRecorder tracing_recorder;
335 auto input_vector_specs = std::vector<InputVectorSpecification>();
336 input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
337 MockSecAggServerMetricsListener* metrics =
338 new MockSecAggServerMetricsListener();
339 auto sender = std::make_unique<MockSendToClientsInterface>();
340 FakePrng prng;
341 ShamirSecretSharing sharer;
342 auto self_shamir_share_table = std::make_unique<
343 absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
344 for (int i = 0; i < 4; ++i) {
345 self_shamir_share_table->try_emplace(
346 i, sharer.Share(
347 3, 4,
348 MakeAesKey(absl::StrCat("test 32 byte AES key for user #", i))));
349 }
350
351 auto impl =
352 CreateSecAggServerProtocolImpl(input_vector_specs, sender.get(), metrics);
353 impl->set_pairwise_shamir_share_table(
354 std::make_unique<
355 absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>());
356 impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
357
358 SecAggServerPrngRunningState state(
359 std::move(impl),
360 0, // number_of_clients_failed_after_sending_masked_input
361 0, // number_of_clients_failed_before_sending_masked_input
362 0); // number_of_clients_terminated_without_unmasking
363
364 ClientToServerWrapperMessage client_message;
365 client_message.mutable_abort();
366 EXPECT_CALL(*metrics,
367 MessageReceivedSizes(
368 Eq(ClientToServerWrapperMessage::MessageContentCase::kAbort),
369 Eq(false), Eq(client_message.ByteSizeLong())));
370 EXPECT_CALL(*metrics, ClientsDropped(_, _)).Times(0);
371 EXPECT_CALL(*sender, Send(Eq(0), _)).Times(0);
372
373 EXPECT_THAT(state.HandleMessage(0, client_message), IsOk());
374 EXPECT_THAT(state.NumberOfClientsFailedAfterSendingMaskedInput(), Eq(0));
375 ASSERT_THAT(state.AbortedClientIds().contains(0), Eq(true));
376 EXPECT_THAT(tracing_recorder.FindAllEvents<ClientMessageReceived>(),
377 ElementsAre(IsEvent<ClientMessageReceived>(
378 Eq(ClientToServerMessageType_Abort),
379 Eq(client_message.ByteSizeLong()), Eq(false), Ge(0))));
380 }
381
TEST(SecaggServerPrngRunningStateTest,AbortReturnsValidStateAndNotifiesClients)382 TEST(SecaggServerPrngRunningStateTest,
383 AbortReturnsValidStateAndNotifiesClients) {
384 TestTracingRecorder tracing_recorder;
385 auto input_vector_specs = std::vector<InputVectorSpecification>();
386 input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
387 MockSecAggServerMetricsListener* metrics =
388 new MockSecAggServerMetricsListener();
389 auto sender = std::make_unique<MockSendToClientsInterface>();
390 FakePrng prng;
391 ShamirSecretSharing sharer;
392 auto self_shamir_share_table = std::make_unique<
393 absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
394 for (int i = 0; i < 4; ++i) {
395 self_shamir_share_table->try_emplace(
396 i, sharer.Share(
397 3, 4,
398 MakeAesKey(absl::StrCat("test 32 byte AES key for user #", i))));
399 }
400
401 auto impl =
402 CreateSecAggServerProtocolImpl(input_vector_specs, sender.get(), metrics);
403 impl->set_pairwise_shamir_share_table(
404 std::make_unique<
405 absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>());
406 impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
407
408 SecAggServerPrngRunningState state(
409 std::move(impl),
410 0, // number_of_clients_failed_after_sending_masked_input
411 0, // number_of_clients_failed_before_sending_masked_input
412 0); // number_of_clients_terminated_without_unmasking
413
414 ServerToClientWrapperMessage abort_message;
415 abort_message.mutable_abort()->set_early_success(false);
416 abort_message.mutable_abort()->set_diagnostic_info("test abort reason");
417
418 EXPECT_CALL(*metrics,
419 ProtocolOutcomes(Eq(SecAggServerOutcome::UNHANDLED_ERROR)));
420 EXPECT_CALL(*sender, SendBroadcast(EqualsProto(abort_message)));
421 auto next_state =
422 state.Abort("test abort reason", SecAggServerOutcome::UNHANDLED_ERROR);
423
424 ASSERT_THAT(next_state->State(), Eq(SecAggServerStateKind::ABORTED));
425 ASSERT_THAT(next_state->ErrorMessage().ok(), Eq(true));
426 EXPECT_THAT(next_state->ErrorMessage().value(), Eq("test abort reason"));
427 EXPECT_THAT(tracing_recorder.FindAllEvents<BroadcastMessageSent>(),
428 ElementsAre(IsEvent<BroadcastMessageSent>(
429 Eq(ServerToClientMessageType_Abort),
430 Eq(abort_message.ByteSizeLong()))));
431 }
432
TEST(SecaggServerPrngRunningStateTest,PrngGetsRightMasksWhenAllClientsSurvive)433 TEST(SecaggServerPrngRunningStateTest,
434 PrngGetsRightMasksWhenAllClientsSurvive) {
435 // First, set up necessary data for the SecAggServerPrngRunningState
436 auto input_vector_specs = std::vector<InputVectorSpecification>();
437 input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
438 auto sender = std::make_unique<MockSendToClientsInterface>();
439 FakePrng prng;
440 ShamirSecretSharing sharer;
441 auto self_shamir_share_table = std::make_unique<
442 absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
443 for (int i = 0; i < 4; ++i) {
444 self_shamir_share_table->insert(std::make_pair(
445 i, sharer.Share(3, 4,
446 MakeAesKey(absl::StrCat(
447 "test 32 byte AES key for user #", i)))));
448 }
449
450 // Generate the expected (negative) sum of masking vectors using MapofMasks.
451 std::vector<AesKey> prng_keys_to_add;
452 std::vector<AesKey> prng_keys_to_subtract;
453 for (int i = 0; i < 4; ++i) {
454 prng_keys_to_subtract.push_back(
455 MakeAesKey(absl::StrCat("test 32 byte AES key for user #", i)));
456 }
457 auto session_id = MakeTestSessionId();
458 auto expected_map_of_masks =
459 MapOfMasks(prng_keys_to_add, prng_keys_to_subtract, input_vector_specs,
460 *session_id, AesCtrPrngFactory());
461
462 auto impl = CreateSecAggServerProtocolImpl(input_vector_specs, sender.get());
463 auto zero_map = std::make_unique<SecAggUnpackedVectorMap>();
464 zero_map->emplace("foobar", SecAggUnpackedVector({0, 0, 0, 0}, 32));
465 impl->set_masked_input(std::move(zero_map));
466 impl->set_pairwise_shamir_share_table(
467 std::make_unique<
468 absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>());
469 impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
470
471 SecAggServerPrngRunningState state(
472 std::move(impl),
473 0, // number_of_clients_failed_after_sending_masked_input
474 0, // number_of_clients_failed_before_sending_masked_input
475 0); // number_of_clients_terminated_without_unmasking
476
477 MockPrngDone prng_done;
478 EXPECT_CALL(prng_done, Callback());
479
480 state.EnterState();
481 state.SetAsyncCallback([&]() { prng_done.Callback(); });
482
483 EXPECT_THAT(state.ReadyForNextRound(), Eq(true));
484
485 auto next_state = state.ProceedToNextRound();
486 ASSERT_THAT(next_state.ok(), Eq(true));
487 ASSERT_THAT(next_state.value()->State(),
488 Eq(SecAggServerStateKind::COMPLETED));
489 auto result = next_state.value()->Result();
490 ASSERT_THAT(result.ok(), Eq(true));
491 EXPECT_THAT(*result.value(),
492 testing::MatchesSecAggVectorMap(*expected_map_of_masks));
493 }
494
TEST(SecaggServerPrngRunningStateTest,PrngGetsRightMasksWithOneDeadClientAfterSendingInput)495 TEST(SecaggServerPrngRunningStateTest,
496 PrngGetsRightMasksWithOneDeadClientAfterSendingInput) {
497 // In this test, client 1 died after sending its masked input. Its input will
498 // still be included.
499 //
500 // First, set up necessary data for the SecAggServerPrngRunningState.
501 auto input_vector_specs = std::vector<InputVectorSpecification>();
502 input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
503 auto sender = std::make_unique<MockSendToClientsInterface>();
504 FakePrng prng;
505 ShamirSecretSharing sharer;
506 auto pairwise_shamir_share_table = std::make_unique<
507 absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
508 auto self_shamir_share_table = std::make_unique<
509 absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
510
511 auto aborted_client_ids = std::make_unique<absl::flat_hash_set<uint32_t>>();
512 auto impl = CreateSecAggServerProtocolImpl(input_vector_specs, sender.get());
513 impl->set_client_status(
514 1, ClientStatus::DEAD_AFTER_MASKED_INPUT_RESPONSE_RECEIVED);
515
516 aborted_client_ids->insert(1);
517
518 for (int i = 0; i < 4; ++i) {
519 self_shamir_share_table->insert(std::make_pair(
520 i, sharer.Share(3, 4,
521 MakeAesKey(absl::StrCat(
522 "test 32 byte AES key for user #", i)))));
523 // Blank out the share in position 1 because it would not have been sent.
524 (*self_shamir_share_table)[i][1] = {""};
525 }
526
527 auto zero_map = std::make_unique<SecAggUnpackedVectorMap>();
528 zero_map->insert(std::make_pair(
529 "foobar", SecAggUnpackedVector(std::vector<uint64_t>{0, 0, 0, 0}, 32)));
530 impl->set_masked_input(std::move(zero_map));
531 impl->set_pairwise_shamir_share_table(std::move(pairwise_shamir_share_table));
532 impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
533
534 // Generate the expected (negative) sum of masking vectors using MapofMasks.
535 std::vector<AesKey> prng_keys_to_add;
536 std::vector<AesKey> prng_keys_to_subtract;
537 for (int i = 0; i < 4; ++i) {
538 prng_keys_to_subtract.push_back(
539 MakeAesKey(absl::StrCat("test 32 byte AES key for user #", i)));
540 }
541 auto session_id = MakeTestSessionId();
542 auto expected_map_of_masks =
543 MapOfMasks(prng_keys_to_add, prng_keys_to_subtract, input_vector_specs,
544 *session_id, AesCtrPrngFactory());
545
546 SecAggServerPrngRunningState state(
547 std::move(impl),
548 1, // number_of_clients_failed_after_sending_masked_input
549 0, // number_of_clients_failed_before_sending_masked_input
550 1); // number_of_clients_terminated_without_unmasking
551
552 MockPrngDone prng_done;
553 EXPECT_CALL(prng_done, Callback());
554
555 state.EnterState();
556 state.SetAsyncCallback([&]() { prng_done.Callback(); });
557
558 EXPECT_THAT(state.ReadyForNextRound(), Eq(true));
559
560 auto next_state = state.ProceedToNextRound();
561 ASSERT_THAT(next_state.ok(), Eq(true));
562 ASSERT_THAT(next_state.value()->State(),
563 Eq(SecAggServerStateKind::COMPLETED));
564 auto result = next_state.value()->Result();
565 ASSERT_THAT(result.ok(), Eq(true));
566 EXPECT_THAT(*result.value(),
567 testing::MatchesSecAggVectorMap(*expected_map_of_masks));
568 }
569
TEST(SecaggServerPrngRunningStateTest,PrngGetsRightMasksWithOneDeadClientBeforeSendingInput)570 TEST(SecaggServerPrngRunningStateTest,
571 PrngGetsRightMasksWithOneDeadClientBeforeSendingInput) {
572 // In this test, client 1 died before sending its masked input but after other
573 // clients computed theirs, so its pairwise key will need to be canceled out.
574 //
575 // First, set up necessary data for the SecAggServerPrngRunningState.
576 auto input_vector_specs = std::vector<InputVectorSpecification>();
577 input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
578 auto sender = std::make_unique<MockSendToClientsInterface>();
579 FakePrng prng;
580 ShamirSecretSharing sharer;
581 auto pairwise_shamir_share_table = std::make_unique<
582 absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
583 auto self_shamir_share_table = std::make_unique<
584 absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
585
586 auto impl = CreateSecAggServerProtocolImpl(input_vector_specs, sender.get());
587 impl->set_client_status(1, ClientStatus::DEAD_AFTER_SHARE_KEYS_RECEIVED);
588
589 auto aborted_client_ids = std::make_unique<absl::flat_hash_set<uint32_t>>();
590 aborted_client_ids->insert(1);
591
592 EcdhPregeneratedTestKeys ecdh_keys;
593 for (int i = 0; i < 4; ++i) {
594 if (i == 1) {
595 // Client 1 died in the previous step, so the other clients will have sent
596 // shares of its pairwise key instead.
597 pairwise_shamir_share_table->insert(
598 std::make_pair(i, sharer.Share(3, 4, ecdh_keys.GetPrivateKey(i))));
599 // Blank out the share in position 1 because it would not have been sent.
600 (*pairwise_shamir_share_table)[i][1] = {""};
601 } else {
602 self_shamir_share_table->insert(std::make_pair(
603 i, sharer.Share(3, 4,
604 MakeAesKey(absl::StrCat(
605 "test 32 byte AES key for user #", i)))));
606 // Blank out the share in position 1 because it would not have been sent.
607 (*self_shamir_share_table)[i][1] = {""};
608 }
609 }
610
611 auto zero_map = std::make_unique<SecAggUnpackedVectorMap>();
612 zero_map->emplace("foobar", SecAggUnpackedVector({0, 0, 0, 0}, 32));
613 impl->set_masked_input(std::move(zero_map));
614 impl->set_pairwise_shamir_share_table(std::move(pairwise_shamir_share_table));
615 impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
616
617 // Generate the expected (negative) sum of masking vectors using MapofMasks.
618 // We should subtract the self masks of clients 0, 2, and 3. We should
619 // subtract the pairwise mask 2 and 3 added for 1, and add the pairwise mask
620 // that 0 subtracted for 1.
621 auto aborted_client_key_agreement =
622 EcdhKeyAgreement::CreateFromPrivateKey(ecdh_keys.GetPrivateKey(1));
623 std::vector<AesKey> prng_keys_to_add;
624 std::vector<AesKey> prng_keys_to_subtract;
625 for (int i = 0; i < 4; ++i) {
626 if (i == 1) {
627 continue;
628 }
629 prng_keys_to_subtract.push_back(
630 MakeAesKey(absl::StrCat("test 32 byte AES key for user #", i)));
631 AesKey pairwise_key = aborted_client_key_agreement.value()
632 ->ComputeSharedSecret(ecdh_keys.GetPublicKey(i))
633 .value();
634 if (i == 0) {
635 prng_keys_to_add.push_back(pairwise_key);
636 } else {
637 prng_keys_to_subtract.push_back(pairwise_key);
638 }
639 }
640 auto session_id = MakeTestSessionId();
641 auto expected_map_of_masks =
642 MapOfMasks(prng_keys_to_add, prng_keys_to_subtract, input_vector_specs,
643 *session_id, AesCtrPrngFactory());
644
645 SecAggServerPrngRunningState state(
646 std::move(impl),
647 0, // number_of_clients_failed_after_sending_masked_input
648 1, // number_of_clients_failed_before_sending_masked_input
649 0); // number_of_clients_terminated_without_unmasking
650
651 MockPrngDone prng_done;
652 EXPECT_CALL(prng_done, Callback());
653
654 state.EnterState();
655 state.SetAsyncCallback([&]() { prng_done.Callback(); });
656
657 EXPECT_THAT(state.ReadyForNextRound(), Eq(true));
658
659 auto next_state = state.ProceedToNextRound();
660 ASSERT_THAT(next_state.ok(), Eq(true));
661 ASSERT_THAT(next_state.value()->State(),
662 Eq(SecAggServerStateKind::COMPLETED));
663 auto result = next_state.value()->Result();
664 ASSERT_THAT(result.ok(), Eq(true));
665 EXPECT_THAT(*result.value(),
666 testing::MatchesSecAggVectorMap(*expected_map_of_masks));
667 }
668
TEST(SecaggServerPrngRunningStateTest,PrngGetsRightMasksAndCallsCallbackIfSpecified)669 TEST(SecaggServerPrngRunningStateTest,
670 PrngGetsRightMasksAndCallsCallbackIfSpecified) {
671 // In this test, there is now a callback that should be called when the PRNG
672 // is done running.
673 //
674 // First, set up necessary data for the SecAggServerPrngRunningState.
675 auto input_vector_specs = std::vector<InputVectorSpecification>();
676 input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
677 auto sender = std::make_unique<MockSendToClientsInterface>();
678 FakePrng prng;
679 ShamirSecretSharing sharer;
680 auto pairwise_shamir_share_table = std::make_unique<
681 absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
682 auto self_shamir_share_table = std::make_unique<
683 absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
684
685 auto impl = CreateSecAggServerProtocolImpl(input_vector_specs, sender.get());
686 impl->set_client_status(1, ClientStatus::DEAD_AFTER_SHARE_KEYS_RECEIVED);
687
688 auto aborted_client_ids = std::make_unique<absl::flat_hash_set<uint32_t>>();
689 aborted_client_ids->insert(1);
690
691 EcdhPregeneratedTestKeys ecdh_keys;
692 for (int i = 0; i < 4; ++i) {
693 if (i == 1) {
694 // Client 1 died in the previous step, so the other clients will have sent
695 // shares of its pairwise key instead.
696 pairwise_shamir_share_table->insert(
697 std::make_pair(i, sharer.Share(3, 4, ecdh_keys.GetPrivateKey(i))));
698 // Blank out the share in position 1 because it would not have been sent.
699 (*pairwise_shamir_share_table)[i][1] = {""};
700 } else {
701 self_shamir_share_table->insert(std::make_pair(
702 i, sharer.Share(3, 4,
703 MakeAesKey(absl::StrCat(
704 "test 32 byte AES key for user #", i)))));
705 // Blank out the share in position 1 because it would not have been sent.
706 (*self_shamir_share_table)[i][1] = {""};
707 }
708 }
709
710 auto zero_map = std::make_unique<SecAggUnpackedVectorMap>();
711 zero_map->emplace("foobar", SecAggUnpackedVector({0, 0, 0, 0}, 32));
712 impl->set_masked_input(std::move(zero_map));
713 impl->set_pairwise_shamir_share_table(std::move(pairwise_shamir_share_table));
714 impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
715
716 // Generate the expected (negative) sum of masking vectors using MapofMasks.
717 // We should subtract the self masks of clients 0, 2, and 3. We should
718 // subtract the pairwise mask 2 and 3 added for 1, and add the pairwise mask
719 // that 0 subtracted for 1.
720 auto aborted_client_key_agreement =
721 EcdhKeyAgreement::CreateFromPrivateKey(ecdh_keys.GetPrivateKey(1));
722 std::vector<AesKey> prng_keys_to_add;
723 std::vector<AesKey> prng_keys_to_subtract;
724 for (int i = 0; i < 4; ++i) {
725 if (i == 1) {
726 continue;
727 }
728 prng_keys_to_subtract.push_back(
729 MakeAesKey(absl::StrCat("test 32 byte AES key for user #", i)));
730 AesKey pairwise_key = aborted_client_key_agreement.value()
731 ->ComputeSharedSecret(ecdh_keys.GetPublicKey(i))
732 .value();
733 if (i == 0) {
734 prng_keys_to_add.push_back(pairwise_key);
735 } else {
736 prng_keys_to_subtract.push_back(pairwise_key);
737 }
738 }
739 auto session_id = MakeTestSessionId();
740 auto expected_map_of_masks =
741 MapOfMasks(prng_keys_to_add, prng_keys_to_subtract, input_vector_specs,
742 *session_id, AesCtrPrngFactory());
743
744 SecAggServerPrngRunningState state(
745 std::move(impl),
746 0, // number_of_clients_failed_after_sending_masked_input
747 1, // number_of_clients_failed_before_sending_masked_input
748 0); // number_of_clients_terminated_without_unmasking
749
750 MockPrngDone prng_done;
751 EXPECT_CALL(prng_done, Callback());
752
753 state.EnterState();
754 state.SetAsyncCallback([&]() { prng_done.Callback(); });
755
756 EXPECT_THAT(state.ReadyForNextRound(), Eq(true));
757
758 auto next_state = state.ProceedToNextRound();
759 ASSERT_THAT(next_state.ok(), Eq(true));
760 ASSERT_THAT(next_state.value()->State(),
761 Eq(SecAggServerStateKind::COMPLETED));
762 auto result = next_state.value()->Result();
763 ASSERT_THAT(result.ok(), Eq(true));
764 EXPECT_THAT(*result.value(),
765 testing::MatchesSecAggVectorMap(*expected_map_of_masks));
766 }
767
TEST(SecaggServerPrngRunningStateTest,SetAsyncCallbackCanBeCalledTwice)768 TEST(SecaggServerPrngRunningStateTest, SetAsyncCallbackCanBeCalledTwice) {
769 // StartPrng should have the property that it can be called after it has
770 // already run successfully without any problems. It should just return OK
771 // again.
772 //
773 // First, set up necessary data for the SecAggServerPrngRunningState.
774 auto input_vector_specs = std::vector<InputVectorSpecification>();
775 input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
776 auto sender = std::make_unique<MockSendToClientsInterface>();
777 FakePrng prng;
778 ShamirSecretSharing sharer;
779 auto pairwise_shamir_share_table = std::make_unique<
780 absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
781 auto self_shamir_share_table = std::make_unique<
782 absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
783
784 auto impl = CreateSecAggServerProtocolImpl(input_vector_specs, sender.get());
785 impl->set_client_status(1, ClientStatus::DEAD_AFTER_SHARE_KEYS_RECEIVED);
786
787 auto aborted_client_ids = std::make_unique<absl::flat_hash_set<uint32_t>>();
788 aborted_client_ids->insert(1);
789
790 EcdhPregeneratedTestKeys ecdh_keys;
791 for (int i = 0; i < 4; ++i) {
792 if (i == 1) {
793 // Client 1 died in the previous step, so the other clients will have sent
794 // shares of its pairwise key instead.
795 pairwise_shamir_share_table->insert(
796 std::make_pair(i, sharer.Share(3, 4, ecdh_keys.GetPrivateKey(i))));
797 // Blank out the share in position 1 because it would not have been sent.
798 (*pairwise_shamir_share_table)[i][1] = {""};
799 } else {
800 self_shamir_share_table->insert(std::make_pair(
801 i, sharer.Share(3, 4,
802 MakeAesKey(absl::StrCat(
803 "test 32 byte AES key for user #", i)))));
804 // Blank out the share in position 1 because it would not have been sent.
805 (*self_shamir_share_table)[i][1] = {""};
806 }
807 }
808
809 auto zero_map = std::make_unique<SecAggUnpackedVectorMap>();
810 zero_map->emplace("foobar", SecAggUnpackedVector({0, 0, 0, 0}, 32));
811 impl->set_masked_input(std::move(zero_map));
812 impl->set_pairwise_shamir_share_table(std::move(pairwise_shamir_share_table));
813 impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
814
815 // Generate the expected (negative) sum of masking vectors using MapofMasks.
816 // We should subtract the self masks of clients 0, 2, and 3. We should
817 // subtract the pairwise mask 2 and 3 added for 1, and add the pairwise mask
818 // that 0 subtracted for 1.
819 auto aborted_client_key_agreement =
820 EcdhKeyAgreement::CreateFromPrivateKey(ecdh_keys.GetPrivateKey(1));
821 std::vector<AesKey> prng_keys_to_add;
822 std::vector<AesKey> prng_keys_to_subtract;
823 for (int i = 0; i < 4; ++i) {
824 if (i == 1) {
825 continue;
826 }
827 prng_keys_to_subtract.push_back(
828 MakeAesKey(absl::StrCat("test 32 byte AES key for user #", i)));
829 AesKey pairwise_key = aborted_client_key_agreement.value()
830 ->ComputeSharedSecret(ecdh_keys.GetPublicKey(i))
831 .value();
832 if (i == 0) {
833 prng_keys_to_add.push_back(pairwise_key);
834 } else {
835 prng_keys_to_subtract.push_back(pairwise_key);
836 }
837 }
838 auto session_id = MakeTestSessionId();
839 auto expected_map_of_masks =
840 MapOfMasks(prng_keys_to_add, prng_keys_to_subtract, input_vector_specs,
841 *session_id, AesCtrPrngFactory());
842
843 SecAggServerPrngRunningState state(
844 std::move(impl),
845 0, // number_of_clients_failed_after_sending_masked_input
846 1, // number_of_clients_failed_before_sending_masked_input
847 0); // number_of_clients_terminated_without_unmasking
848
849 MockPrngDone prng_done;
850 EXPECT_CALL(prng_done, Callback());
851
852 state.EnterState();
853 state.SetAsyncCallback([&]() { prng_done.Callback(); });
854
855 EXPECT_THAT(state.ReadyForNextRound(), Eq(true));
856
857 // Make sure we can call SetAsyncCallback again.
858 MockPrngDone prng_done_2;
859 EXPECT_CALL(prng_done_2, Callback());
860 state.SetAsyncCallback([&]() { prng_done_2.Callback(); });
861
862 auto next_state = state.ProceedToNextRound();
863 ASSERT_THAT(next_state.ok(), Eq(true));
864 ASSERT_THAT(next_state.value()->State(),
865 Eq(SecAggServerStateKind::COMPLETED));
866 auto result = next_state.value()->Result();
867 ASSERT_THAT(result.ok(), Eq(true));
868 EXPECT_THAT(*result.value(),
869 testing::MatchesSecAggVectorMap(*expected_map_of_masks));
870 }
871
TEST(SecaggServerPrngRunningStateTest,PrngGetsRightMasksWhenClientsUse16BSelfKeys)872 TEST(SecaggServerPrngRunningStateTest,
873 PrngGetsRightMasksWhenClientsUse16BSelfKeys) {
874 // TODO(team): This test is only for ensuring Java compatibility.
875 // First, set up necessary data for the SecAggServerPrngRunningState
876 auto input_vector_specs = std::vector<InputVectorSpecification>();
877 input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
878 auto sender = std::make_unique<MockSendToClientsInterface>();
879 FakePrng prng;
880 ShamirSecretSharing sharer;
881 auto self_shamir_share_table = std::make_unique<
882 absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
883 for (int i = 0; i < 4; ++i) {
884 self_shamir_share_table->insert(std::make_pair(
885 i, sharer.Share(3, 4,
886 AesKey(reinterpret_cast<const uint8_t*>(
887 absl::StrCat("16B key of user", i).c_str()),
888 16))));
889 }
890
891 // Generate the expected (negative) sum of masking vectors using MapofMasks.
892 std::vector<AesKey> prng_keys_to_add;
893 std::vector<AesKey> prng_keys_to_subtract;
894 for (int i = 0; i < 4; ++i) {
895 prng_keys_to_subtract.push_back(
896 AesKey(reinterpret_cast<const uint8_t*>(
897 absl::StrCat("16B key of user", i).c_str()),
898 16));
899 }
900 auto session_id = MakeTestSessionId();
901 auto expected_map_of_masks =
902 MapOfMasks(prng_keys_to_add, prng_keys_to_subtract, input_vector_specs,
903 *session_id, AesCtrPrngFactory());
904
905 auto impl = CreateSecAggServerProtocolImpl(input_vector_specs, sender.get());
906 auto zero_map = std::make_unique<SecAggUnpackedVectorMap>();
907 zero_map->emplace("foobar", SecAggUnpackedVector({0, 0, 0, 0}, 32));
908 impl->set_masked_input(std::move(zero_map));
909 impl->set_pairwise_shamir_share_table(
910 std::make_unique<
911 absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>());
912 impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
913
914 SecAggServerPrngRunningState state(
915 std::move(impl),
916 0, // number_of_clients_failed_after_sending_masked_input
917 0, // number_of_clients_failed_before_sending_masked_input
918 0); // number_of_clients_terminated_without_unmasking
919
920 MockPrngDone prng_done;
921 EXPECT_CALL(prng_done, Callback());
922
923 state.EnterState();
924 state.SetAsyncCallback([&]() { prng_done.Callback(); });
925
926 EXPECT_THAT(state.ReadyForNextRound(), Eq(true));
927
928 auto next_state = state.ProceedToNextRound();
929 ASSERT_THAT(next_state.ok(), Eq(true));
930 ASSERT_THAT(next_state.value()->State(),
931 Eq(SecAggServerStateKind::COMPLETED));
932 auto result = next_state.value()->Result();
933 ASSERT_THAT(result.ok(), Eq(true));
934 EXPECT_THAT(*result.value(),
935 testing::MatchesSecAggVectorMap(*expected_map_of_masks));
936 }
937
TEST(SecaggServerPrngRunningStateTest,TimingMetricsAreRecorded)938 TEST(SecaggServerPrngRunningStateTest, TimingMetricsAreRecorded) {
939 // First, set up necessary data for the SecAggServerPrngRunningState
940 TestTracingRecorder tracing_recorder;
941 auto input_vector_specs = std::vector<InputVectorSpecification>();
942 input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
943 MockSecAggServerMetricsListener* metrics =
944 new MockSecAggServerMetricsListener();
945 auto sender = std::make_unique<MockSendToClientsInterface>();
946 FakePrng prng;
947 ShamirSecretSharing sharer;
948 auto self_shamir_share_table = std::make_unique<
949 absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
950 for (int i = 0; i < 4; ++i) {
951 self_shamir_share_table->insert(std::make_pair(
952 i, sharer.Share(3, 4,
953 MakeAesKey(absl::StrCat(
954 "test 32 byte AES key for user #", i)))));
955 }
956
957 auto impl =
958 CreateSecAggServerProtocolImpl(input_vector_specs, sender.get(), metrics);
959 auto zero_map = std::make_unique<SecAggUnpackedVectorMap>();
960 zero_map->emplace("foobar", SecAggUnpackedVector({0, 0, 0, 0}, 32));
961 impl->set_masked_input(std::move(zero_map));
962 impl->set_pairwise_shamir_share_table(
963 std::make_unique<
964 absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>());
965 impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
966
967 SecAggServerPrngRunningState state(
968 std::move(impl),
969 0, // number_of_clients_failed_after_sending_masked_input
970 0, // number_of_clients_failed_before_sending_masked_input
971 0); // number_of_clients_terminated_without_unmasking
972
973 MockPrngDone prng_done;
974 EXPECT_CALL(prng_done, Callback());
975
976 EXPECT_CALL(*metrics, PrngExpansionTimes(Ge(0)));
977 EXPECT_CALL(*metrics, RoundTimes(Eq(SecAggServerStateKind::PRNG_RUNNING),
978 Eq(true), Ge(0)));
979 EXPECT_CALL(*metrics, ShamirReconstructionTimes(Ge(0)));
980
981 state.EnterState();
982 state.SetAsyncCallback([&]() { prng_done.Callback(); });
983 EXPECT_THAT(state.ReadyForNextRound(), Eq(true));
984
985 auto next_state = state.ProceedToNextRound();
986 ASSERT_THAT(next_state.ok(), Eq(true));
987 ASSERT_THAT(next_state.value()->State(),
988 Eq(SecAggServerStateKind::COMPLETED));
989 EXPECT_THAT(tracing_recorder.FindAllEvents<ShamirReconstruction>(),
990 ElementsAre(IsEvent<ShamirReconstruction>(Ge(0))));
991 EXPECT_THAT(tracing_recorder.FindAllEvents<PrngExpansion>(),
992 ElementsAre(IsEvent<PrngExpansion>(Ge(0))));
993 }
994
995 } // namespace
996 } // namespace secagg
997 } // namespace fcp
998