1 /*
2  * Copyright 2018 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     https://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/secagg/server/secagg_server_r3_unmasking_state.h"
18 
19 #include <memory>
20 #include <utility>
21 
22 #include "gmock/gmock.h"
23 #include "gtest/gtest.h"
24 #include "absl/strings/str_cat.h"
25 #include "fcp/base/monitoring.h"
26 #include "fcp/secagg/server/aes/aes_secagg_server_protocol_impl.h"
27 #include "fcp/secagg/server/secagg_server_state.h"
28 #include "fcp/secagg/server/secret_sharing_graph_factory.h"
29 #include "fcp/secagg/shared/aes_ctr_prng_factory.h"
30 #include "fcp/secagg/shared/ecdh_keys.h"
31 #include "fcp/secagg/shared/input_vector_specification.h"
32 #include "fcp/secagg/shared/secagg_messages.pb.h"
33 #include "fcp/secagg/testing/ecdh_pregenerated_test_keys.h"
34 #include "fcp/secagg/testing/fake_prng.h"
35 #include "fcp/secagg/testing/server/mock_secagg_server_metrics_listener.h"
36 #include "fcp/secagg/testing/server/mock_send_to_clients_interface.h"
37 #include "fcp/testing/testing.h"
38 #include "fcp/tracing/test_tracing_recorder.h"
39 
40 namespace fcp {
41 namespace secagg {
42 namespace {
43 
44 using ::testing::_;
45 using ::testing::Eq;
46 using ::testing::Ge;
47 using ::testing::IsFalse;
48 using ::testing::IsTrue;
49 using ::testing::Ne;
50 
51 // Default test session_id.
52 SessionId session_id = {"session id number, 32 bytes long"};
53 
CreateSecAggServerProtocolImpl(int minimum_number_of_clients_to_proceed,int total_number_of_clients,MockSendToClientsInterface * sender,MockSecAggServerMetricsListener * metrics_listener=nullptr)54 std::unique_ptr<AesSecAggServerProtocolImpl> CreateSecAggServerProtocolImpl(
55     int minimum_number_of_clients_to_proceed, int total_number_of_clients,
56     MockSendToClientsInterface* sender,
57     MockSecAggServerMetricsListener* metrics_listener = nullptr) {
58   auto input_vector_specs = std::vector<InputVectorSpecification>();
59   input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
60   SecretSharingGraphFactory factory;
61   auto impl = std::make_unique<AesSecAggServerProtocolImpl>(
62       factory.CreateCompleteGraph(total_number_of_clients,
63                                   minimum_number_of_clients_to_proceed),
64       minimum_number_of_clients_to_proceed, input_vector_specs,
65       std::unique_ptr<MockSecAggServerMetricsListener>(metrics_listener),
66       std::make_unique<AesCtrPrngFactory>(), sender,
67       nullptr,  // prng_runner
68       std::vector<ClientStatus>(total_number_of_clients,
69                                 ClientStatus::MASKED_INPUT_RESPONSE_RECEIVED),
70       ServerVariant::NATIVE_V1);
71   impl->set_session_id(std::make_unique<SessionId>(session_id));
72   EcdhPregeneratedTestKeys ecdh_keys;
73 
74   for (int i = 0; i < total_number_of_clients; ++i) {
75     impl->SetPairwisePublicKeys(i, ecdh_keys.GetPublicKey(i));
76   }
77 
78   return impl;
79 }
80 
TEST(SecaggServerR3UnmaskingStateTest,IsAbortedReturnsFalse)81 TEST(SecaggServerR3UnmaskingStateTest, IsAbortedReturnsFalse) {
82   auto sender = std::make_unique<MockSendToClientsInterface>();
83 
84   SecAggServerR3UnmaskingState state(
85       CreateSecAggServerProtocolImpl(3, 4, sender.get()),
86       0,   // number_of_clients_failed_after_sending_masked_input
87       0,   // number_of_clients_failed_before_sending_masked_input
88       0);  // number_of_clients_terminated_without_unmasking
89 
90   EXPECT_THAT(state.IsAborted(), IsFalse());
91 }
92 
TEST(SecaggServerR3UnmaskingStateTest,IsCompletedSuccessfullyReturnsFalse)93 TEST(SecaggServerR3UnmaskingStateTest, IsCompletedSuccessfullyReturnsFalse) {
94   auto sender = std::make_unique<MockSendToClientsInterface>();
95 
96   SecAggServerR3UnmaskingState state(
97       CreateSecAggServerProtocolImpl(3, 4, sender.get()),
98       0,   // number_of_clients_failed_after_sending_masked_input
99       0,   // number_of_clients_failed_before_sending_masked_input
100       0);  // number_of_clients_terminated_without_unmasking
101 
102   EXPECT_THAT(state.IsCompletedSuccessfully(), IsFalse());
103 }
104 
TEST(SecaggServerR3UnmaskingStateTest,ErrorMessageRaisesErrorStatus)105 TEST(SecaggServerR3UnmaskingStateTest, ErrorMessageRaisesErrorStatus) {
106   auto sender = std::make_unique<MockSendToClientsInterface>();
107 
108   SecAggServerR3UnmaskingState state(
109       CreateSecAggServerProtocolImpl(3, 4, sender.get()),
110       0,   // number_of_clients_failed_after_sending_masked_input
111       0,   // number_of_clients_failed_before_sending_masked_input
112       0);  // number_of_clients_terminated_without_unmasking
113 
114   EXPECT_THAT(state.ErrorMessage().ok(), IsFalse());
115 }
116 
TEST(SecaggServerR3UnmaskingStateTest,ResultRaisesErrorStatus)117 TEST(SecaggServerR3UnmaskingStateTest, ResultRaisesErrorStatus) {
118   auto sender = std::make_unique<MockSendToClientsInterface>();
119 
120   SecAggServerR3UnmaskingState state(
121       CreateSecAggServerProtocolImpl(3, 4, sender.get()),
122       0,   // number_of_clients_failed_after_sending_masked_input
123       0,   // number_of_clients_failed_before_sending_masked_input
124       0);  // number_of_clients_terminated_without_unmasking
125 
126   EXPECT_THAT(state.Result().ok(), IsFalse());
127 }
128 
TEST(SecaggServerR3UnmaskingStateTest,AbortClientAfterUnmaskingResponseReceived)129 TEST(SecaggServerR3UnmaskingStateTest,
130      AbortClientAfterUnmaskingResponseReceived) {
131   auto sender = std::make_unique<MockSendToClientsInterface>();
132   MockSecAggServerMetricsListener* metrics =
133       new MockSecAggServerMetricsListener();
134   auto impl = CreateSecAggServerProtocolImpl(3, 4, sender.get(), metrics);
135   impl->set_client_status(2, ClientStatus::UNMASKING_RESPONSE_RECEIVED);
136   SecAggServerR3UnmaskingState state(
137       std::move(impl),
138       0,   // number_of_clients_failed_after_sending_masked_input
139       0,   // number_of_clients_failed_before_sending_masked_input
140       0);  // number_of_clients_terminated_without_unmasking
141 
142   state.AbortClient(2, "close client message.",
143                     ClientDropReason::SENT_ABORT_MESSAGE, false);
144   ASSERT_THAT(state.NumberOfClientsFailedAfterSendingMaskedInput(), Eq(0));
145   // Metrics are not logged
146   EXPECT_CALL(*metrics, ClientsDropped(_, _)).Times(0);
147   // Client is not notified
148   EXPECT_CALL(*sender, Send(_, _)).Times(0);
149   ASSERT_THAT(state.AbortedClientIds().contains(2), Eq(true));
150 }
151 
TEST(SecaggServerR3UnmaskingStateTest,AbortReturnsValidStateAndNotifiesClients)152 TEST(SecaggServerR3UnmaskingStateTest,
153      AbortReturnsValidStateAndNotifiesClients) {
154   TestTracingRecorder tracing_recorder;
155   MockSecAggServerMetricsListener* metrics =
156       new MockSecAggServerMetricsListener();
157   auto sender = std::make_unique<MockSendToClientsInterface>();
158 
159   SecAggServerR3UnmaskingState state(
160       CreateSecAggServerProtocolImpl(3, 4, sender.get(), metrics),
161       0,   // number_of_clients_failed_after_sending_masked_input
162       0,   // number_of_clients_failed_before_sending_masked_input
163       0);  // number_of_clients_terminated_without_unmasking
164 
165   ServerToClientWrapperMessage abort_message;
166   abort_message.mutable_abort()->set_early_success(false);
167   abort_message.mutable_abort()->set_diagnostic_info("test abort reason");
168 
169   EXPECT_CALL(*metrics,
170               ProtocolOutcomes(Eq(SecAggServerOutcome::EXTERNAL_REQUEST)));
171   EXPECT_CALL(*sender, SendBroadcast(EqualsProto(abort_message)));
172   auto next_state =
173       state.Abort("test abort reason", SecAggServerOutcome::EXTERNAL_REQUEST);
174 
175   ASSERT_THAT(next_state->State(), Eq(SecAggServerStateKind::ABORTED));
176   ASSERT_THAT(next_state->ErrorMessage(), IsOk());
177   EXPECT_THAT(next_state->ErrorMessage().value(), Eq("test abort reason"));
178   EXPECT_THAT(tracing_recorder.FindAllEvents<BroadcastMessageSent>(),
179               ElementsAre(IsEvent<BroadcastMessageSent>(
180                   Eq(ServerToClientMessageType_Abort),
181                   Eq(abort_message.ByteSizeLong()))));
182 }
183 
TEST(SecaggServerR3UnmaskingStateTest,StateProceedsCorrectlyWithNoAbortsAndAllCorrectMessagesReceived)184 TEST(SecaggServerR3UnmaskingStateTest,
185      StateProceedsCorrectlyWithNoAbortsAndAllCorrectMessagesReceived) {
186   // In this test, no clients abort or aborted at any point, and all four
187   // clients send unmasking responses to the server before ProceedToNextRound is
188   // called.
189   auto sender = std::make_unique<MockSendToClientsInterface>();
190 
191   SecAggServerR3UnmaskingState state(
192       CreateSecAggServerProtocolImpl(3, 4, sender.get()),
193       0,   // number_of_clients_failed_after_sending_masked_input
194       0,   // number_of_clients_failed_before_sending_masked_input
195       0);  // number_of_clients_terminated_without_unmasking
196 
197   // Set up correct responses
198   std::vector<ClientToServerWrapperMessage> unmasking_responses(4);
199   for (int i = 0; i < 4; ++i) {
200     for (int j = 0; j < 4; ++j) {
201       NoiseOrPrfKeyShare* share = unmasking_responses[i]
202                                       .mutable_unmasking_response()
203                                       ->add_noise_or_prf_key_shares();
204       share->set_prf_sk_share(
205           absl::StrCat("Test key share for client ", j, " from client ", i));
206     }
207   }
208 
209   // No clients should actually get a message in this round.
210   EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
211   EXPECT_CALL(*sender, Send(_, _)).Times(0);
212 
213   // i is the number of messages received so far
214   for (int i = 0; i <= 4; ++i) {
215     EXPECT_THAT(state.NeedsToAbort(), IsFalse());
216     EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
217     EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
218     EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
219     EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
220     if (i < 3) {
221       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
222       EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
223     } else {
224       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
225       EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
226     }
227     if (i < 4) {
228       ASSERT_THAT(state.HandleMessage(i, unmasking_responses[i]), IsOk());
229       EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
230     }
231   }
232 
233   auto next_state = state.ProceedToNextRound();
234   ASSERT_THAT(next_state, IsOk());
235   EXPECT_THAT(next_state.value()->State(),
236               Eq(SecAggServerStateKind::PRNG_RUNNING));
237   EXPECT_THAT(
238       next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
239       Eq(0));
240   EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
241               Eq(0));
242 }
243 
TEST(SecaggServerR3UnmaskingStateTest,StateProceedsCorrectlyWithMinimumCorrectMessagesReceived)244 TEST(SecaggServerR3UnmaskingStateTest,
245      StateProceedsCorrectlyWithMinimumCorrectMessagesReceived) {
246   // In this test, no clients abort or aborted at any point, but
247   // ProceedToNextRound is called after only 3 clients have submitted masked
248   // input responses. This is perfectly valid because the threshold is 3.
249   auto sender = std::make_unique<MockSendToClientsInterface>();
250   MockSecAggServerMetricsListener* metrics =
251       new MockSecAggServerMetricsListener();
252 
253   SecAggServerR3UnmaskingState state(
254       CreateSecAggServerProtocolImpl(3, 4, sender.get(), metrics),
255       0,   // number_of_clients_failed_after_sending_masked_input
256       0,   // number_of_clients_failed_before_sending_masked_input
257       0);  // number_of_clients_terminated_without_unmasking
258 
259   // Set up correct responses
260   std::vector<ClientToServerWrapperMessage> unmasking_responses(4);
261   for (int i = 0; i < 4; ++i) {
262     for (int j = 0; j < 4; ++j) {
263       NoiseOrPrfKeyShare* share = unmasking_responses[i]
264                                       .mutable_unmasking_response()
265                                       ->add_noise_or_prf_key_shares();
266       share->set_prf_sk_share(
267           absl::StrCat("Test key share for client ", j, " from client ", i));
268     }
269   }
270 
271   // Only client 3 should get a message this round.
272   ServerToClientWrapperMessage abort_message;
273   abort_message.mutable_abort()->set_early_success(true);
274   abort_message.mutable_abort()->set_diagnostic_info(
275       "Client did not send unmasking response but protocol completed "
276       "successfully.");
277   EXPECT_CALL(*sender, Send(3, EqualsProto(abort_message))).Times(1);
278   EXPECT_CALL(*sender, Send(Ne(3), _)).Times(0);
279   EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
280   EXPECT_CALL(*metrics,
281               ClientsDropped(
282                   Eq(ClientStatus::DEAD_AFTER_MASKED_INPUT_RESPONSE_RECEIVED),
283                   Eq(ClientDropReason::EARLY_SUCCESS)));
284 
285   // i is the number of messages received so far. Stop after 3
286   for (int i = 0; i <= 3; ++i) {
287     EXPECT_THAT(state.NeedsToAbort(), IsFalse());
288     EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
289     EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
290     EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
291     EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
292     if (i < 3) {
293       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
294       EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
295     } else {
296       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
297       EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
298     }
299     if (i < 3) {
300       ASSERT_THAT(state.HandleMessage(i, unmasking_responses[i]), IsOk());
301       EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
302     }
303   }
304 
305   auto next_state = state.ProceedToNextRound();
306   ASSERT_THAT(next_state, IsOk());
307   EXPECT_THAT(next_state.value()->State(),
308               Eq(SecAggServerStateKind::PRNG_RUNNING));
309   EXPECT_THAT(
310       next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
311       Eq(0));
312   EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
313               Eq(1));
314 }
315 
TEST(SecaggServerR3UnmaskingStateTest,StateProceedsCorrectlyWithOneFailure)316 TEST(SecaggServerR3UnmaskingStateTest, StateProceedsCorrectlyWithOneFailure) {
317   // In this test, no clients abort or aborted at any point, but client 0 sends
318   // an invalid message. It should be aborted, but the other 3 clients should be
319   // enough to proceed.
320   auto sender = std::make_unique<MockSendToClientsInterface>();
321 
322   SecAggServerR3UnmaskingState state(
323       CreateSecAggServerProtocolImpl(3, 4, sender.get()),
324       0,   // number_of_clients_failed_after_sending_masked_input
325       0,   // number_of_clients_failed_before_sending_masked_input
326       0);  // number_of_clients_terminated_without_unmasking
327 
328   // Set up correct responses
329   std::vector<ClientToServerWrapperMessage> unmasking_responses(4);
330   for (int i = 0; i < 4; ++i) {
331     for (int j = 0; j < 4; ++j) {
332       NoiseOrPrfKeyShare* share = unmasking_responses[i]
333                                       .mutable_unmasking_response()
334                                       ->add_noise_or_prf_key_shares();
335       share->set_prf_sk_share(
336           absl::StrCat("Test key share for client ", j, " from client ", i));
337     }
338   }
339   // Add an incorrect response.
340   unmasking_responses[0]
341       .mutable_unmasking_response()
342       ->mutable_noise_or_prf_key_shares(2)
343       ->set_noise_sk_share("This is the wrong type of share!");
344 
345   // Only client 0 should get a message this round.
346   ServerToClientWrapperMessage abort_message;
347   abort_message.mutable_abort()->set_diagnostic_info(
348       "Client did not include the correct type of key share.");
349   abort_message.mutable_abort()->set_early_success(false);
350   EXPECT_CALL(*sender, Send(0, EqualsProto(abort_message))).Times(1);
351   EXPECT_CALL(*sender, Send(Ne(0), _)).Times(0);
352   EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
353 
354   EXPECT_THAT(state.HandleMessage(0, unmasking_responses[0]), IsOk());
355   EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
356   EXPECT_THAT(state.NumberOfClientsFailedAfterSendingMaskedInput(), Eq(1));
357   EXPECT_THAT(state.AbortedClientIds().contains(0), IsTrue());
358 
359   // i is the number of messages received so far.
360   for (int i = 1; i <= 4; ++i) {
361     EXPECT_THAT(state.NeedsToAbort(), IsFalse());
362     EXPECT_THAT(state.NumberOfAliveClients(), Eq(3));
363     EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i - 1));
364     EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i - 1));
365     EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
366     if (i < 4) {
367       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(4 - i));
368       EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
369     } else {
370       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
371       EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
372     }
373     if (i < 4) {
374       ASSERT_THAT(state.HandleMessage(i, unmasking_responses[i]), IsOk());
375       EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 3));
376     }
377   }
378 
379   auto next_state = state.ProceedToNextRound();
380   ASSERT_THAT(next_state, IsOk());
381   EXPECT_THAT(next_state.value()->State(),
382               Eq(SecAggServerStateKind::PRNG_RUNNING));
383   EXPECT_THAT(
384       next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
385       Eq(1));
386   EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
387               Eq(0));
388 }
389 
TEST(SecaggServerR3UnmaskingStateTest,StateProceedsCorrectlyWithAnAbortInRound2)390 TEST(SecaggServerR3UnmaskingStateTest,
391      StateProceedsCorrectlyWithAnAbortInRound2) {
392   // In this test, client 3 never sent a masked input, so clients should send
393   // the pairwise key share for client 3.
394   auto sender = std::make_unique<MockSendToClientsInterface>();
395   auto impl = CreateSecAggServerProtocolImpl(3, 4, sender.get());
396   impl->set_client_status(3, ClientStatus::DEAD_AFTER_SHARE_KEYS_RECEIVED);
397 
398   SecAggServerR3UnmaskingState state(
399       std::move(impl),
400       0,   // number_of_clients_failed_after_sending_masked_input
401       1,   // number_of_clients_failed_before_sending_masked_input
402       0);  // number_of_clients_terminated_without_unmasking
403 
404   // Set up correct responses
405   std::vector<ClientToServerWrapperMessage> unmasking_responses(3);
406   for (int i = 0; i < 3; ++i) {
407     for (int j = 0; j < 3; ++j) {
408       NoiseOrPrfKeyShare* share = unmasking_responses[i]
409                                       .mutable_unmasking_response()
410                                       ->add_noise_or_prf_key_shares();
411       share->set_prf_sk_share(
412           absl::StrCat("Test key share for client ", j, " from client ", i));
413     }
414     NoiseOrPrfKeyShare* share = unmasking_responses[i]
415                                     .mutable_unmasking_response()
416                                     ->add_noise_or_prf_key_shares();
417     share->set_noise_sk_share(
418         absl::StrCat("Test key share for client ", 3, " from client ", i));
419   }
420 
421   // No clients should actually get a message in this round.
422   EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
423   EXPECT_CALL(*sender, Send(_, _)).Times(0);
424 
425   // i is the number of messages received so far
426   for (int i = 0; i <= 3; ++i) {
427     EXPECT_THAT(state.NeedsToAbort(), IsFalse());
428     EXPECT_THAT(state.NumberOfAliveClients(), Eq(3));
429     EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
430     EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
431     EXPECT_THAT(state.NumberOfPendingClients(), Eq(3 - i));
432     if (i < 3) {
433       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
434       EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
435     } else {
436       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
437       EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
438     }
439     if (i < 3) {
440       ASSERT_THAT(state.HandleMessage(i, unmasking_responses[i]), IsOk());
441       EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
442     }
443   }
444 
445   auto next_state = state.ProceedToNextRound();
446   ASSERT_THAT(next_state, IsOk());
447   EXPECT_THAT(next_state.value()->State(),
448               Eq(SecAggServerStateKind::PRNG_RUNNING));
449   EXPECT_THAT(
450       next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
451       Eq(0));
452   EXPECT_THAT(
453       next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
454       Eq(1));
455   EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
456               Eq(0));
457 }
458 
TEST(SecaggServerR3UnmaskingStateTest,StateProceedsCorrectlyWithAnAbortInRound1)459 TEST(SecaggServerR3UnmaskingStateTest,
460      StateProceedsCorrectlyWithAnAbortInRound1) {
461   // In this test, client 3 never even finished the key share round, so the
462   // other clients should send no key share for client 3.
463   auto sender = std::make_unique<MockSendToClientsInterface>();
464   auto impl = CreateSecAggServerProtocolImpl(3, 4, sender.get());
465   impl->set_client_status(3, ClientStatus::DEAD_AFTER_ADVERTISE_KEYS_RECEIVED);
466 
467   SecAggServerR3UnmaskingState state(
468       std::move(impl),
469       0,   // number_of_clients_failed_after_sending_masked_input
470       1,   // number_of_clients_failed_before_sending_masked_input
471       0);  // number_of_clients_terminated_without_unmasking
472 
473   // Set up correct responses
474   std::vector<ClientToServerWrapperMessage> unmasking_responses(3);
475   for (int i = 0; i < 3; ++i) {
476     for (int j = 0; j < 3; ++j) {
477       NoiseOrPrfKeyShare* share = unmasking_responses[i]
478                                       .mutable_unmasking_response()
479                                       ->add_noise_or_prf_key_shares();
480       share->set_prf_sk_share(
481           absl::StrCat("Test key share for client ", j, " from client ", i));
482     }
483     unmasking_responses[i]
484         .mutable_unmasking_response()
485         ->add_noise_or_prf_key_shares();
486   }
487 
488   // No clients should actually get a message in this round.
489   EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
490   EXPECT_CALL(*sender, Send(_, _)).Times(0);
491 
492   // i is the number of messages received so far
493   for (int i = 0; i <= 3; ++i) {
494     EXPECT_THAT(state.NeedsToAbort(), IsFalse());
495     EXPECT_THAT(state.NumberOfAliveClients(), Eq(3));
496     EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
497     EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
498     EXPECT_THAT(state.NumberOfPendingClients(), Eq(3 - i));
499     if (i < 3) {
500       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
501       EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
502     } else {
503       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
504       EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
505     }
506     if (i < 3) {
507       ASSERT_THAT(state.HandleMessage(i, unmasking_responses[i]), IsOk());
508       EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
509     }
510   }
511 
512   auto next_state = state.ProceedToNextRound();
513   ASSERT_THAT(next_state, IsOk());
514   EXPECT_THAT(next_state.value()->State(),
515               Eq(SecAggServerStateKind::PRNG_RUNNING));
516   EXPECT_THAT(
517       next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
518       Eq(0));
519   EXPECT_THAT(
520       next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
521       Eq(1));
522   EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
523               Eq(0));
524 }
525 
TEST(SecaggServerR3UnmaskingStateTest,StateProceedsCorrectlyEvenIfClientsAbortAfterSendingMessage)526 TEST(SecaggServerR3UnmaskingStateTest,
527      StateProceedsCorrectlyEvenIfClientsAbortAfterSendingMessage) {
528   // In this test, clients 0 and 1 send valid messages but then abort. But since
529   // they sent valid messages, the server should proceed regardless.
530   auto sender = std::make_unique<MockSendToClientsInterface>();
531 
532   SecAggServerR3UnmaskingState state(
533       CreateSecAggServerProtocolImpl(3, 4, sender.get()),
534       0,   // number_of_clients_failed_after_sending_masked_input
535       0,   // number_of_clients_failed_before_sending_masked_input
536       0);  // number_of_clients_terminated_without_unmasking
537 
538   // Set up correct responses
539   std::vector<ClientToServerWrapperMessage> unmasking_responses(4);
540   for (int i = 0; i < 4; ++i) {
541     for (int j = 0; j < 4; ++j) {
542       NoiseOrPrfKeyShare* share = unmasking_responses[i]
543                                       .mutable_unmasking_response()
544                                       ->add_noise_or_prf_key_shares();
545       share->set_prf_sk_share(
546           absl::StrCat("Test key share for client ", j, " from client ", i));
547     }
548   }
549 
550   // No clients should actually get a message in this round.
551   EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
552   EXPECT_CALL(*sender, Send(_, _)).Times(0);
553 
554   ClientToServerWrapperMessage abort_message;
555   abort_message.mutable_abort();
556 
557   // i is the number of messages received so far
558   for (int i = 0; i <= 4; ++i) {
559     EXPECT_THAT(state.NeedsToAbort(), IsFalse());
560     EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
561     EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
562     EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
563     EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
564     if (i < 3) {
565       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
566       EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
567     } else {
568       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
569       EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
570     }
571     if (i < 4) {
572       ASSERT_THAT(state.HandleMessage(i, unmasking_responses[i]), IsOk());
573       EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
574     }
575   }
576   // These should not change anything.
577   EXPECT_THAT(state.HandleMessage(0, abort_message), IsOk());
578   EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
579   EXPECT_THAT(state.HandleMessage(1, abort_message), IsOk());
580   EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
581   EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
582   EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(4));
583   EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(4));
584   EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
585 
586   auto next_state = state.ProceedToNextRound();
587   ASSERT_THAT(next_state, IsOk());
588   EXPECT_THAT(next_state.value()->State(),
589               Eq(SecAggServerStateKind::PRNG_RUNNING));
590   EXPECT_THAT(
591       next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
592       Eq(0));
593   EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
594               Eq(0));
595 }
596 
TEST(SecaggServerR3UnmaskingStateTest,StateAbortsIfTooManyClientsAbort)597 TEST(SecaggServerR3UnmaskingStateTest, StateAbortsIfTooManyClientsAbort) {
598   // In this test, clients 0 and 1 send abort messages rather than valid
599   // unmasking responses, so the server must abort
600   TestTracingRecorder tracing_recorder;
601   auto sender = std::make_unique<MockSendToClientsInterface>();
602 
603   SecAggServerR3UnmaskingState state(
604       CreateSecAggServerProtocolImpl(3, 4, sender.get()),
605       0,   // number_of_clients_failed_after_sending_masked_input
606       0,   // number_of_clients_failed_before_sending_masked_input
607       0);  // number_of_clients_terminated_without_unmasking
608 
609   // Set up correct responses
610   std::vector<ClientToServerWrapperMessage> unmasking_responses(4);
611   for (int i = 0; i < 4; ++i) {
612     for (int j = 0; j < 4; ++j) {
613       NoiseOrPrfKeyShare* share = unmasking_responses[i]
614                                       .mutable_unmasking_response()
615                                       ->add_noise_or_prf_key_shares();
616       share->set_prf_sk_share(
617           absl::StrCat("Test key share for client ", j, " from client ", i));
618     }
619   }
620 
621   // No individual clients should get a message, but the server should broadcast
622   // an abort message
623   ServerToClientWrapperMessage server_abort_message;
624   server_abort_message.mutable_abort()->set_diagnostic_info(
625       "Too many clients aborted.");
626   server_abort_message.mutable_abort()->set_early_success(false);
627   EXPECT_CALL(*sender, SendBroadcast(EqualsProto(server_abort_message)))
628       .Times(1);
629   EXPECT_CALL(*sender, Send(_, _)).Times(0);
630 
631   ClientToServerWrapperMessage client_abort_message;
632   client_abort_message.mutable_abort();
633 
634   ASSERT_THAT(state.HandleMessage(0, client_abort_message), IsOk());
635   EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
636   EXPECT_THAT(state.NeedsToAbort(), IsFalse());
637   EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
638   ASSERT_THAT(state.HandleMessage(1, client_abort_message), IsOk());
639   EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
640   EXPECT_THAT(state.NeedsToAbort(), IsTrue());
641   EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
642 
643   auto next_state = state.ProceedToNextRound();
644   ASSERT_THAT(next_state, IsOk());
645   EXPECT_THAT(next_state.value()->State(), Eq(SecAggServerStateKind::ABORTED));
646   EXPECT_THAT(next_state.value()->ErrorMessage(), IsOk());
647   EXPECT_THAT(next_state.value()->ErrorMessage().value(),
648               Eq("Too many clients aborted."));
649   EXPECT_THAT(tracing_recorder.FindAllEvents<BroadcastMessageSent>(),
650               ElementsAre(IsEvent<BroadcastMessageSent>(
651                   Eq(ServerToClientMessageType_Abort),
652                   Eq(server_abort_message.ByteSizeLong()))));
653 }
654 
TEST(SecaggServerR3UnmaskingStateTest,MetricsRecordsMessageSizes)655 TEST(SecaggServerR3UnmaskingStateTest, MetricsRecordsMessageSizes) {
656   // In this test, client 3 never sent a masked input, so clients should send
657   // the pairwise key share for client 3.
658   TestTracingRecorder tracing_recorder;
659   MockSecAggServerMetricsListener* metrics =
660       new MockSecAggServerMetricsListener();
661   auto sender = std::make_unique<MockSendToClientsInterface>();
662   auto impl = CreateSecAggServerProtocolImpl(3, 4, sender.get(), metrics);
663   impl->set_client_status(3, ClientStatus::DEAD_AFTER_SHARE_KEYS_RECEIVED);
664 
665   SecAggServerR3UnmaskingState state(
666       std::move(impl),
667       0,   // number_of_clients_failed_after_sending_masked_input
668       1,   // number_of_clients_failed_before_sending_masked_input
669       0);  // number_of_clients_terminated_without_unmasking
670 
671   // Set up correct responses
672   std::vector<ClientToServerWrapperMessage> unmasking_responses(3);
673   for (int i = 0; i < 3; ++i) {
674     for (int j = 0; j < 3; ++j) {
675       NoiseOrPrfKeyShare* share = unmasking_responses[i]
676                                       .mutable_unmasking_response()
677                                       ->add_noise_or_prf_key_shares();
678       share->set_prf_sk_share(
679           absl::StrCat("Test key share for client ", j, " from client ", i));
680     }
681     NoiseOrPrfKeyShare* share = unmasking_responses[i]
682                                     .mutable_unmasking_response()
683                                     ->add_noise_or_prf_key_shares();
684     share->set_noise_sk_share(
685         absl::StrCat("Test key share for client ", 3, " from client ", i));
686   }
687 
688   // No clients should actually get a message in this round.
689   EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
690   EXPECT_CALL(*sender, Send(_, _)).Times(0);
691   EXPECT_CALL(
692       *metrics,
693       MessageReceivedSizes(Eq(ClientToServerWrapperMessage::MessageContentCase::
694                                   kUnmaskingResponse),
695                            Eq(true), Eq(unmasking_responses[0].ByteSizeLong())))
696       .Times(3);
697 
698   // i is the number of messages received so far
699   for (int i = 0; i <= 3; ++i) {
700     EXPECT_THAT(state.NeedsToAbort(), IsFalse());
701     EXPECT_THAT(state.NumberOfAliveClients(), Eq(3));
702     EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
703     EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
704     EXPECT_THAT(state.NumberOfPendingClients(), Eq(3 - i));
705     if (i < 3) {
706       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
707       EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
708     } else {
709       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
710       EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
711     }
712     if (i < 3) {
713       ASSERT_THAT(state.HandleMessage(i, unmasking_responses[i]), IsOk());
714       EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
715       EXPECT_THAT(
716           tracing_recorder.root()[i],
717           IsEvent<ClientMessageReceived>(
718               Eq(ClientToServerMessageType_UnmaskingResponse),
719               Eq(unmasking_responses[i].ByteSizeLong()), Eq(true), Ge(0)));
720     }
721   }
722 
723   auto next_state = state.ProceedToNextRound();
724   ASSERT_THAT(next_state, IsOk());
725   EXPECT_THAT(next_state.value()->State(),
726               Eq(SecAggServerStateKind::PRNG_RUNNING));
727   EXPECT_THAT(
728       next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
729       Eq(0));
730   EXPECT_THAT(
731       next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
732       Eq(1));
733   EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
734               Eq(0));
735 }
736 
TEST(SecaggServerR3UnmaskingStateTest,ServerAndClientAbortsAreRecordedCorrectly)737 TEST(SecaggServerR3UnmaskingStateTest,
738      ServerAndClientAbortsAreRecordedCorrectly) {
739   // In this test clients abort for a variety of reasons, and then ultimately
740   // the server aborts. Metrics should record all of these events.
741   MockSecAggServerMetricsListener* metrics =
742       new MockSecAggServerMetricsListener();
743   auto sender = std::make_unique<MockSendToClientsInterface>();
744   auto impl = CreateSecAggServerProtocolImpl(2, 8, sender.get(), metrics);
745   impl->ErasePublicKeysForClient(7);
746   impl->set_client_status(6, ClientStatus::DEAD_AFTER_SHARE_KEYS_RECEIVED);
747   impl->set_client_status(7, ClientStatus::DEAD_BEFORE_SENDING_ANYTHING);
748 
749   SecAggServerR3UnmaskingState state(
750       std::move(impl),
751       0,   // number_of_clients_failed_after_sending_masked_input
752       2,   // number_of_clients_failed_before_sending_masked_input
753       0);  // number_of_clients_terminated_without_unmasking
754 
755   EXPECT_CALL(*metrics,
756               ClientsDropped(
757                   Eq(ClientStatus::DEAD_AFTER_MASKED_INPUT_RESPONSE_RECEIVED),
758                   Eq(ClientDropReason::SENT_ABORT_MESSAGE)));
759   EXPECT_CALL(*metrics,
760               ClientsDropped(
761                   Eq(ClientStatus::DEAD_AFTER_UNMASKING_RESPONSE_RECEIVED), _))
762       .Times(0);
763   EXPECT_CALL(*metrics,
764               ClientsDropped(
765                   Eq(ClientStatus::DEAD_AFTER_MASKED_INPUT_RESPONSE_RECEIVED),
766                   Eq(ClientDropReason::UNEXPECTED_MESSAGE_TYPE)));
767   EXPECT_CALL(*metrics,
768               ClientsDropped(
769                   Eq(ClientStatus::DEAD_AFTER_MASKED_INPUT_RESPONSE_RECEIVED),
770                   Eq(ClientDropReason::INVALID_UNMASKING_RESPONSE)))
771       .Times(3);
772   EXPECT_CALL(
773       *metrics,
774       ProtocolOutcomes(Eq(SecAggServerOutcome::NOT_ENOUGH_CLIENTS_REMAINING)));
775 
776   ClientToServerWrapperMessage abort_message;
777   abort_message.mutable_abort()->set_diagnostic_info("Aborting for test");
778 
779   ClientToServerWrapperMessage valid_message;  // from client 1
780   for (int j = 0; j < 6; ++j) {
781     NoiseOrPrfKeyShare* share = valid_message.mutable_unmasking_response()
782                                     ->add_noise_or_prf_key_shares();
783     share->set_prf_sk_share(
784         absl::StrCat("Test key share for client ", j, " from client 1"));
785   }
786   NoiseOrPrfKeyShare* share =
787       valid_message.mutable_unmasking_response()->add_noise_or_prf_key_shares();
788   share->set_noise_sk_share(
789       absl::StrCat("Test key share for client ", 6, " from client 1"));
790   share =
791       valid_message.mutable_unmasking_response()->add_noise_or_prf_key_shares();
792 
793   ClientToServerWrapperMessage invalid_noise_instead_of_prf;  // from client 2
794   for (int j = 0; j < 5; ++j) {
795     share = invalid_noise_instead_of_prf.mutable_unmasking_response()
796                 ->add_noise_or_prf_key_shares();
797     share->set_prf_sk_share(
798         absl::StrCat("Test key share for client ", j, " from client 2"));
799   }
800   for (int j = 5; j < 7; ++j) {  // client 5 should not be included here
801     share = invalid_noise_instead_of_prf.mutable_unmasking_response()
802                 ->add_noise_or_prf_key_shares();
803     share->set_noise_sk_share(
804         absl::StrCat("Test key share for client ", j, " from client 2"));
805   }
806   share = invalid_noise_instead_of_prf.mutable_unmasking_response()
807               ->add_noise_or_prf_key_shares();
808 
809   ClientToServerWrapperMessage invalid_prf_instead_of_noise;  // from client 3
810   for (int j = 0; j < 7; ++j) {  // client 6 should not be included here
811     share = invalid_prf_instead_of_noise.mutable_unmasking_response()
812                 ->add_noise_or_prf_key_shares();
813     share->set_prf_sk_share(
814         absl::StrCat("Test key share for client ", j, " from client 3"));
815   }
816   share = invalid_prf_instead_of_noise.mutable_unmasking_response()
817               ->add_noise_or_prf_key_shares();
818 
819   ClientToServerWrapperMessage invalid_noise_instead_of_blank;  // from client 4
820   for (int j = 0; j < 6; ++j) {
821     share = invalid_noise_instead_of_blank.mutable_unmasking_response()
822                 ->add_noise_or_prf_key_shares();
823     share->set_prf_sk_share(
824         absl::StrCat("Test key share for client ", j, " from client 4"));
825   }
826   for (int j = 6; j < 8; ++j) {  // client 7 should not be included here
827     share = invalid_noise_instead_of_blank.mutable_unmasking_response()
828                 ->add_noise_or_prf_key_shares();
829     share->set_noise_sk_share(
830         absl::StrCat("Test key share for client ", j, " from client 4"));
831   }
832 
833   ClientToServerWrapperMessage wrong_message;
834   wrong_message.mutable_advertise_keys();  // wrong type of message
835 
836   state.HandleMessage(0, abort_message).IgnoreError();
837   state.HandleMessage(1, valid_message).IgnoreError();
838   state.HandleMessage(1, valid_message).IgnoreError();
839   state.HandleMessage(2, invalid_noise_instead_of_prf).IgnoreError();
840   state.HandleMessage(3, invalid_prf_instead_of_noise).IgnoreError();
841   state.HandleMessage(4, invalid_noise_instead_of_blank).IgnoreError();
842   state.HandleMessage(5, wrong_message).IgnoreError();
843   state.ProceedToNextRound().IgnoreError();  // causes server abort
844 }
845 
TEST(SecaggServerR3UnmaskingStateTest,MetricsAreRecorded)846 TEST(SecaggServerR3UnmaskingStateTest, MetricsAreRecorded) {
847   // In this test, no clients abort or aborted at any point, but
848   // ProceedToNextRound is called after only 3 clients have submitted masked
849   // input responses. This is perfectly valid because the threshold is 3.
850   MockSecAggServerMetricsListener* metrics =
851       new MockSecAggServerMetricsListener();
852   auto sender = std::make_unique<MockSendToClientsInterface>();
853 
854   SecAggServerR3UnmaskingState state(
855       CreateSecAggServerProtocolImpl(3, 4, sender.get(), metrics),
856       0,   // number_of_clients_failed_after_sending_masked_input
857       0,   // number_of_clients_failed_before_sending_masked_input
858       0);  // number_of_clients_terminated_without_unmasking
859 
860   // Set up correct responses
861   std::vector<ClientToServerWrapperMessage> unmasking_responses(4);
862   for (int i = 0; i < 4; ++i) {
863     for (int j = 0; j < 4; ++j) {
864       NoiseOrPrfKeyShare* share = unmasking_responses[i]
865                                       .mutable_unmasking_response()
866                                       ->add_noise_or_prf_key_shares();
867       share->set_prf_sk_share(
868           absl::StrCat("Test key share for client ", j, " from client ", i));
869     }
870   }
871 
872   // Only client 3 should get a message this round.
873   ServerToClientWrapperMessage abort_message;
874   abort_message.mutable_abort()->set_diagnostic_info(
875       "Client did not send unmasking response but protocol completed "
876       "successfully.");
877   abort_message.mutable_abort()->set_early_success(true);
878   EXPECT_CALL(*sender, Send(3, EqualsProto(abort_message))).Times(1);
879   EXPECT_CALL(*sender, Send(Ne(3), _)).Times(0);
880   EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
881 
882   EXPECT_CALL(*metrics, RoundTimes(Eq(SecAggServerStateKind::R3_UNMASKING),
883                                    Eq(true), Ge(0)));
884   EXPECT_CALL(*metrics, RoundSurvivingClients(
885                             Eq(SecAggServerStateKind::R3_UNMASKING), Eq(3)));
886   EXPECT_CALL(*metrics, ClientResponseTimes(
887                             Eq(ClientToServerWrapperMessage::
888                                    MessageContentCase::kUnmaskingResponse),
889                             Ge(0)))
890       .Times(3);
891   EXPECT_CALL(*metrics,
892               ClientsDropped(
893                   Eq(ClientStatus::DEAD_AFTER_MASKED_INPUT_RESPONSE_RECEIVED),
894                   Eq(ClientDropReason::EARLY_SUCCESS)));
895 
896   // i is the number of messages received so far. Stop after 3
897   for (int i = 0; i <= 3; ++i) {
898     EXPECT_THAT(state.NeedsToAbort(), IsFalse());
899     EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
900     EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
901     EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
902     EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
903     if (i < 3) {
904       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
905       EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
906     } else {
907       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
908       EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
909     }
910     if (i < 3) {
911       ASSERT_THAT(state.HandleMessage(i, unmasking_responses[i]), IsOk());
912       EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
913     }
914   }
915 
916   auto next_state = state.ProceedToNextRound();
917   ASSERT_THAT(next_state, IsOk());
918   EXPECT_THAT(next_state.value()->State(),
919               Eq(SecAggServerStateKind::PRNG_RUNNING));
920 }
921 
922 }  // namespace
923 }  // namespace secagg
924 }  // namespace fcp
925