1 /*
2  * Copyright 2019 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     https://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/secagg/server/secagg_server_r1_share_keys_state.h"
18 
19 #include <memory>
20 #include <utility>
21 
22 #include "gmock/gmock.h"
23 #include "gtest/gtest.h"
24 #include "absl/strings/str_cat.h"
25 #include "fcp/secagg/server/aes/aes_secagg_server_protocol_impl.h"
26 #include "fcp/secagg/server/secagg_server_state.h"
27 #include "fcp/secagg/server/secret_sharing_graph_factory.h"
28 #include "fcp/secagg/shared/aes_ctr_prng_factory.h"
29 #include "fcp/secagg/shared/compute_session_id.h"
30 #include "fcp/secagg/shared/ecdh_keys.h"
31 #include "fcp/secagg/shared/input_vector_specification.h"
32 #include "fcp/secagg/shared/secagg_messages.pb.h"
33 #include "fcp/secagg/shared/shamir_secret_sharing.h"
34 #include "fcp/secagg/testing/ecdh_pregenerated_test_keys.h"
35 #include "fcp/secagg/testing/server/mock_secagg_server_metrics_listener.h"
36 #include "fcp/secagg/testing/server/mock_send_to_clients_interface.h"
37 #include "fcp/testing/testing.h"
38 #include "fcp/tracing/test_tracing_recorder.h"
39 
40 namespace fcp {
41 namespace secagg {
42 namespace {
43 
44 using ::testing::_;
45 using ::testing::Eq;
46 using ::testing::Ge;
47 using ::testing::IsFalse;
48 using ::testing::IsTrue;
49 
50 // Default test session_id.
51 SessionId session_id = {"session id number, 32 bytes long"};
52 
CreateSecAggServerProtocolImpl(int minimum_number_of_clients_to_proceed,int total_number_of_clients,MockSendToClientsInterface * sender,MockSecAggServerMetricsListener * metrics_listener=nullptr)53 std::unique_ptr<AesSecAggServerProtocolImpl> CreateSecAggServerProtocolImpl(
54     int minimum_number_of_clients_to_proceed, int total_number_of_clients,
55     MockSendToClientsInterface* sender,
56     MockSecAggServerMetricsListener* metrics_listener = nullptr) {
57   auto input_vector_specs = std::vector<InputVectorSpecification>();
58   SecretSharingGraphFactory factory;
59   input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
60   auto impl = std::make_unique<AesSecAggServerProtocolImpl>(
61       factory.CreateCompleteGraph(total_number_of_clients,
62                                   minimum_number_of_clients_to_proceed),
63       minimum_number_of_clients_to_proceed, input_vector_specs,
64       std::unique_ptr<MockSecAggServerMetricsListener>(metrics_listener),
65       std::make_unique<AesCtrPrngFactory>(), sender,
66       std::make_unique<SecAggScheduler>(
67           /*sequential_scheduler=*/nullptr,
68           /*parallel_scheduler=*/nullptr),
69       std::vector<ClientStatus>(total_number_of_clients,
70                                 ClientStatus::ADVERTISE_KEYS_RECEIVED),
71       ServerVariant::NATIVE_V1);
72   impl->set_session_id(std::make_unique<SessionId>(session_id));
73   EcdhPregeneratedTestKeys ecdh_keys;
74   for (int i = 0; i < total_number_of_clients; i++) {
75     impl->SetPairwisePublicKeys(i, ecdh_keys.GetPublicKey(i));
76   }
77   return impl;
78 }
79 
TEST(SecaggServerR1ShareKeysStateTest,IsAbortedReturnsFalse)80 TEST(SecaggServerR1ShareKeysStateTest, IsAbortedReturnsFalse) {
81   auto sender = std::make_shared<MockSendToClientsInterface>();
82 
83   SecAggServerR1ShareKeysState state(
84       CreateSecAggServerProtocolImpl(3, 4, sender.get()),
85       0,  // number_of_clients_failed_after_sending_masked_input
86       0,  // number_of_clients_failed_before_sending_masked_input
87       0   // number_of_clients_terminated_without_unmasking
88   );
89 
90   EXPECT_THAT(state.IsAborted(), IsFalse());
91 }
92 
TEST(SecaggServerR1ShareKeysStateTest,IsCompletedSuccessfullyReturnsFalse)93 TEST(SecaggServerR1ShareKeysStateTest, IsCompletedSuccessfullyReturnsFalse) {
94   auto sender = std::make_shared<MockSendToClientsInterface>();
95 
96   SecAggServerR1ShareKeysState state(
97       CreateSecAggServerProtocolImpl(3, 4, sender.get()),
98       0,  // number_of_clients_failed_after_sending_masked_input
99       0,  // number_of_clients_failed_before_sending_masked_input
100       0   // number_of_clients_terminated_without_unmasking
101   );
102 
103   EXPECT_THAT(state.IsCompletedSuccessfully(), IsFalse());
104 }
105 
TEST(SecaggServerR1ShareKeysStateTest,ErrorMessageRaisesErrorStatus)106 TEST(SecaggServerR1ShareKeysStateTest, ErrorMessageRaisesErrorStatus) {
107   auto sender = std::make_shared<MockSendToClientsInterface>();
108 
109   SecAggServerR1ShareKeysState state(
110       CreateSecAggServerProtocolImpl(3, 4, sender.get()),
111       0,  // number_of_clients_failed_after_sending_masked_input
112       0,  // number_of_clients_failed_before_sending_masked_input
113       0   // number_of_clients_terminated_without_unmasking
114   );
115 
116   EXPECT_THAT(state.ErrorMessage().ok(), IsFalse());
117 }
118 
TEST(SecaggServerR1ShareKeysStateTest,ResultRaisesErrorStatus)119 TEST(SecaggServerR1ShareKeysStateTest, ResultRaisesErrorStatus) {
120   auto sender = std::make_shared<MockSendToClientsInterface>();
121 
122   SecAggServerR1ShareKeysState state(
123       CreateSecAggServerProtocolImpl(3, 4, sender.get()),
124       0,  // number_of_clients_failed_after_sending_masked_input
125       0,  // number_of_clients_failed_before_sending_masked_input
126       0   // number_of_clients_terminated_without_unmasking
127   );
128 
129   EXPECT_THAT(state.Result().ok(), IsFalse());
130 }
131 
TEST(SecaggServerR1ShareKeysStateTest,AbortReturnsValidStateAndNotifiesClients)132 TEST(SecaggServerR1ShareKeysStateTest,
133      AbortReturnsValidStateAndNotifiesClients) {
134   TestTracingRecorder tracing_recorder;
135   MockSecAggServerMetricsListener* metrics =
136       new MockSecAggServerMetricsListener();
137   auto sender = std::make_shared<MockSendToClientsInterface>();
138 
139   SecAggServerR1ShareKeysState state(
140       CreateSecAggServerProtocolImpl(3, 4, sender.get(), metrics),
141       0,  // number_of_clients_failed_after_sending_masked_input
142       0,  // number_of_clients_failed_before_sending_masked_input
143       0   // number_of_clients_terminated_without_unmasking
144   );
145 
146   ServerToClientWrapperMessage abort_message;
147   abort_message.mutable_abort()->set_early_success(false);
148   abort_message.mutable_abort()->set_diagnostic_info("test abort reason");
149 
150   EXPECT_CALL(*metrics,
151               ProtocolOutcomes(Eq(SecAggServerOutcome::EXTERNAL_REQUEST)));
152   EXPECT_CALL(*sender, SendBroadcast(EqualsProto(abort_message)));
153   auto next_state =
154       state.Abort("test abort reason", SecAggServerOutcome::EXTERNAL_REQUEST);
155 
156   ASSERT_THAT(next_state->State(), Eq(SecAggServerStateKind::ABORTED));
157   ASSERT_THAT(next_state->ErrorMessage(), IsOk());
158   EXPECT_THAT(next_state->ErrorMessage().value(), Eq("test abort reason"));
159   EXPECT_THAT(tracing_recorder.FindAllEvents<BroadcastMessageSent>(),
160               ElementsAre(IsEvent<BroadcastMessageSent>(
161                   Eq(ServerToClientMessageType_Abort),
162                   Eq(abort_message.ByteSizeLong()))));
163 }
164 
TEST(SecaggServerR1ShareKeysStateTest,StateProceedsCorrectlyWithAllClientsValid)165 TEST(SecaggServerR1ShareKeysStateTest,
166      StateProceedsCorrectlyWithAllClientsValid) {
167   // In this test, all clients send inputs for the correct clients, and then the
168   // server proceeds to the next state. (The inputs aren't actually encrypted
169   // shared keys, but that doesn't matter for this test.)
170   auto sender = std::make_shared<MockSendToClientsInterface>();
171 
172   SecAggServerR1ShareKeysState state(
173       CreateSecAggServerProtocolImpl(3, 4, sender.get()),
174       0,  // number_of_clients_failed_after_sending_masked_input
175       0,  // number_of_clients_failed_before_sending_masked_input
176       0   // number_of_clients_terminated_without_unmasking
177   );
178 
179   for (int i = 0; i < 5; ++i) {
180     EXPECT_THAT(state.NeedsToAbort(), IsFalse());
181     EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
182     EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
183     EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
184     EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
185     if (i < 3) {
186       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
187       EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
188     } else {
189       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
190       EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
191     }
192     if (i < 4) {
193       // Have one client send the right vector of "encrypted keys" to the
194       // server.
195       ClientToServerWrapperMessage client_message;
196       for (int j = 0; j < 4; ++j) {
197         if (i == j) {
198           client_message.mutable_share_keys_response()
199               ->add_encrypted_key_shares("");
200         } else {
201           client_message.mutable_share_keys_response()
202               ->add_encrypted_key_shares(
203                   absl::StrCat("encrypted key shares from ", i, " to ", j));
204         }
205       }
206       ASSERT_THAT(state.HandleMessage(i, client_message), IsOk());
207       EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
208     }
209   }
210   std::vector<ServerToClientWrapperMessage> server_messages(4);
211   for (int i = 0; i < 4; ++i) {
212     for (int j = 0; j < 4; ++j) {
213       if (i == j) {
214         server_messages[i]
215             .mutable_masked_input_request()
216             ->add_encrypted_key_shares("");
217       } else {
218         server_messages[i]
219             .mutable_masked_input_request()
220             ->add_encrypted_key_shares(
221                 absl::StrCat("encrypted key shares from ", j, " to ", i));
222       }
223     }
224     EXPECT_CALL(*sender, Send(Eq(i), EqualsProto(server_messages[i]))).Times(1);
225   }
226   EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
227 
228   auto next_state = state.ProceedToNextRound();
229   ASSERT_THAT(next_state, IsOk());
230   EXPECT_THAT(next_state.value()->State(),
231               Eq(SecAggServerStateKind::R2_MASKED_INPUT_COLLECTION));
232   EXPECT_THAT(
233       next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
234       Eq(0));
235   EXPECT_THAT(
236       next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
237       Eq(0));
238   EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
239               Eq(0));
240 }
241 
TEST(SecaggServerR1ShareKeysStateTest,StateProceedsCorrectlyWithOnePreviousDropout)242 TEST(SecaggServerR1ShareKeysStateTest,
243      StateProceedsCorrectlyWithOnePreviousDropout) {
244   // In this test, client 3 dropped out in round 0, so clients should not send
245   // key shares for it. All other clients proceed normally.
246   auto sender = std::make_shared<MockSendToClientsInterface>();
247   auto impl = CreateSecAggServerProtocolImpl(3, 4, sender.get());
248   impl->set_client_status(3, ClientStatus::DEAD_BEFORE_SENDING_ANYTHING);
249 
250   SecAggServerR1ShareKeysState state(
251       std::move(impl),
252       0,  // number_of_clients_failed_after_sending_masked_input
253       1,  // number_of_clients_failed_before_sending_masked_input
254       0   // number_of_clients_terminated_without_unmasking
255   );
256 
257   for (int i = 0; i < 4; ++i) {
258     EXPECT_THAT(state.NeedsToAbort(), IsFalse());
259     EXPECT_THAT(state.NumberOfAliveClients(), Eq(3));
260     EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
261     EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
262     EXPECT_THAT(state.NumberOfPendingClients(), Eq(3 - i));
263     if (i < 3) {
264       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
265       EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
266     } else {
267       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
268       EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
269     }
270     if (i < 3) {
271       // Have one client send the right vector of "encrypted keys" to the
272       // server.
273       ClientToServerWrapperMessage client_message;
274       for (int j = 0; j < 4; ++j) {
275         if (i == j || j == 3) {
276           client_message.mutable_share_keys_response()
277               ->add_encrypted_key_shares("");
278         } else {
279           client_message.mutable_share_keys_response()
280               ->add_encrypted_key_shares(
281                   absl::StrCat("encrypted key shares from ", i, " to ", j));
282         }
283       }
284       ASSERT_THAT(state.HandleMessage(i, client_message), IsOk());
285       EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
286     }
287   }
288   std::vector<ServerToClientWrapperMessage> server_messages(3);
289   for (int i = 0; i < 3; ++i) {
290     for (int j = 0; j < 4; ++j) {
291       if (i == j || j == 3) {
292         server_messages[i]
293             .mutable_masked_input_request()
294             ->add_encrypted_key_shares("");
295       } else {
296         server_messages[i]
297             .mutable_masked_input_request()
298             ->add_encrypted_key_shares(
299                 absl::StrCat("encrypted key shares from ", j, " to ", i));
300       }
301     }
302     EXPECT_CALL(*sender, Send(Eq(i), EqualsProto(server_messages[i]))).Times(1);
303   }
304   EXPECT_CALL(*sender, Send(Eq(3), _)).Times(0);
305   EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
306 
307   auto next_state = state.ProceedToNextRound();
308   ASSERT_THAT(next_state, IsOk());
309   EXPECT_THAT(next_state.value()->State(),
310               Eq(SecAggServerStateKind::R2_MASKED_INPUT_COLLECTION));
311   EXPECT_THAT(
312       next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
313       Eq(0));
314   EXPECT_THAT(
315       next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
316       Eq(1));
317   EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
318               Eq(0));
319 }
320 
TEST(SecaggServerR1ShareKeysStateTest,StateProceedsCorrectlyWithAnAbortAfterSendingShares)321 TEST(SecaggServerR1ShareKeysStateTest,
322      StateProceedsCorrectlyWithAnAbortAfterSendingShares) {
323   // In this test, all clients send inputs for the correct clients, but then
324   // client 2 aborts. This should cause that client's message shared keys not to
325   // appear in the messages sent later.
326   auto sender = std::make_shared<MockSendToClientsInterface>();
327 
328   SecAggServerR1ShareKeysState state(
329       CreateSecAggServerProtocolImpl(3, 4, sender.get()),
330       0,  // number_of_clients_failed_after_sending_masked_input
331       0,  // number_of_clients_failed_before_sending_masked_input
332       0   // number_of_clients_terminated_without_unmasking
333   );
334 
335   for (int i = 0; i < 5; ++i) {
336     EXPECT_THAT(state.NeedsToAbort(), IsFalse());
337     EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
338     EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
339     EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
340     EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
341     if (i < 3) {
342       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
343       EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
344     } else {
345       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
346       EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
347     }
348     if (i < 4) {
349       // Have one client send the right vector of "encrypted key shares" to
350       // the server.
351       ClientToServerWrapperMessage client_message;
352       for (int j = 0; j < 4; ++j) {
353         if (i == j) {
354           client_message.mutable_share_keys_response()
355               ->add_encrypted_key_shares("");
356         } else {
357           client_message.mutable_share_keys_response()
358               ->add_encrypted_key_shares(
359                   absl::StrCat("encrypted key shares from ", i, " to ", j));
360         }
361       }
362       ASSERT_THAT(state.HandleMessage(i, client_message), IsOk());
363       EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
364     }
365   }
366 
367   ClientToServerWrapperMessage abort_message;
368   abort_message.mutable_abort()->set_diagnostic_info("aborting for test");
369   ASSERT_THAT(state.HandleMessage(2, abort_message), IsOk());
370   EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
371 
372   std::vector<ServerToClientWrapperMessage> server_messages(4);
373   for (int i = 0; i < 4; ++i) {
374     if (i == 2) {
375       EXPECT_CALL(*sender, Send(Eq(2), _)).Times(0);
376       continue;
377     }
378     for (int j = 0; j < 4; ++j) {
379       if (i == j || j == 2) {
380         server_messages[i]
381             .mutable_masked_input_request()
382             ->add_encrypted_key_shares("");
383       } else {
384         server_messages[i]
385             .mutable_masked_input_request()
386             ->add_encrypted_key_shares(
387                 absl::StrCat("encrypted key shares from ", j, " to ", i));
388       }
389     }
390     EXPECT_CALL(*sender, Send(Eq(i), EqualsProto(server_messages[i]))).Times(1);
391   }
392   EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
393 
394   auto next_state = state.ProceedToNextRound();
395   ASSERT_THAT(next_state, IsOk());
396   EXPECT_THAT(next_state.value()->State(),
397               Eq(SecAggServerStateKind::R2_MASKED_INPUT_COLLECTION));
398   EXPECT_THAT(
399       next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
400       Eq(0));
401   EXPECT_THAT(
402       next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
403       Eq(1));
404   EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
405               Eq(0));
406 }
407 
TEST(SecaggServerR1ShareKeysStateTest,StateProceedsCorrectlyWithOneClientSendingInvalidShares)408 TEST(SecaggServerR1ShareKeysStateTest,
409      StateProceedsCorrectlyWithOneClientSendingInvalidShares) {
410   // In this test, all clients send encrypted shares, but client 0 omits an
411   // encrypted share for client 1. This should force client 0 to abort.
412   auto sender = std::make_shared<MockSendToClientsInterface>();
413 
414   SecAggServerR1ShareKeysState state(
415       CreateSecAggServerProtocolImpl(3, 4, sender.get()),
416       0,  // number_of_clients_failed_after_sending_masked_input
417       0,  // number_of_clients_failed_before_sending_masked_input
418       0   // number_of_clients_terminated_without_unmasking
419   );
420 
421   std::vector<ServerToClientWrapperMessage> server_messages(4);
422   server_messages[0].mutable_abort()->set_early_success(false);
423   server_messages[0].mutable_abort()->set_diagnostic_info(
424       "Client omitted a key share that was expected.");
425   EXPECT_CALL(*sender, Send(Eq(0), EqualsProto(server_messages[0]))).Times(1);
426   for (int i = 1; i < 4; ++i) {
427     for (int j = 0; j < 4; ++j) {
428       if (i == j || j == 0) {
429         server_messages[i]
430             .mutable_masked_input_request()
431             ->add_encrypted_key_shares("");
432       } else {
433         server_messages[i]
434             .mutable_masked_input_request()
435             ->add_encrypted_key_shares(
436                 absl::StrCat("encrypted key shares from ", j, " to ", i));
437       }
438     }
439     EXPECT_CALL(*sender, Send(Eq(i), EqualsProto(server_messages[i]))).Times(1);
440   }
441   EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
442 
443   ClientToServerWrapperMessage bad_message;
444   bad_message.mutable_share_keys_response()->add_encrypted_key_shares("");
445   bad_message.mutable_share_keys_response()->add_encrypted_key_shares("");
446   bad_message.mutable_share_keys_response()->add_encrypted_key_shares(
447       "encrypted key shares from 0 to 2");
448   bad_message.mutable_share_keys_response()->add_encrypted_key_shares(
449       "encrypted key shares from 0 to 3");
450   ASSERT_THAT(state.HandleMessage(0, bad_message), IsOk());
451   EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
452 
453   for (int i = 1; i < 5; ++i) {
454     EXPECT_THAT(state.NeedsToAbort(), IsFalse());
455     EXPECT_THAT(state.NumberOfAliveClients(), Eq(3));
456     EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i - 1));
457     EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i - 1));
458     EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
459     if (i < 4) {
460       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(4 - i));
461       EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
462     } else {
463       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
464       EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
465     }
466     if (i < 4) {
467       // Have one client send the right vector of "encrypted key shares" to
468       // the server.
469       ClientToServerWrapperMessage client_message;
470       for (int j = 0; j < 4; ++j) {
471         if (i == j) {
472           client_message.mutable_share_keys_response()
473               ->add_encrypted_key_shares("");
474         } else {
475           client_message.mutable_share_keys_response()
476               ->add_encrypted_key_shares(
477                   absl::StrCat("encrypted key shares from ", i, " to ", j));
478         }
479       }
480       ASSERT_THAT(state.HandleMessage(i, client_message), IsOk());
481       EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 3));
482     }
483   }
484 
485   auto next_state = state.ProceedToNextRound();
486   ASSERT_THAT(next_state, IsOk());
487   EXPECT_THAT(next_state.value()->State(),
488               Eq(SecAggServerStateKind::R2_MASKED_INPUT_COLLECTION));
489   EXPECT_THAT(
490       next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
491       Eq(0));
492   EXPECT_THAT(
493       next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
494       Eq(1));
495   EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
496               Eq(0));
497 }
498 
TEST(SecaggServerR1ShareKeysStateTest,StateAbortsIfTooManyClientsAbort)499 TEST(SecaggServerR1ShareKeysStateTest, StateAbortsIfTooManyClientsAbort) {
500   // In this test, clients 0 and 1 send abort messages. This should cause the
501   // server state to register that it needs to abort immediately.
502   TestTracingRecorder tracing_recorder;
503   auto sender = std::make_shared<MockSendToClientsInterface>();
504 
505   SecAggServerR1ShareKeysState state(
506       CreateSecAggServerProtocolImpl(3, 4, sender.get()),
507       0,  // number_of_clients_failed_after_sending_masked_input
508       0,  // number_of_clients_failed_before_sending_masked_input
509       0   // number_of_clients_terminated_without_unmasking
510   );
511 
512   for (int i = 0; i < 3; ++i) {
513     EXPECT_THAT(state.NeedsToAbort(), Eq(i >= 2));
514     EXPECT_THAT(state.NumberOfAliveClients(), Eq(4 - i));
515     EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(0));
516     EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(0));
517     EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
518     EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3));
519     EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
520     if (i < 2) {
521       // Have client abort
522       ClientToServerWrapperMessage abort_message;
523       abort_message.mutable_abort()->set_diagnostic_info("Aborting for test");
524       ASSERT_THAT(state.HandleMessage(i, abort_message), IsOk());
525       EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 1));
526     }
527   }
528 
529   ServerToClientWrapperMessage server_message;
530   server_message.mutable_abort()->set_early_success(false);
531   server_message.mutable_abort()->set_diagnostic_info(
532       "Too many clients aborted.");
533   EXPECT_CALL(*sender, SendBroadcast(EqualsProto(server_message))).Times(1);
534   EXPECT_CALL(*sender, Send(_, _)).Times(0);
535 
536   auto next_state = state.ProceedToNextRound();
537   ASSERT_THAT(next_state, IsOk());
538   EXPECT_THAT(next_state.value()->State(), Eq(SecAggServerStateKind::ABORTED));
539   ASSERT_THAT(next_state.value()->ErrorMessage(), IsOk());
540   EXPECT_THAT(next_state.value()->ErrorMessage().value(),
541               Eq("Too many clients aborted."));
542   EXPECT_THAT(tracing_recorder.FindAllEvents<BroadcastMessageSent>(),
543               ElementsAre(IsEvent<BroadcastMessageSent>(
544                   Eq(ServerToClientMessageType_Abort),
545                   Eq(server_message.ByteSizeLong()))));
546 }
547 
TEST(SecaggServerR1ShareKeysStateTest,MetricsRecordsMessageSizes)548 TEST(SecaggServerR1ShareKeysStateTest, MetricsRecordsMessageSizes) {
549   // In this test, all clients send inputs for the correct clients, and then the
550   // server proceeds to the next state. (The inputs aren't actually encrypted
551   // shared keys, but that doesn't matter for this test.)
552   TestTracingRecorder tracing_recorder;
553   MockSecAggServerMetricsListener* metrics =
554       new MockSecAggServerMetricsListener();
555   auto sender = std::make_shared<MockSendToClientsInterface>();
556 
557   SecAggServerR1ShareKeysState state(
558       CreateSecAggServerProtocolImpl(3, 4, sender.get(), metrics),
559       0,  // number_of_clients_failed_after_sending_masked_input
560       0,  // number_of_clients_failed_before_sending_masked_input
561       0   // number_of_clients_terminated_without_unmasking
562   );
563 
564   for (int i = 0; i < 5; ++i) {
565     EXPECT_THAT(state.NeedsToAbort(), IsFalse());
566     EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
567     EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
568     EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
569     EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
570     if (i < 3) {
571       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
572       EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
573     } else {
574       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
575       EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
576     }
577     if (i < 4) {
578       // Have one client send the right vector of "encrypted keys" to the
579       // server.
580       ClientToServerWrapperMessage client_message;
581       for (int j = 0; j < 4; ++j) {
582         if (i == j) {
583           client_message.mutable_share_keys_response()
584               ->add_encrypted_key_shares("");
585         } else {
586           client_message.mutable_share_keys_response()
587               ->add_encrypted_key_shares(
588                   absl::StrCat("encrypted key shares from ", i, " to ", j));
589         }
590       }
591       EXPECT_CALL(*metrics, MessageReceivedSizes(
592                                 Eq(ClientToServerWrapperMessage::
593                                        MessageContentCase::kShareKeysResponse),
594                                 Eq(true), Eq(client_message.ByteSizeLong())));
595       ASSERT_THAT(state.HandleMessage(i, client_message), IsOk());
596       EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
597       EXPECT_THAT(tracing_recorder.root()[i],
598                   IsEvent<ClientMessageReceived>(
599                       Eq(ClientToServerMessageType_ShareKeysResponse),
600                       Eq(client_message.ByteSizeLong()), Eq(true), Ge(0)));
601     }
602   }
603   std::vector<ServerToClientWrapperMessage> server_messages(4);
604   for (int i = 0; i < 4; ++i) {
605     for (int j = 0; j < 4; ++j) {
606       if (i == j) {
607         server_messages[i]
608             .mutable_masked_input_request()
609             ->add_encrypted_key_shares("");
610       } else {
611         server_messages[i]
612             .mutable_masked_input_request()
613             ->add_encrypted_key_shares(
614                 absl::StrCat("encrypted key shares from ", j, " to ", i));
615       }
616     }
617     EXPECT_CALL(*sender, Send(Eq(i), EqualsProto(server_messages[i])));
618   }
619   EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
620   EXPECT_CALL(*metrics, BroadcastMessageSizes(_, _)).Times(0);
621   EXPECT_CALL(*metrics, IndividualMessageSizes(
622                             Eq(ServerToClientWrapperMessage::
623                                    MessageContentCase::kMaskedInputRequest),
624                             Eq(server_messages[0].ByteSizeLong())))
625       .Times(4);
626 
627   auto next_state = state.ProceedToNextRound();
628   ASSERT_THAT(next_state, IsOk());
629   EXPECT_THAT(next_state.value()->State(),
630               Eq(SecAggServerStateKind::R2_MASKED_INPUT_COLLECTION));
631   EXPECT_THAT(
632       next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
633       Eq(0));
634   EXPECT_THAT(
635       next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
636       Eq(0));
637   EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
638               Eq(0));
639   EXPECT_THAT(
640       tracing_recorder.FindAllEvents<IndividualMessageSent>(),
641       ElementsAre(IsEvent<IndividualMessageSent>(
642                       0, Eq(ServerToClientMessageType_MaskedInputRequest),
643                       Eq(server_messages[0].ByteSizeLong())),
644                   IsEvent<IndividualMessageSent>(
645                       1, Eq(ServerToClientMessageType_MaskedInputRequest),
646                       Eq(server_messages[1].ByteSizeLong())),
647                   IsEvent<IndividualMessageSent>(
648                       2, Eq(ServerToClientMessageType_MaskedInputRequest),
649                       Eq(server_messages[2].ByteSizeLong())),
650                   IsEvent<IndividualMessageSent>(
651                       3, Eq(ServerToClientMessageType_MaskedInputRequest),
652                       Eq(server_messages[3].ByteSizeLong()))));
653 }
654 
TEST(SecaggServerR1ShareKeysStateTest,ServerAndClientAbortsAreRecordedCorrectly)655 TEST(SecaggServerR1ShareKeysStateTest,
656      ServerAndClientAbortsAreRecordedCorrectly) {
657   // In this test clients abort for a variety of reasons, and then ultimately
658   // the server aborts. Metrics should record all of these events.
659   MockSecAggServerMetricsListener* metrics =
660       new MockSecAggServerMetricsListener();
661   auto sender = std::make_shared<MockSendToClientsInterface>();
662 
663   SecAggServerR1ShareKeysState state(
664       CreateSecAggServerProtocolImpl(2, 7, sender.get(), metrics),
665       0,  // number_of_clients_failed_after_sending_masked_input
666       0,  // number_of_clients_failed_before_sending_masked_input
667       0   // number_of_clients_terminated_without_unmasking
668   );
669 
670   EXPECT_CALL(
671       *metrics,
672       ClientsDropped(Eq(ClientStatus::DEAD_AFTER_ADVERTISE_KEYS_RECEIVED),
673                      Eq(ClientDropReason::SENT_ABORT_MESSAGE)));
674   EXPECT_CALL(
675       *metrics,
676       ClientsDropped(Eq(ClientStatus::DEAD_AFTER_ADVERTISE_KEYS_RECEIVED),
677                      Eq(ClientDropReason::SHARE_KEYS_UNEXPECTED)));
678   EXPECT_CALL(
679       *metrics,
680       ClientsDropped(Eq(ClientStatus::DEAD_AFTER_ADVERTISE_KEYS_RECEIVED),
681                      Eq(ClientDropReason::UNEXPECTED_MESSAGE_TYPE)));
682   EXPECT_CALL(
683       *metrics,
684       ClientsDropped(Eq(ClientStatus::DEAD_AFTER_ADVERTISE_KEYS_RECEIVED),
685                      Eq(ClientDropReason::INVALID_SHARE_KEYS_RESPONSE)))
686       .Times(3);
687   EXPECT_CALL(
688       *metrics,
689       ProtocolOutcomes(Eq(SecAggServerOutcome::NOT_ENOUGH_CLIENTS_REMAINING)));
690 
691   ClientToServerWrapperMessage abort_message;
692   abort_message.mutable_abort()->set_diagnostic_info("Aborting for test");
693   ClientToServerWrapperMessage valid_message;  // from client 1
694   for (int j = 0; j < 7; ++j) {
695     if (1 == j) {
696       valid_message.mutable_share_keys_response()->add_encrypted_key_shares("");
697     } else {
698       valid_message.mutable_share_keys_response()->add_encrypted_key_shares(
699           absl::StrCat("encrypted key shares from ", 1, " to ", j));
700     }
701   }
702 
703   ClientToServerWrapperMessage invalid_message_wrong_number;  // from client 2
704   for (int j = 0; j <= 7; ++j) {  // goes one past the end
705     if (2 == j) {
706       invalid_message_wrong_number.mutable_share_keys_response()
707           ->add_encrypted_key_shares("");
708     } else {
709       invalid_message_wrong_number.mutable_share_keys_response()
710           ->add_encrypted_key_shares(
711               absl::StrCat("encrypted key shares from ", 2, " to ", j));
712     }
713   }
714 
715   ClientToServerWrapperMessage invalid_message_missing_share;  // from client 3
716   for (int j = 0; j < 7; ++j) {
717     if (3 == j || 0 == j) {  // missing share for 0
718       invalid_message_missing_share.mutable_share_keys_response()
719           ->add_encrypted_key_shares("");
720     } else {
721       invalid_message_missing_share.mutable_share_keys_response()
722           ->add_encrypted_key_shares(
723               absl::StrCat("encrypted key shares from ", 3, " to ", j));
724     }
725   }
726 
727   ClientToServerWrapperMessage invalid_message_extra_share;  // from client 4
728   for (int j = 0; j < 7; ++j) {
729     // including share for self, which is wrong
730     invalid_message_extra_share.mutable_share_keys_response()
731         ->add_encrypted_key_shares(
732             absl::StrCat("encrypted key shares from ", 4, " to ", j));
733   }
734 
735   ClientToServerWrapperMessage wrong_message;
736   wrong_message.mutable_advertise_keys();  // wrong type of message
737 
738   state.HandleMessage(0, abort_message).IgnoreError();
739   state.HandleMessage(1, valid_message).IgnoreError();
740   state.HandleMessage(1, valid_message).IgnoreError();
741   state.HandleMessage(2, invalid_message_wrong_number).IgnoreError();
742   state.HandleMessage(3, invalid_message_missing_share).IgnoreError();
743   state.HandleMessage(4, invalid_message_extra_share).IgnoreError();
744   state.HandleMessage(5, wrong_message).IgnoreError();
745   state.ProceedToNextRound().IgnoreError();  // causes server abort
746 }
747 
TEST(SecaggServerR1ShareKeysStateTest,MetricsAreRecorded)748 TEST(SecaggServerR1ShareKeysStateTest, MetricsAreRecorded) {
749   // In this test, all clients send inputs for the correct clients, and then the
750   // server proceeds to the next state. (The inputs aren't actually encrypted
751   // shared keys, but that doesn't matter for this test.)
752   MockSecAggServerMetricsListener* metrics =
753       new MockSecAggServerMetricsListener();
754   auto sender = std::make_shared<MockSendToClientsInterface>();
755 
756   SecAggServerR1ShareKeysState state(
757       CreateSecAggServerProtocolImpl(3, 4, sender.get(), metrics),
758       0,  // number_of_clients_failed_after_sending_masked_input
759       0,  // number_of_clients_failed_before_sending_masked_input
760       0   // number_of_clients_terminated_without_unmasking
761   );
762 
763   EXPECT_CALL(*metrics, ClientResponseTimes(
764                             Eq(ClientToServerWrapperMessage::
765                                    MessageContentCase::kShareKeysResponse),
766                             Ge(0)))
767       .Times(4);
768 
769   for (int i = 0; i < 5; ++i) {
770     EXPECT_THAT(state.NeedsToAbort(), IsFalse());
771     EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
772     EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
773     EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
774     EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
775     if (i < 3) {
776       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
777       EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
778     } else {
779       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
780       EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
781     }
782     if (i < 4) {
783       // Have one client send the right vector of "encrypted keys" to the
784       // server.
785       ClientToServerWrapperMessage client_message;
786       for (int j = 0; j < 4; ++j) {
787         if (i == j) {
788           client_message.mutable_share_keys_response()
789               ->add_encrypted_key_shares("");
790         } else {
791           client_message.mutable_share_keys_response()
792               ->add_encrypted_key_shares(
793                   absl::StrCat("encrypted key shares from ", i, " to ", j));
794         }
795       }
796       ASSERT_THAT(state.HandleMessage(i, client_message), IsOk());
797       EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
798     }
799   }
800   std::vector<ServerToClientWrapperMessage> server_messages(4);
801   for (int i = 0; i < 4; ++i) {
802     for (int j = 0; j < 4; ++j) {
803       if (i == j) {
804         server_messages[i]
805             .mutable_masked_input_request()
806             ->add_encrypted_key_shares("");
807       } else {
808         server_messages[i]
809             .mutable_masked_input_request()
810             ->add_encrypted_key_shares(
811                 absl::StrCat("encrypted key shares from ", j, " to ", i));
812       }
813     }
814     EXPECT_CALL(*sender, Send(Eq(i), EqualsProto(server_messages[i]))).Times(1);
815   }
816   EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
817   EXPECT_CALL(*metrics, RoundTimes(Eq(SecAggServerStateKind::R1_SHARE_KEYS),
818                                    Eq(true), Ge(0)));
819   EXPECT_CALL(*metrics, RoundSurvivingClients(
820                             Eq(SecAggServerStateKind::R1_SHARE_KEYS), Eq(4)));
821 
822   auto next_state = state.ProceedToNextRound();
823   ASSERT_THAT(next_state, IsOk());
824   EXPECT_THAT(next_state.value()->State(),
825               Eq(SecAggServerStateKind::R2_MASKED_INPUT_COLLECTION));
826 }
827 }  // namespace
828 }  // namespace secagg
829 }  // namespace fcp
830