1 /*
2 * Copyright 2019 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fcp/secagg/server/secagg_server_r1_share_keys_state.h"
18
19 #include <memory>
20 #include <utility>
21
22 #include "gmock/gmock.h"
23 #include "gtest/gtest.h"
24 #include "absl/strings/str_cat.h"
25 #include "fcp/secagg/server/aes/aes_secagg_server_protocol_impl.h"
26 #include "fcp/secagg/server/secagg_server_state.h"
27 #include "fcp/secagg/server/secret_sharing_graph_factory.h"
28 #include "fcp/secagg/shared/aes_ctr_prng_factory.h"
29 #include "fcp/secagg/shared/compute_session_id.h"
30 #include "fcp/secagg/shared/ecdh_keys.h"
31 #include "fcp/secagg/shared/input_vector_specification.h"
32 #include "fcp/secagg/shared/secagg_messages.pb.h"
33 #include "fcp/secagg/shared/shamir_secret_sharing.h"
34 #include "fcp/secagg/testing/ecdh_pregenerated_test_keys.h"
35 #include "fcp/secagg/testing/server/mock_secagg_server_metrics_listener.h"
36 #include "fcp/secagg/testing/server/mock_send_to_clients_interface.h"
37 #include "fcp/testing/testing.h"
38 #include "fcp/tracing/test_tracing_recorder.h"
39
40 namespace fcp {
41 namespace secagg {
42 namespace {
43
44 using ::testing::_;
45 using ::testing::Eq;
46 using ::testing::Ge;
47 using ::testing::IsFalse;
48 using ::testing::IsTrue;
49
50 // Default test session_id.
51 SessionId session_id = {"session id number, 32 bytes long"};
52
CreateSecAggServerProtocolImpl(int minimum_number_of_clients_to_proceed,int total_number_of_clients,MockSendToClientsInterface * sender,MockSecAggServerMetricsListener * metrics_listener=nullptr)53 std::unique_ptr<AesSecAggServerProtocolImpl> CreateSecAggServerProtocolImpl(
54 int minimum_number_of_clients_to_proceed, int total_number_of_clients,
55 MockSendToClientsInterface* sender,
56 MockSecAggServerMetricsListener* metrics_listener = nullptr) {
57 auto input_vector_specs = std::vector<InputVectorSpecification>();
58 SecretSharingGraphFactory factory;
59 input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
60 auto impl = std::make_unique<AesSecAggServerProtocolImpl>(
61 factory.CreateCompleteGraph(total_number_of_clients,
62 minimum_number_of_clients_to_proceed),
63 minimum_number_of_clients_to_proceed, input_vector_specs,
64 std::unique_ptr<MockSecAggServerMetricsListener>(metrics_listener),
65 std::make_unique<AesCtrPrngFactory>(), sender,
66 std::make_unique<SecAggScheduler>(
67 /*sequential_scheduler=*/nullptr,
68 /*parallel_scheduler=*/nullptr),
69 std::vector<ClientStatus>(total_number_of_clients,
70 ClientStatus::ADVERTISE_KEYS_RECEIVED),
71 ServerVariant::NATIVE_V1);
72 impl->set_session_id(std::make_unique<SessionId>(session_id));
73 EcdhPregeneratedTestKeys ecdh_keys;
74 for (int i = 0; i < total_number_of_clients; i++) {
75 impl->SetPairwisePublicKeys(i, ecdh_keys.GetPublicKey(i));
76 }
77 return impl;
78 }
79
TEST(SecaggServerR1ShareKeysStateTest,IsAbortedReturnsFalse)80 TEST(SecaggServerR1ShareKeysStateTest, IsAbortedReturnsFalse) {
81 auto sender = std::make_shared<MockSendToClientsInterface>();
82
83 SecAggServerR1ShareKeysState state(
84 CreateSecAggServerProtocolImpl(3, 4, sender.get()),
85 0, // number_of_clients_failed_after_sending_masked_input
86 0, // number_of_clients_failed_before_sending_masked_input
87 0 // number_of_clients_terminated_without_unmasking
88 );
89
90 EXPECT_THAT(state.IsAborted(), IsFalse());
91 }
92
TEST(SecaggServerR1ShareKeysStateTest,IsCompletedSuccessfullyReturnsFalse)93 TEST(SecaggServerR1ShareKeysStateTest, IsCompletedSuccessfullyReturnsFalse) {
94 auto sender = std::make_shared<MockSendToClientsInterface>();
95
96 SecAggServerR1ShareKeysState state(
97 CreateSecAggServerProtocolImpl(3, 4, sender.get()),
98 0, // number_of_clients_failed_after_sending_masked_input
99 0, // number_of_clients_failed_before_sending_masked_input
100 0 // number_of_clients_terminated_without_unmasking
101 );
102
103 EXPECT_THAT(state.IsCompletedSuccessfully(), IsFalse());
104 }
105
TEST(SecaggServerR1ShareKeysStateTest,ErrorMessageRaisesErrorStatus)106 TEST(SecaggServerR1ShareKeysStateTest, ErrorMessageRaisesErrorStatus) {
107 auto sender = std::make_shared<MockSendToClientsInterface>();
108
109 SecAggServerR1ShareKeysState state(
110 CreateSecAggServerProtocolImpl(3, 4, sender.get()),
111 0, // number_of_clients_failed_after_sending_masked_input
112 0, // number_of_clients_failed_before_sending_masked_input
113 0 // number_of_clients_terminated_without_unmasking
114 );
115
116 EXPECT_THAT(state.ErrorMessage().ok(), IsFalse());
117 }
118
TEST(SecaggServerR1ShareKeysStateTest,ResultRaisesErrorStatus)119 TEST(SecaggServerR1ShareKeysStateTest, ResultRaisesErrorStatus) {
120 auto sender = std::make_shared<MockSendToClientsInterface>();
121
122 SecAggServerR1ShareKeysState state(
123 CreateSecAggServerProtocolImpl(3, 4, sender.get()),
124 0, // number_of_clients_failed_after_sending_masked_input
125 0, // number_of_clients_failed_before_sending_masked_input
126 0 // number_of_clients_terminated_without_unmasking
127 );
128
129 EXPECT_THAT(state.Result().ok(), IsFalse());
130 }
131
TEST(SecaggServerR1ShareKeysStateTest,AbortReturnsValidStateAndNotifiesClients)132 TEST(SecaggServerR1ShareKeysStateTest,
133 AbortReturnsValidStateAndNotifiesClients) {
134 TestTracingRecorder tracing_recorder;
135 MockSecAggServerMetricsListener* metrics =
136 new MockSecAggServerMetricsListener();
137 auto sender = std::make_shared<MockSendToClientsInterface>();
138
139 SecAggServerR1ShareKeysState state(
140 CreateSecAggServerProtocolImpl(3, 4, sender.get(), metrics),
141 0, // number_of_clients_failed_after_sending_masked_input
142 0, // number_of_clients_failed_before_sending_masked_input
143 0 // number_of_clients_terminated_without_unmasking
144 );
145
146 ServerToClientWrapperMessage abort_message;
147 abort_message.mutable_abort()->set_early_success(false);
148 abort_message.mutable_abort()->set_diagnostic_info("test abort reason");
149
150 EXPECT_CALL(*metrics,
151 ProtocolOutcomes(Eq(SecAggServerOutcome::EXTERNAL_REQUEST)));
152 EXPECT_CALL(*sender, SendBroadcast(EqualsProto(abort_message)));
153 auto next_state =
154 state.Abort("test abort reason", SecAggServerOutcome::EXTERNAL_REQUEST);
155
156 ASSERT_THAT(next_state->State(), Eq(SecAggServerStateKind::ABORTED));
157 ASSERT_THAT(next_state->ErrorMessage(), IsOk());
158 EXPECT_THAT(next_state->ErrorMessage().value(), Eq("test abort reason"));
159 EXPECT_THAT(tracing_recorder.FindAllEvents<BroadcastMessageSent>(),
160 ElementsAre(IsEvent<BroadcastMessageSent>(
161 Eq(ServerToClientMessageType_Abort),
162 Eq(abort_message.ByteSizeLong()))));
163 }
164
TEST(SecaggServerR1ShareKeysStateTest,StateProceedsCorrectlyWithAllClientsValid)165 TEST(SecaggServerR1ShareKeysStateTest,
166 StateProceedsCorrectlyWithAllClientsValid) {
167 // In this test, all clients send inputs for the correct clients, and then the
168 // server proceeds to the next state. (The inputs aren't actually encrypted
169 // shared keys, but that doesn't matter for this test.)
170 auto sender = std::make_shared<MockSendToClientsInterface>();
171
172 SecAggServerR1ShareKeysState state(
173 CreateSecAggServerProtocolImpl(3, 4, sender.get()),
174 0, // number_of_clients_failed_after_sending_masked_input
175 0, // number_of_clients_failed_before_sending_masked_input
176 0 // number_of_clients_terminated_without_unmasking
177 );
178
179 for (int i = 0; i < 5; ++i) {
180 EXPECT_THAT(state.NeedsToAbort(), IsFalse());
181 EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
182 EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
183 EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
184 EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
185 if (i < 3) {
186 EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
187 EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
188 } else {
189 EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
190 EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
191 }
192 if (i < 4) {
193 // Have one client send the right vector of "encrypted keys" to the
194 // server.
195 ClientToServerWrapperMessage client_message;
196 for (int j = 0; j < 4; ++j) {
197 if (i == j) {
198 client_message.mutable_share_keys_response()
199 ->add_encrypted_key_shares("");
200 } else {
201 client_message.mutable_share_keys_response()
202 ->add_encrypted_key_shares(
203 absl::StrCat("encrypted key shares from ", i, " to ", j));
204 }
205 }
206 ASSERT_THAT(state.HandleMessage(i, client_message), IsOk());
207 EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
208 }
209 }
210 std::vector<ServerToClientWrapperMessage> server_messages(4);
211 for (int i = 0; i < 4; ++i) {
212 for (int j = 0; j < 4; ++j) {
213 if (i == j) {
214 server_messages[i]
215 .mutable_masked_input_request()
216 ->add_encrypted_key_shares("");
217 } else {
218 server_messages[i]
219 .mutable_masked_input_request()
220 ->add_encrypted_key_shares(
221 absl::StrCat("encrypted key shares from ", j, " to ", i));
222 }
223 }
224 EXPECT_CALL(*sender, Send(Eq(i), EqualsProto(server_messages[i]))).Times(1);
225 }
226 EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
227
228 auto next_state = state.ProceedToNextRound();
229 ASSERT_THAT(next_state, IsOk());
230 EXPECT_THAT(next_state.value()->State(),
231 Eq(SecAggServerStateKind::R2_MASKED_INPUT_COLLECTION));
232 EXPECT_THAT(
233 next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
234 Eq(0));
235 EXPECT_THAT(
236 next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
237 Eq(0));
238 EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
239 Eq(0));
240 }
241
TEST(SecaggServerR1ShareKeysStateTest,StateProceedsCorrectlyWithOnePreviousDropout)242 TEST(SecaggServerR1ShareKeysStateTest,
243 StateProceedsCorrectlyWithOnePreviousDropout) {
244 // In this test, client 3 dropped out in round 0, so clients should not send
245 // key shares for it. All other clients proceed normally.
246 auto sender = std::make_shared<MockSendToClientsInterface>();
247 auto impl = CreateSecAggServerProtocolImpl(3, 4, sender.get());
248 impl->set_client_status(3, ClientStatus::DEAD_BEFORE_SENDING_ANYTHING);
249
250 SecAggServerR1ShareKeysState state(
251 std::move(impl),
252 0, // number_of_clients_failed_after_sending_masked_input
253 1, // number_of_clients_failed_before_sending_masked_input
254 0 // number_of_clients_terminated_without_unmasking
255 );
256
257 for (int i = 0; i < 4; ++i) {
258 EXPECT_THAT(state.NeedsToAbort(), IsFalse());
259 EXPECT_THAT(state.NumberOfAliveClients(), Eq(3));
260 EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
261 EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
262 EXPECT_THAT(state.NumberOfPendingClients(), Eq(3 - i));
263 if (i < 3) {
264 EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
265 EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
266 } else {
267 EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
268 EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
269 }
270 if (i < 3) {
271 // Have one client send the right vector of "encrypted keys" to the
272 // server.
273 ClientToServerWrapperMessage client_message;
274 for (int j = 0; j < 4; ++j) {
275 if (i == j || j == 3) {
276 client_message.mutable_share_keys_response()
277 ->add_encrypted_key_shares("");
278 } else {
279 client_message.mutable_share_keys_response()
280 ->add_encrypted_key_shares(
281 absl::StrCat("encrypted key shares from ", i, " to ", j));
282 }
283 }
284 ASSERT_THAT(state.HandleMessage(i, client_message), IsOk());
285 EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
286 }
287 }
288 std::vector<ServerToClientWrapperMessage> server_messages(3);
289 for (int i = 0; i < 3; ++i) {
290 for (int j = 0; j < 4; ++j) {
291 if (i == j || j == 3) {
292 server_messages[i]
293 .mutable_masked_input_request()
294 ->add_encrypted_key_shares("");
295 } else {
296 server_messages[i]
297 .mutable_masked_input_request()
298 ->add_encrypted_key_shares(
299 absl::StrCat("encrypted key shares from ", j, " to ", i));
300 }
301 }
302 EXPECT_CALL(*sender, Send(Eq(i), EqualsProto(server_messages[i]))).Times(1);
303 }
304 EXPECT_CALL(*sender, Send(Eq(3), _)).Times(0);
305 EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
306
307 auto next_state = state.ProceedToNextRound();
308 ASSERT_THAT(next_state, IsOk());
309 EXPECT_THAT(next_state.value()->State(),
310 Eq(SecAggServerStateKind::R2_MASKED_INPUT_COLLECTION));
311 EXPECT_THAT(
312 next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
313 Eq(0));
314 EXPECT_THAT(
315 next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
316 Eq(1));
317 EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
318 Eq(0));
319 }
320
TEST(SecaggServerR1ShareKeysStateTest,StateProceedsCorrectlyWithAnAbortAfterSendingShares)321 TEST(SecaggServerR1ShareKeysStateTest,
322 StateProceedsCorrectlyWithAnAbortAfterSendingShares) {
323 // In this test, all clients send inputs for the correct clients, but then
324 // client 2 aborts. This should cause that client's message shared keys not to
325 // appear in the messages sent later.
326 auto sender = std::make_shared<MockSendToClientsInterface>();
327
328 SecAggServerR1ShareKeysState state(
329 CreateSecAggServerProtocolImpl(3, 4, sender.get()),
330 0, // number_of_clients_failed_after_sending_masked_input
331 0, // number_of_clients_failed_before_sending_masked_input
332 0 // number_of_clients_terminated_without_unmasking
333 );
334
335 for (int i = 0; i < 5; ++i) {
336 EXPECT_THAT(state.NeedsToAbort(), IsFalse());
337 EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
338 EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
339 EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
340 EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
341 if (i < 3) {
342 EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
343 EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
344 } else {
345 EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
346 EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
347 }
348 if (i < 4) {
349 // Have one client send the right vector of "encrypted key shares" to
350 // the server.
351 ClientToServerWrapperMessage client_message;
352 for (int j = 0; j < 4; ++j) {
353 if (i == j) {
354 client_message.mutable_share_keys_response()
355 ->add_encrypted_key_shares("");
356 } else {
357 client_message.mutable_share_keys_response()
358 ->add_encrypted_key_shares(
359 absl::StrCat("encrypted key shares from ", i, " to ", j));
360 }
361 }
362 ASSERT_THAT(state.HandleMessage(i, client_message), IsOk());
363 EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
364 }
365 }
366
367 ClientToServerWrapperMessage abort_message;
368 abort_message.mutable_abort()->set_diagnostic_info("aborting for test");
369 ASSERT_THAT(state.HandleMessage(2, abort_message), IsOk());
370 EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
371
372 std::vector<ServerToClientWrapperMessage> server_messages(4);
373 for (int i = 0; i < 4; ++i) {
374 if (i == 2) {
375 EXPECT_CALL(*sender, Send(Eq(2), _)).Times(0);
376 continue;
377 }
378 for (int j = 0; j < 4; ++j) {
379 if (i == j || j == 2) {
380 server_messages[i]
381 .mutable_masked_input_request()
382 ->add_encrypted_key_shares("");
383 } else {
384 server_messages[i]
385 .mutable_masked_input_request()
386 ->add_encrypted_key_shares(
387 absl::StrCat("encrypted key shares from ", j, " to ", i));
388 }
389 }
390 EXPECT_CALL(*sender, Send(Eq(i), EqualsProto(server_messages[i]))).Times(1);
391 }
392 EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
393
394 auto next_state = state.ProceedToNextRound();
395 ASSERT_THAT(next_state, IsOk());
396 EXPECT_THAT(next_state.value()->State(),
397 Eq(SecAggServerStateKind::R2_MASKED_INPUT_COLLECTION));
398 EXPECT_THAT(
399 next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
400 Eq(0));
401 EXPECT_THAT(
402 next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
403 Eq(1));
404 EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
405 Eq(0));
406 }
407
TEST(SecaggServerR1ShareKeysStateTest,StateProceedsCorrectlyWithOneClientSendingInvalidShares)408 TEST(SecaggServerR1ShareKeysStateTest,
409 StateProceedsCorrectlyWithOneClientSendingInvalidShares) {
410 // In this test, all clients send encrypted shares, but client 0 omits an
411 // encrypted share for client 1. This should force client 0 to abort.
412 auto sender = std::make_shared<MockSendToClientsInterface>();
413
414 SecAggServerR1ShareKeysState state(
415 CreateSecAggServerProtocolImpl(3, 4, sender.get()),
416 0, // number_of_clients_failed_after_sending_masked_input
417 0, // number_of_clients_failed_before_sending_masked_input
418 0 // number_of_clients_terminated_without_unmasking
419 );
420
421 std::vector<ServerToClientWrapperMessage> server_messages(4);
422 server_messages[0].mutable_abort()->set_early_success(false);
423 server_messages[0].mutable_abort()->set_diagnostic_info(
424 "Client omitted a key share that was expected.");
425 EXPECT_CALL(*sender, Send(Eq(0), EqualsProto(server_messages[0]))).Times(1);
426 for (int i = 1; i < 4; ++i) {
427 for (int j = 0; j < 4; ++j) {
428 if (i == j || j == 0) {
429 server_messages[i]
430 .mutable_masked_input_request()
431 ->add_encrypted_key_shares("");
432 } else {
433 server_messages[i]
434 .mutable_masked_input_request()
435 ->add_encrypted_key_shares(
436 absl::StrCat("encrypted key shares from ", j, " to ", i));
437 }
438 }
439 EXPECT_CALL(*sender, Send(Eq(i), EqualsProto(server_messages[i]))).Times(1);
440 }
441 EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
442
443 ClientToServerWrapperMessage bad_message;
444 bad_message.mutable_share_keys_response()->add_encrypted_key_shares("");
445 bad_message.mutable_share_keys_response()->add_encrypted_key_shares("");
446 bad_message.mutable_share_keys_response()->add_encrypted_key_shares(
447 "encrypted key shares from 0 to 2");
448 bad_message.mutable_share_keys_response()->add_encrypted_key_shares(
449 "encrypted key shares from 0 to 3");
450 ASSERT_THAT(state.HandleMessage(0, bad_message), IsOk());
451 EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
452
453 for (int i = 1; i < 5; ++i) {
454 EXPECT_THAT(state.NeedsToAbort(), IsFalse());
455 EXPECT_THAT(state.NumberOfAliveClients(), Eq(3));
456 EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i - 1));
457 EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i - 1));
458 EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
459 if (i < 4) {
460 EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(4 - i));
461 EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
462 } else {
463 EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
464 EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
465 }
466 if (i < 4) {
467 // Have one client send the right vector of "encrypted key shares" to
468 // the server.
469 ClientToServerWrapperMessage client_message;
470 for (int j = 0; j < 4; ++j) {
471 if (i == j) {
472 client_message.mutable_share_keys_response()
473 ->add_encrypted_key_shares("");
474 } else {
475 client_message.mutable_share_keys_response()
476 ->add_encrypted_key_shares(
477 absl::StrCat("encrypted key shares from ", i, " to ", j));
478 }
479 }
480 ASSERT_THAT(state.HandleMessage(i, client_message), IsOk());
481 EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 3));
482 }
483 }
484
485 auto next_state = state.ProceedToNextRound();
486 ASSERT_THAT(next_state, IsOk());
487 EXPECT_THAT(next_state.value()->State(),
488 Eq(SecAggServerStateKind::R2_MASKED_INPUT_COLLECTION));
489 EXPECT_THAT(
490 next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
491 Eq(0));
492 EXPECT_THAT(
493 next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
494 Eq(1));
495 EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
496 Eq(0));
497 }
498
TEST(SecaggServerR1ShareKeysStateTest,StateAbortsIfTooManyClientsAbort)499 TEST(SecaggServerR1ShareKeysStateTest, StateAbortsIfTooManyClientsAbort) {
500 // In this test, clients 0 and 1 send abort messages. This should cause the
501 // server state to register that it needs to abort immediately.
502 TestTracingRecorder tracing_recorder;
503 auto sender = std::make_shared<MockSendToClientsInterface>();
504
505 SecAggServerR1ShareKeysState state(
506 CreateSecAggServerProtocolImpl(3, 4, sender.get()),
507 0, // number_of_clients_failed_after_sending_masked_input
508 0, // number_of_clients_failed_before_sending_masked_input
509 0 // number_of_clients_terminated_without_unmasking
510 );
511
512 for (int i = 0; i < 3; ++i) {
513 EXPECT_THAT(state.NeedsToAbort(), Eq(i >= 2));
514 EXPECT_THAT(state.NumberOfAliveClients(), Eq(4 - i));
515 EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(0));
516 EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(0));
517 EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
518 EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3));
519 EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
520 if (i < 2) {
521 // Have client abort
522 ClientToServerWrapperMessage abort_message;
523 abort_message.mutable_abort()->set_diagnostic_info("Aborting for test");
524 ASSERT_THAT(state.HandleMessage(i, abort_message), IsOk());
525 EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 1));
526 }
527 }
528
529 ServerToClientWrapperMessage server_message;
530 server_message.mutable_abort()->set_early_success(false);
531 server_message.mutable_abort()->set_diagnostic_info(
532 "Too many clients aborted.");
533 EXPECT_CALL(*sender, SendBroadcast(EqualsProto(server_message))).Times(1);
534 EXPECT_CALL(*sender, Send(_, _)).Times(0);
535
536 auto next_state = state.ProceedToNextRound();
537 ASSERT_THAT(next_state, IsOk());
538 EXPECT_THAT(next_state.value()->State(), Eq(SecAggServerStateKind::ABORTED));
539 ASSERT_THAT(next_state.value()->ErrorMessage(), IsOk());
540 EXPECT_THAT(next_state.value()->ErrorMessage().value(),
541 Eq("Too many clients aborted."));
542 EXPECT_THAT(tracing_recorder.FindAllEvents<BroadcastMessageSent>(),
543 ElementsAre(IsEvent<BroadcastMessageSent>(
544 Eq(ServerToClientMessageType_Abort),
545 Eq(server_message.ByteSizeLong()))));
546 }
547
TEST(SecaggServerR1ShareKeysStateTest,MetricsRecordsMessageSizes)548 TEST(SecaggServerR1ShareKeysStateTest, MetricsRecordsMessageSizes) {
549 // In this test, all clients send inputs for the correct clients, and then the
550 // server proceeds to the next state. (The inputs aren't actually encrypted
551 // shared keys, but that doesn't matter for this test.)
552 TestTracingRecorder tracing_recorder;
553 MockSecAggServerMetricsListener* metrics =
554 new MockSecAggServerMetricsListener();
555 auto sender = std::make_shared<MockSendToClientsInterface>();
556
557 SecAggServerR1ShareKeysState state(
558 CreateSecAggServerProtocolImpl(3, 4, sender.get(), metrics),
559 0, // number_of_clients_failed_after_sending_masked_input
560 0, // number_of_clients_failed_before_sending_masked_input
561 0 // number_of_clients_terminated_without_unmasking
562 );
563
564 for (int i = 0; i < 5; ++i) {
565 EXPECT_THAT(state.NeedsToAbort(), IsFalse());
566 EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
567 EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
568 EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
569 EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
570 if (i < 3) {
571 EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
572 EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
573 } else {
574 EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
575 EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
576 }
577 if (i < 4) {
578 // Have one client send the right vector of "encrypted keys" to the
579 // server.
580 ClientToServerWrapperMessage client_message;
581 for (int j = 0; j < 4; ++j) {
582 if (i == j) {
583 client_message.mutable_share_keys_response()
584 ->add_encrypted_key_shares("");
585 } else {
586 client_message.mutable_share_keys_response()
587 ->add_encrypted_key_shares(
588 absl::StrCat("encrypted key shares from ", i, " to ", j));
589 }
590 }
591 EXPECT_CALL(*metrics, MessageReceivedSizes(
592 Eq(ClientToServerWrapperMessage::
593 MessageContentCase::kShareKeysResponse),
594 Eq(true), Eq(client_message.ByteSizeLong())));
595 ASSERT_THAT(state.HandleMessage(i, client_message), IsOk());
596 EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
597 EXPECT_THAT(tracing_recorder.root()[i],
598 IsEvent<ClientMessageReceived>(
599 Eq(ClientToServerMessageType_ShareKeysResponse),
600 Eq(client_message.ByteSizeLong()), Eq(true), Ge(0)));
601 }
602 }
603 std::vector<ServerToClientWrapperMessage> server_messages(4);
604 for (int i = 0; i < 4; ++i) {
605 for (int j = 0; j < 4; ++j) {
606 if (i == j) {
607 server_messages[i]
608 .mutable_masked_input_request()
609 ->add_encrypted_key_shares("");
610 } else {
611 server_messages[i]
612 .mutable_masked_input_request()
613 ->add_encrypted_key_shares(
614 absl::StrCat("encrypted key shares from ", j, " to ", i));
615 }
616 }
617 EXPECT_CALL(*sender, Send(Eq(i), EqualsProto(server_messages[i])));
618 }
619 EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
620 EXPECT_CALL(*metrics, BroadcastMessageSizes(_, _)).Times(0);
621 EXPECT_CALL(*metrics, IndividualMessageSizes(
622 Eq(ServerToClientWrapperMessage::
623 MessageContentCase::kMaskedInputRequest),
624 Eq(server_messages[0].ByteSizeLong())))
625 .Times(4);
626
627 auto next_state = state.ProceedToNextRound();
628 ASSERT_THAT(next_state, IsOk());
629 EXPECT_THAT(next_state.value()->State(),
630 Eq(SecAggServerStateKind::R2_MASKED_INPUT_COLLECTION));
631 EXPECT_THAT(
632 next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
633 Eq(0));
634 EXPECT_THAT(
635 next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
636 Eq(0));
637 EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
638 Eq(0));
639 EXPECT_THAT(
640 tracing_recorder.FindAllEvents<IndividualMessageSent>(),
641 ElementsAre(IsEvent<IndividualMessageSent>(
642 0, Eq(ServerToClientMessageType_MaskedInputRequest),
643 Eq(server_messages[0].ByteSizeLong())),
644 IsEvent<IndividualMessageSent>(
645 1, Eq(ServerToClientMessageType_MaskedInputRequest),
646 Eq(server_messages[1].ByteSizeLong())),
647 IsEvent<IndividualMessageSent>(
648 2, Eq(ServerToClientMessageType_MaskedInputRequest),
649 Eq(server_messages[2].ByteSizeLong())),
650 IsEvent<IndividualMessageSent>(
651 3, Eq(ServerToClientMessageType_MaskedInputRequest),
652 Eq(server_messages[3].ByteSizeLong()))));
653 }
654
TEST(SecaggServerR1ShareKeysStateTest,ServerAndClientAbortsAreRecordedCorrectly)655 TEST(SecaggServerR1ShareKeysStateTest,
656 ServerAndClientAbortsAreRecordedCorrectly) {
657 // In this test clients abort for a variety of reasons, and then ultimately
658 // the server aborts. Metrics should record all of these events.
659 MockSecAggServerMetricsListener* metrics =
660 new MockSecAggServerMetricsListener();
661 auto sender = std::make_shared<MockSendToClientsInterface>();
662
663 SecAggServerR1ShareKeysState state(
664 CreateSecAggServerProtocolImpl(2, 7, sender.get(), metrics),
665 0, // number_of_clients_failed_after_sending_masked_input
666 0, // number_of_clients_failed_before_sending_masked_input
667 0 // number_of_clients_terminated_without_unmasking
668 );
669
670 EXPECT_CALL(
671 *metrics,
672 ClientsDropped(Eq(ClientStatus::DEAD_AFTER_ADVERTISE_KEYS_RECEIVED),
673 Eq(ClientDropReason::SENT_ABORT_MESSAGE)));
674 EXPECT_CALL(
675 *metrics,
676 ClientsDropped(Eq(ClientStatus::DEAD_AFTER_ADVERTISE_KEYS_RECEIVED),
677 Eq(ClientDropReason::SHARE_KEYS_UNEXPECTED)));
678 EXPECT_CALL(
679 *metrics,
680 ClientsDropped(Eq(ClientStatus::DEAD_AFTER_ADVERTISE_KEYS_RECEIVED),
681 Eq(ClientDropReason::UNEXPECTED_MESSAGE_TYPE)));
682 EXPECT_CALL(
683 *metrics,
684 ClientsDropped(Eq(ClientStatus::DEAD_AFTER_ADVERTISE_KEYS_RECEIVED),
685 Eq(ClientDropReason::INVALID_SHARE_KEYS_RESPONSE)))
686 .Times(3);
687 EXPECT_CALL(
688 *metrics,
689 ProtocolOutcomes(Eq(SecAggServerOutcome::NOT_ENOUGH_CLIENTS_REMAINING)));
690
691 ClientToServerWrapperMessage abort_message;
692 abort_message.mutable_abort()->set_diagnostic_info("Aborting for test");
693 ClientToServerWrapperMessage valid_message; // from client 1
694 for (int j = 0; j < 7; ++j) {
695 if (1 == j) {
696 valid_message.mutable_share_keys_response()->add_encrypted_key_shares("");
697 } else {
698 valid_message.mutable_share_keys_response()->add_encrypted_key_shares(
699 absl::StrCat("encrypted key shares from ", 1, " to ", j));
700 }
701 }
702
703 ClientToServerWrapperMessage invalid_message_wrong_number; // from client 2
704 for (int j = 0; j <= 7; ++j) { // goes one past the end
705 if (2 == j) {
706 invalid_message_wrong_number.mutable_share_keys_response()
707 ->add_encrypted_key_shares("");
708 } else {
709 invalid_message_wrong_number.mutable_share_keys_response()
710 ->add_encrypted_key_shares(
711 absl::StrCat("encrypted key shares from ", 2, " to ", j));
712 }
713 }
714
715 ClientToServerWrapperMessage invalid_message_missing_share; // from client 3
716 for (int j = 0; j < 7; ++j) {
717 if (3 == j || 0 == j) { // missing share for 0
718 invalid_message_missing_share.mutable_share_keys_response()
719 ->add_encrypted_key_shares("");
720 } else {
721 invalid_message_missing_share.mutable_share_keys_response()
722 ->add_encrypted_key_shares(
723 absl::StrCat("encrypted key shares from ", 3, " to ", j));
724 }
725 }
726
727 ClientToServerWrapperMessage invalid_message_extra_share; // from client 4
728 for (int j = 0; j < 7; ++j) {
729 // including share for self, which is wrong
730 invalid_message_extra_share.mutable_share_keys_response()
731 ->add_encrypted_key_shares(
732 absl::StrCat("encrypted key shares from ", 4, " to ", j));
733 }
734
735 ClientToServerWrapperMessage wrong_message;
736 wrong_message.mutable_advertise_keys(); // wrong type of message
737
738 state.HandleMessage(0, abort_message).IgnoreError();
739 state.HandleMessage(1, valid_message).IgnoreError();
740 state.HandleMessage(1, valid_message).IgnoreError();
741 state.HandleMessage(2, invalid_message_wrong_number).IgnoreError();
742 state.HandleMessage(3, invalid_message_missing_share).IgnoreError();
743 state.HandleMessage(4, invalid_message_extra_share).IgnoreError();
744 state.HandleMessage(5, wrong_message).IgnoreError();
745 state.ProceedToNextRound().IgnoreError(); // causes server abort
746 }
747
TEST(SecaggServerR1ShareKeysStateTest,MetricsAreRecorded)748 TEST(SecaggServerR1ShareKeysStateTest, MetricsAreRecorded) {
749 // In this test, all clients send inputs for the correct clients, and then the
750 // server proceeds to the next state. (The inputs aren't actually encrypted
751 // shared keys, but that doesn't matter for this test.)
752 MockSecAggServerMetricsListener* metrics =
753 new MockSecAggServerMetricsListener();
754 auto sender = std::make_shared<MockSendToClientsInterface>();
755
756 SecAggServerR1ShareKeysState state(
757 CreateSecAggServerProtocolImpl(3, 4, sender.get(), metrics),
758 0, // number_of_clients_failed_after_sending_masked_input
759 0, // number_of_clients_failed_before_sending_masked_input
760 0 // number_of_clients_terminated_without_unmasking
761 );
762
763 EXPECT_CALL(*metrics, ClientResponseTimes(
764 Eq(ClientToServerWrapperMessage::
765 MessageContentCase::kShareKeysResponse),
766 Ge(0)))
767 .Times(4);
768
769 for (int i = 0; i < 5; ++i) {
770 EXPECT_THAT(state.NeedsToAbort(), IsFalse());
771 EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
772 EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
773 EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
774 EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
775 if (i < 3) {
776 EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
777 EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
778 } else {
779 EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
780 EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
781 }
782 if (i < 4) {
783 // Have one client send the right vector of "encrypted keys" to the
784 // server.
785 ClientToServerWrapperMessage client_message;
786 for (int j = 0; j < 4; ++j) {
787 if (i == j) {
788 client_message.mutable_share_keys_response()
789 ->add_encrypted_key_shares("");
790 } else {
791 client_message.mutable_share_keys_response()
792 ->add_encrypted_key_shares(
793 absl::StrCat("encrypted key shares from ", i, " to ", j));
794 }
795 }
796 ASSERT_THAT(state.HandleMessage(i, client_message), IsOk());
797 EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
798 }
799 }
800 std::vector<ServerToClientWrapperMessage> server_messages(4);
801 for (int i = 0; i < 4; ++i) {
802 for (int j = 0; j < 4; ++j) {
803 if (i == j) {
804 server_messages[i]
805 .mutable_masked_input_request()
806 ->add_encrypted_key_shares("");
807 } else {
808 server_messages[i]
809 .mutable_masked_input_request()
810 ->add_encrypted_key_shares(
811 absl::StrCat("encrypted key shares from ", j, " to ", i));
812 }
813 }
814 EXPECT_CALL(*sender, Send(Eq(i), EqualsProto(server_messages[i]))).Times(1);
815 }
816 EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
817 EXPECT_CALL(*metrics, RoundTimes(Eq(SecAggServerStateKind::R1_SHARE_KEYS),
818 Eq(true), Ge(0)));
819 EXPECT_CALL(*metrics, RoundSurvivingClients(
820 Eq(SecAggServerStateKind::R1_SHARE_KEYS), Eq(4)));
821
822 auto next_state = state.ProceedToNextRound();
823 ASSERT_THAT(next_state, IsOk());
824 EXPECT_THAT(next_state.value()->State(),
825 Eq(SecAggServerStateKind::R2_MASKED_INPUT_COLLECTION));
826 }
827 } // namespace
828 } // namespace secagg
829 } // namespace fcp
830