xref: /aosp_15_r20/external/federated-compute/fcp/secagg/server/secagg_server_test.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2019 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/secagg/server/secagg_server.h"
18 
19 #include <cstddef>
20 #include <memory>
21 #include <set>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 
26 #include "gmock/gmock.h"
27 #include "gtest/gtest.h"
28 #include "fcp/base/monitoring.h"
29 #include "fcp/secagg/server/secagg_server_enums.pb.h"
30 #include "fcp/secagg/server/secagg_server_state.h"
31 #include "fcp/secagg/server/tracing_schema.h"
32 #include "fcp/secagg/shared/aes_ctr_prng_factory.h"
33 #include "fcp/secagg/shared/input_vector_specification.h"
34 #include "fcp/secagg/shared/secagg_messages.pb.h"
35 #include "fcp/secagg/testing/ecdh_pregenerated_test_keys.h"
36 #include "fcp/secagg/testing/server/mock_secagg_server_metrics_listener.h"
37 #include "fcp/secagg/testing/server/mock_send_to_clients_interface.h"
38 #include "fcp/secagg/testing/server/test_secagg_experiments.h"
39 #include "fcp/testing/testing.h"
40 #include "fcp/tracing/test_tracing_recorder.h"
41 
42 namespace fcp {
43 namespace secagg {
44 namespace {
45 
46 using ::testing::_;
47 using ::testing::Eq;
48 
CreateServer(SendToClientsInterface * sender,SecAggServerMetricsListener * metrics=new MockSecAggServerMetricsListener (),std::unique_ptr<TestSecAggExperiment> experiments=std::make_unique<TestSecAggExperiment> ())49 std::unique_ptr<SecAggServer> CreateServer(
50     SendToClientsInterface* sender,
51     SecAggServerMetricsListener* metrics =
52         new MockSecAggServerMetricsListener(),
53     std::unique_ptr<TestSecAggExperiment> experiments =
54         std::make_unique<TestSecAggExperiment>()) {
55   SecureAggregationRequirements threat_model;
56   threat_model.set_adversary_class(AdversaryClass::CURIOUS_SERVER);
57   threat_model.set_adversarial_client_rate(.3);
58   threat_model.set_estimated_dropout_rate(.3);
59   std::unique_ptr<AesPrngFactory> prng_factory;
60   std::vector<InputVectorSpecification> input_vector_specs;
61   input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
62   auto status_or_server = SecAggServer::Create(
63       100,   // minimum_number_of_clients_to_proceed
64       1000,  // total_number_of_clients
65       input_vector_specs, sender,
66       std::unique_ptr<SecAggServerMetricsListener>(metrics),
67       /*prng_runner=*/nullptr, std::move(experiments), threat_model);
68   EXPECT_THAT(status_or_server.ok(), true) << status_or_server.status();
69   return std::move(status_or_server.value());
70 }
71 
72 template <typename... M>
TraceRecorderHas(const M &...matchers)73 auto TraceRecorderHas(const M&... matchers) {
74   return ElementsAre(AllOf(
75       IsSpan<CreateSecAggServer>(),
76       ElementsAre(
77           IsEvent<SubGraphServerParameters>(
78               1000,    // number_of_clients
79               219,     // degree
80               116,     // threshold
81               700,     // minimum_number_of_clients_to_proceed
82               false),  // is_r2_async_aggregation_enabled
83           AllOf(IsSpan<SecureAggServerSession>(), ElementsAre(matchers...)))));
84 }
85 
TEST(SecaggServerTest,ConstructedWithCorrectState)86 TEST(SecaggServerTest, ConstructedWithCorrectState) {
87   TestTracingRecorder tracing_recorder;
88   auto sender = std::make_unique<MockSendToClientsInterface>();
89   auto server = CreateServer(sender.get());
90 
91   EXPECT_THAT(server->IsAborted(), Eq(false));
92   EXPECT_THAT(server->NumberOfNeighbors(), Eq(219));
93   EXPECT_THAT(server->IsCompletedSuccessfully(), Eq(false));
94   EXPECT_THAT(server->State(), Eq(SecAggServerStateKind::R0_ADVERTISE_KEYS));
95   EXPECT_THAT(tracing_recorder.root(),
96               TraceRecorderHas(IsSpan<SecureAggServerState>(
97                   SecAggServerTraceState_R0AdvertiseKeys)));
98 }
99 
TEST(SecaggServerTest,FullgraphSecAggExperimentTakesEffect)100 TEST(SecaggServerTest, FullgraphSecAggExperimentTakesEffect) {
101   // Tests FullgraphSecAggExperiment by instatiating
102   // a server under that experiment , and
103   // checking that it results in the expected number of neighbors for the given
104   // setting (1000 clients) and threat model (.3 dropout rate and .3 adversarial
105   // client rate).
106   SecureAggregationRequirements threat_model;
107   threat_model.set_adversary_class(AdversaryClass::CURIOUS_SERVER);
108   threat_model.set_adversarial_client_rate(.3);
109   threat_model.set_estimated_dropout_rate(.3);
110   std::unique_ptr<AesPrngFactory> prng_factory;
111   std::vector<InputVectorSpecification> input_vector_specs;
112   input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
113   auto sender = std::make_unique<MockSendToClientsInterface>();
114   std::set<std::string> experiment_names = {kFullgraphSecAggExperiment};
115   auto status_or_server = SecAggServer::Create(
116       100,   // minimum_number_of_clients_to_proceed
117       1000,  // total_number_of_clients
118       input_vector_specs, sender.get(),
119       std::unique_ptr<SecAggServerMetricsListener>(
120           new MockSecAggServerMetricsListener()),
121       /*prng_runner=*/nullptr,
122       std::make_unique<TestSecAggExperiment>(experiment_names), threat_model);
123   EXPECT_THAT(status_or_server.ok(), true) << status_or_server.status();
124   EXPECT_THAT(status_or_server.value()->NumberOfNeighbors(), Eq(1000));
125   EXPECT_THAT(status_or_server.value()->IsAborted(), Eq(false));
126   EXPECT_THAT(status_or_server.value()->IsCompletedSuccessfully(), Eq(false));
127   EXPECT_THAT(status_or_server.value()->State(),
128               Eq(SecAggServerStateKind::R0_ADVERTISE_KEYS));
129 }
130 
TEST(SecaggServerTest,SubgraphSecAggResortsToFullGraphOnSmallCohorts)131 TEST(SecaggServerTest, SubgraphSecAggResortsToFullGraphOnSmallCohorts) {
132   // Tests that a small number of clients for which subgraph-secagg does not
133   // have favorable parameters results in executiong the full-graph varian
134   SecureAggregationRequirements threat_model;
135   threat_model.set_adversary_class(AdversaryClass::CURIOUS_SERVER);
136   threat_model.set_adversarial_client_rate(.45);
137   threat_model.set_estimated_dropout_rate(.45);
138   std::unique_ptr<AesPrngFactory> prng_factory;
139   std::vector<InputVectorSpecification> input_vector_specs;
140   auto sender = std::make_unique<MockSendToClientsInterface>();
141   input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
142   std::set<std::string> experiment_names = {};
143   auto status_or_server = SecAggServer::Create(
144       5,   // minimum_number_of_clients_to_proceed
145       25,  // total_number_of_clients
146       input_vector_specs, sender.get(),
147       std::unique_ptr<SecAggServerMetricsListener>(
148           new MockSecAggServerMetricsListener()),
149       /*prng_runner=*/nullptr,
150       std::make_unique<TestSecAggExperiment>(experiment_names), threat_model);
151   EXPECT_THAT(status_or_server.ok(), true) << status_or_server.status();
152   EXPECT_THAT(status_or_server.value()->NumberOfNeighbors(), Eq(25));
153   EXPECT_THAT(
154       status_or_server.value()->MinimumSurvivingNeighborsForReconstruction(),
155       Eq(14));
156   EXPECT_THAT(status_or_server.value()->IsAborted(), Eq(false));
157   EXPECT_THAT(status_or_server.value()->IsCompletedSuccessfully(), Eq(false));
158   EXPECT_THAT(status_or_server.value()->State(),
159               Eq(SecAggServerStateKind::R0_ADVERTISE_KEYS));
160 }
161 
TEST(SecaggServerTest,AbortClientWithInvalidIdThrowsError)162 TEST(SecaggServerTest, AbortClientWithInvalidIdThrowsError) {
163   TestTracingRecorder tracing_recorder;
164   auto sender = std::make_unique<MockSendToClientsInterface>();
165   auto server = CreateServer(sender.get());
166 
167   EXPECT_THAT(
168       server->AbortClient(1001, ClientAbortReason::CONNECTION_DROPPED).code(),
169       Eq(FAILED_PRECONDITION));
170 }
171 
TEST(SecaggServerTest,ReceiveMessageWithInvalidIdThrowsError)172 TEST(SecaggServerTest, ReceiveMessageWithInvalidIdThrowsError) {
173   TestTracingRecorder tracing_recorder;
174   auto sender = std::make_unique<MockSendToClientsInterface>();
175   auto server = CreateServer(sender.get());
176 
177   ClientToServerWrapperMessage client_abort_message;
178   client_abort_message.mutable_abort()->set_diagnostic_info("Abort for test.");
179   EXPECT_THAT(
180       server
181           ->ReceiveMessage(1001, std::make_unique<ClientToServerWrapperMessage>(
182                                      client_abort_message))
183           .status()
184           .code(),
185       Eq(FAILED_PRECONDITION));
186 }
187 
TEST(SecaggServerTest,AbortCausesStateTransitionAndMessageToBeSent)188 TEST(SecaggServerTest, AbortCausesStateTransitionAndMessageToBeSent) {
189   TestTracingRecorder tracing_recorder;
190   auto sender = std::make_unique<MockSendToClientsInterface>();
191   auto server = CreateServer(sender.get());
192 
193   const ServerToClientWrapperMessage abort_message = PARSE_TEXT_PROTO(R"pb(
194     abort: {
195       early_success: false
196       diagnostic_info: "Abort upon external request."
197     })pb");
198 
199   EXPECT_CALL(*sender, SendBroadcast(EqualsProto(abort_message)));
200   Status result = server->Abort();
201 
202   EXPECT_THAT(result.code(), Eq(OK));
203   EXPECT_THAT(server->IsAborted(), Eq(true));
204   EXPECT_THAT(server->State(), Eq(SecAggServerStateKind::ABORTED));
205   ASSERT_THAT(server->ErrorMessage().ok(), Eq(true));
206   EXPECT_THAT(server->ErrorMessage().value(),
207               Eq("Abort upon external request."));
208   EXPECT_THAT(
209       tracing_recorder.root(),
210       TraceRecorderHas(
211           AllOf(IsSpan<SecureAggServerState>(
212                     SecAggServerTraceState_R0AdvertiseKeys),
213                 ElementsAre(
214                     IsSpan<AbortSecAggServer>("Abort upon external request."))),
215           IsSpan<SecureAggServerState>(SecAggServerTraceState_Aborted)));
216 }
217 
TEST(SecaggServerTest,AbortWithReasonCausesStateTransitionAndMessageToBeSent)218 TEST(SecaggServerTest, AbortWithReasonCausesStateTransitionAndMessageToBeSent) {
219   TestTracingRecorder tracing_recorder;
220   auto sender = std::make_unique<MockSendToClientsInterface>();
221   auto server = CreateServer(sender.get());
222 
223   const ServerToClientWrapperMessage abort_message = PARSE_TEXT_PROTO(R"pb(
224     abort: {
225       early_success: false
226       diagnostic_info: "Abort upon external request for reason <Test reason.>."
227     })pb");
228 
229   EXPECT_CALL(*sender, SendBroadcast(EqualsProto(abort_message)));
230   Status result =
231       server->Abort("Test reason.", SecAggServerOutcome::EXTERNAL_REQUEST);
232 
233   EXPECT_THAT(result.code(), Eq(OK));
234   EXPECT_THAT(server->IsAborted(), Eq(true));
235   EXPECT_THAT(server->State(), Eq(SecAggServerStateKind::ABORTED));
236   ASSERT_THAT(server->ErrorMessage().ok(), Eq(true));
237   EXPECT_THAT(server->ErrorMessage().value(),
238               Eq("Abort upon external request for reason <Test reason.>."));
239   EXPECT_THAT(
240       tracing_recorder.root(),
241       TraceRecorderHas(
242           AllOf(IsSpan<SecureAggServerState>(
243                     SecAggServerTraceState_R0AdvertiseKeys),
244                 ElementsAre(IsSpan<AbortSecAggServer>(
245                     "Abort upon external request for reason <Test "
246                     "reason.>."))),
247           IsSpan<SecureAggServerState>(SecAggServerTraceState_Aborted)));
248 }
249 
TEST(SecaggServerTest,AbortClientNotCheckedIn)250 TEST(SecaggServerTest, AbortClientNotCheckedIn) {
251   TestTracingRecorder tracing_recorder;
252   auto sender = std::make_unique<MockSendToClientsInterface>();
253   MockSecAggServerMetricsListener* metrics =
254       new MockSecAggServerMetricsListener();
255   auto server = CreateServer(sender.get(), metrics);
256 
257   EXPECT_CALL(*metrics, ClientsDropped(
258                             Eq(ClientStatus::DEAD_BEFORE_SENDING_ANYTHING),
259                             Eq(ClientDropReason::SERVER_PROTOCOL_ABORT_CLIENT)))
260       .Times(0);
261   // Client is not notified
262   EXPECT_CALL(*sender, Send(_, _)).Times(0);
263   Status result = server->AbortClient(2, ClientAbortReason::NOT_CHECKED_IN);
264 
265   EXPECT_THAT(result.code(), Eq(OK));
266   EXPECT_THAT(server->AbortedClientIds().contains(2), Eq(true));
267   EXPECT_THAT(
268       tracing_recorder.root(),
269       TraceRecorderHas(AllOf(
270           IsSpan<SecureAggServerState>(SecAggServerTraceState_R0AdvertiseKeys),
271           ElementsAre(IsSpan<AbortSecAggClient>(2, "NOT_CHECKED_IN")))));
272 }
273 
TEST(SecaggServerTest,AbortClientWhenConnectionDropped)274 TEST(SecaggServerTest, AbortClientWhenConnectionDropped) {
275   TestTracingRecorder tracing_recorder;
276   auto sender = std::make_unique<MockSendToClientsInterface>();
277   MockSecAggServerMetricsListener* metrics =
278       new MockSecAggServerMetricsListener();
279   auto server = CreateServer(sender.get(), metrics);
280 
281   EXPECT_CALL(*metrics,
282               ClientsDropped(Eq(ClientStatus::DEAD_BEFORE_SENDING_ANYTHING),
283                              Eq(ClientDropReason::CONNECTION_CLOSED)));
284   // Client is not notified
285   EXPECT_CALL(*sender, Send(_, _)).Times(0);
286   Status result = server->AbortClient(2, ClientAbortReason::CONNECTION_DROPPED);
287 
288   EXPECT_THAT(result.code(), Eq(OK));
289   EXPECT_THAT(server->AbortedClientIds().contains(2), Eq(true));
290   EXPECT_THAT(
291       tracing_recorder.root(),
292       TraceRecorderHas(AllOf(
293           IsSpan<SecureAggServerState>(SecAggServerTraceState_R0AdvertiseKeys),
294           ElementsAre(IsSpan<AbortSecAggClient>(2, "CONNECTION_DROPPED")))));
295 }
296 
TEST(SecaggServerTest,AbortClientWhenInvalidMessageSent)297 TEST(SecaggServerTest, AbortClientWhenInvalidMessageSent) {
298   TestTracingRecorder tracing_recorder;
299   auto sender = std::make_unique<MockSendToClientsInterface>();
300   MockSecAggServerMetricsListener* metrics =
301       new MockSecAggServerMetricsListener();
302   auto server = CreateServer(sender.get(), metrics);
303 
304   const ServerToClientWrapperMessage message = PARSE_TEXT_PROTO(R"pb(
305     abort: {
306       early_success: false
307       diagnostic_info: "The protocol is closing client with ClientAbortReason <INVALID_MESSAGE>."
308     })pb");
309   EXPECT_CALL(*sender, Send(2, EqualsProto(message)));
310 
311   EXPECT_CALL(
312       *metrics,
313       ClientsDropped(Eq(ClientStatus::DEAD_BEFORE_SENDING_ANYTHING),
314                      Eq(ClientDropReason::SERVER_PROTOCOL_ABORT_CLIENT)));
315   Status result = server->AbortClient(2, ClientAbortReason::INVALID_MESSAGE);
316 
317   EXPECT_THAT(result.code(), Eq(OK));
318   EXPECT_THAT(server->AbortedClientIds().contains(2), Eq(true));
319   EXPECT_THAT(
320       tracing_recorder.root(),
321       TraceRecorderHas(AllOf(
322           IsSpan<SecureAggServerState>(SecAggServerTraceState_R0AdvertiseKeys),
323           ElementsAre(IsSpan<AbortSecAggClient>(2, "INVALID_MESSAGE")))));
324 }
325 
TEST(SecaggServerTest,ReceiveMessageCausesServerToAbortIfTooManyClientsAbort)326 TEST(SecaggServerTest, ReceiveMessageCausesServerToAbortIfTooManyClientsAbort) {
327   // The actual behavior of the server upon receipt of messages is tested in the
328   // state class test files, but this tests the special behavior that the server
329   // should automatically transition to an abort state if it cannot continue.
330   TestTracingRecorder tracing_recorder;
331   auto sender = std::make_unique<MockSendToClientsInterface>();
332   auto server = CreateServer(sender.get());
333   StatusOr<int> clients_needed = server->MinimumMessagesNeededForNextRound();
334   ASSERT_THAT(clients_needed.ok(), Eq(true));
335   int maximum_number_of_aborts =
336       server->NumberOfAliveClients() - clients_needed.value();
337   EcdhPregeneratedTestKeys ecdh_keys;
338   ClientToServerWrapperMessage client_abort_message;
339   client_abort_message.mutable_abort()->set_diagnostic_info("Abort for test.");
340 
341   // Receiving `maximum_number_of_aborts - 1` aborts should not cause the entire
342   // protocol to abort.
343   std::vector<Matcher<const TestTracingRecorder::SpanOrEvent&>> matchers;
344   for (int i = 0; i < maximum_number_of_aborts; ++i) {
345     StatusOr<bool> result = server->ReceiveMessage(
346         i,
347         std::make_unique<ClientToServerWrapperMessage>(client_abort_message));
348     matchers.push_back(IsSpan<ReceiveSecAggMessage>(i));
349     ASSERT_THAT(result.ok(), Eq(true));
350     EXPECT_THAT(result.value(), Eq(false));
351     EXPECT_THAT(server->IsAborted(), Eq(false));
352     EXPECT_THAT(
353         tracing_recorder.root(),
354         TraceRecorderHas(AllOf(IsSpan<SecureAggServerState>(
355                                    SecAggServerTraceState_R0AdvertiseKeys),
356                                ElementsAreArray(matchers))));
357   }
358   // Receiving `maximum_number_of_aborts` aborts means the protocol is ready to
359   // proceed to the aborted state, which is indicated by ReceiveMessage
360   // returning true.
361   StatusOr<bool> result = server->ReceiveMessage(
362       maximum_number_of_aborts,
363       std::make_unique<ClientToServerWrapperMessage>(client_abort_message));
364   matchers.push_back(IsSpan<ReceiveSecAggMessage>(maximum_number_of_aborts));
365   ASSERT_THAT(result.ok(), Eq(true));
366   EXPECT_THAT(result.value(), Eq(true));
367   // However the server is not aborted until ProceedToNextRound is called.
368   EXPECT_THAT(server->IsAborted(), Eq(false));
369 
370   EXPECT_THAT(server->ProceedToNextRound(), IsOk());
371   matchers.push_back(IsSpan<ProceedToNextSecAggRound>());
372   EXPECT_THAT(server->IsAborted(), Eq(true));
373   EXPECT_THAT(server->State(), Eq(SecAggServerStateKind::ABORTED));
374 
375   EXPECT_THAT(
376       tracing_recorder.root(),
377       TraceRecorderHas(
378           AllOf(IsSpan<SecureAggServerState>(
379                     SecAggServerTraceState_R0AdvertiseKeys),
380                 ElementsAreArray(matchers)),
381           IsSpan<SecureAggServerState>(SecAggServerTraceState_Aborted)));
382 }
383 
TEST(SecaggServerTest,VerifyErrorsInAbortedState)384 TEST(SecaggServerTest, VerifyErrorsInAbortedState) {
385   TestTracingRecorder tracing_recorder;
386   auto sender = std::make_unique<MockSendToClientsInterface>();
387   auto server = CreateServer(sender.get());
388   EXPECT_THAT(server->Abort(), IsOk());
389 
390   EXPECT_THAT(
391       server->ReceiveMessage(1, std::make_unique<ClientToServerWrapperMessage>(
392                                     ClientToServerWrapperMessage{})),
393       IsCode(FAILED_PRECONDITION));
394   EXPECT_THAT(server->ProceedToNextRound(), IsCode(FAILED_PRECONDITION));
395   EXPECT_THAT(server->MinimumMessagesNeededForNextRound(),
396               IsCode(FAILED_PRECONDITION));
397   EXPECT_THAT(server->NumberOfMessagesReceivedInThisRound(),
398               IsCode(FAILED_PRECONDITION));
399   EXPECT_THAT(server->ReadyForNextRound(), IsCode(FAILED_PRECONDITION));
400 }
401 
402 }  // namespace
403 }  // namespace secagg
404 }  // namespace fcp
405