1 /*
2 * Copyright 2019 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fcp/secagg/server/secagg_server.h"
18
19 #include <cstddef>
20 #include <memory>
21 #include <set>
22 #include <string>
23 #include <utility>
24 #include <vector>
25
26 #include "gmock/gmock.h"
27 #include "gtest/gtest.h"
28 #include "fcp/base/monitoring.h"
29 #include "fcp/secagg/server/secagg_server_enums.pb.h"
30 #include "fcp/secagg/server/secagg_server_state.h"
31 #include "fcp/secagg/server/tracing_schema.h"
32 #include "fcp/secagg/shared/aes_ctr_prng_factory.h"
33 #include "fcp/secagg/shared/input_vector_specification.h"
34 #include "fcp/secagg/shared/secagg_messages.pb.h"
35 #include "fcp/secagg/testing/ecdh_pregenerated_test_keys.h"
36 #include "fcp/secagg/testing/server/mock_secagg_server_metrics_listener.h"
37 #include "fcp/secagg/testing/server/mock_send_to_clients_interface.h"
38 #include "fcp/secagg/testing/server/test_secagg_experiments.h"
39 #include "fcp/testing/testing.h"
40 #include "fcp/tracing/test_tracing_recorder.h"
41
42 namespace fcp {
43 namespace secagg {
44 namespace {
45
46 using ::testing::_;
47 using ::testing::Eq;
48
CreateServer(SendToClientsInterface * sender,SecAggServerMetricsListener * metrics=new MockSecAggServerMetricsListener (),std::unique_ptr<TestSecAggExperiment> experiments=std::make_unique<TestSecAggExperiment> ())49 std::unique_ptr<SecAggServer> CreateServer(
50 SendToClientsInterface* sender,
51 SecAggServerMetricsListener* metrics =
52 new MockSecAggServerMetricsListener(),
53 std::unique_ptr<TestSecAggExperiment> experiments =
54 std::make_unique<TestSecAggExperiment>()) {
55 SecureAggregationRequirements threat_model;
56 threat_model.set_adversary_class(AdversaryClass::CURIOUS_SERVER);
57 threat_model.set_adversarial_client_rate(.3);
58 threat_model.set_estimated_dropout_rate(.3);
59 std::unique_ptr<AesPrngFactory> prng_factory;
60 std::vector<InputVectorSpecification> input_vector_specs;
61 input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
62 auto status_or_server = SecAggServer::Create(
63 100, // minimum_number_of_clients_to_proceed
64 1000, // total_number_of_clients
65 input_vector_specs, sender,
66 std::unique_ptr<SecAggServerMetricsListener>(metrics),
67 /*prng_runner=*/nullptr, std::move(experiments), threat_model);
68 EXPECT_THAT(status_or_server.ok(), true) << status_or_server.status();
69 return std::move(status_or_server.value());
70 }
71
72 template <typename... M>
TraceRecorderHas(const M &...matchers)73 auto TraceRecorderHas(const M&... matchers) {
74 return ElementsAre(AllOf(
75 IsSpan<CreateSecAggServer>(),
76 ElementsAre(
77 IsEvent<SubGraphServerParameters>(
78 1000, // number_of_clients
79 219, // degree
80 116, // threshold
81 700, // minimum_number_of_clients_to_proceed
82 false), // is_r2_async_aggregation_enabled
83 AllOf(IsSpan<SecureAggServerSession>(), ElementsAre(matchers...)))));
84 }
85
TEST(SecaggServerTest,ConstructedWithCorrectState)86 TEST(SecaggServerTest, ConstructedWithCorrectState) {
87 TestTracingRecorder tracing_recorder;
88 auto sender = std::make_unique<MockSendToClientsInterface>();
89 auto server = CreateServer(sender.get());
90
91 EXPECT_THAT(server->IsAborted(), Eq(false));
92 EXPECT_THAT(server->NumberOfNeighbors(), Eq(219));
93 EXPECT_THAT(server->IsCompletedSuccessfully(), Eq(false));
94 EXPECT_THAT(server->State(), Eq(SecAggServerStateKind::R0_ADVERTISE_KEYS));
95 EXPECT_THAT(tracing_recorder.root(),
96 TraceRecorderHas(IsSpan<SecureAggServerState>(
97 SecAggServerTraceState_R0AdvertiseKeys)));
98 }
99
TEST(SecaggServerTest,FullgraphSecAggExperimentTakesEffect)100 TEST(SecaggServerTest, FullgraphSecAggExperimentTakesEffect) {
101 // Tests FullgraphSecAggExperiment by instatiating
102 // a server under that experiment , and
103 // checking that it results in the expected number of neighbors for the given
104 // setting (1000 clients) and threat model (.3 dropout rate and .3 adversarial
105 // client rate).
106 SecureAggregationRequirements threat_model;
107 threat_model.set_adversary_class(AdversaryClass::CURIOUS_SERVER);
108 threat_model.set_adversarial_client_rate(.3);
109 threat_model.set_estimated_dropout_rate(.3);
110 std::unique_ptr<AesPrngFactory> prng_factory;
111 std::vector<InputVectorSpecification> input_vector_specs;
112 input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
113 auto sender = std::make_unique<MockSendToClientsInterface>();
114 std::set<std::string> experiment_names = {kFullgraphSecAggExperiment};
115 auto status_or_server = SecAggServer::Create(
116 100, // minimum_number_of_clients_to_proceed
117 1000, // total_number_of_clients
118 input_vector_specs, sender.get(),
119 std::unique_ptr<SecAggServerMetricsListener>(
120 new MockSecAggServerMetricsListener()),
121 /*prng_runner=*/nullptr,
122 std::make_unique<TestSecAggExperiment>(experiment_names), threat_model);
123 EXPECT_THAT(status_or_server.ok(), true) << status_or_server.status();
124 EXPECT_THAT(status_or_server.value()->NumberOfNeighbors(), Eq(1000));
125 EXPECT_THAT(status_or_server.value()->IsAborted(), Eq(false));
126 EXPECT_THAT(status_or_server.value()->IsCompletedSuccessfully(), Eq(false));
127 EXPECT_THAT(status_or_server.value()->State(),
128 Eq(SecAggServerStateKind::R0_ADVERTISE_KEYS));
129 }
130
TEST(SecaggServerTest,SubgraphSecAggResortsToFullGraphOnSmallCohorts)131 TEST(SecaggServerTest, SubgraphSecAggResortsToFullGraphOnSmallCohorts) {
132 // Tests that a small number of clients for which subgraph-secagg does not
133 // have favorable parameters results in executiong the full-graph varian
134 SecureAggregationRequirements threat_model;
135 threat_model.set_adversary_class(AdversaryClass::CURIOUS_SERVER);
136 threat_model.set_adversarial_client_rate(.45);
137 threat_model.set_estimated_dropout_rate(.45);
138 std::unique_ptr<AesPrngFactory> prng_factory;
139 std::vector<InputVectorSpecification> input_vector_specs;
140 auto sender = std::make_unique<MockSendToClientsInterface>();
141 input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
142 std::set<std::string> experiment_names = {};
143 auto status_or_server = SecAggServer::Create(
144 5, // minimum_number_of_clients_to_proceed
145 25, // total_number_of_clients
146 input_vector_specs, sender.get(),
147 std::unique_ptr<SecAggServerMetricsListener>(
148 new MockSecAggServerMetricsListener()),
149 /*prng_runner=*/nullptr,
150 std::make_unique<TestSecAggExperiment>(experiment_names), threat_model);
151 EXPECT_THAT(status_or_server.ok(), true) << status_or_server.status();
152 EXPECT_THAT(status_or_server.value()->NumberOfNeighbors(), Eq(25));
153 EXPECT_THAT(
154 status_or_server.value()->MinimumSurvivingNeighborsForReconstruction(),
155 Eq(14));
156 EXPECT_THAT(status_or_server.value()->IsAborted(), Eq(false));
157 EXPECT_THAT(status_or_server.value()->IsCompletedSuccessfully(), Eq(false));
158 EXPECT_THAT(status_or_server.value()->State(),
159 Eq(SecAggServerStateKind::R0_ADVERTISE_KEYS));
160 }
161
TEST(SecaggServerTest,AbortClientWithInvalidIdThrowsError)162 TEST(SecaggServerTest, AbortClientWithInvalidIdThrowsError) {
163 TestTracingRecorder tracing_recorder;
164 auto sender = std::make_unique<MockSendToClientsInterface>();
165 auto server = CreateServer(sender.get());
166
167 EXPECT_THAT(
168 server->AbortClient(1001, ClientAbortReason::CONNECTION_DROPPED).code(),
169 Eq(FAILED_PRECONDITION));
170 }
171
TEST(SecaggServerTest,ReceiveMessageWithInvalidIdThrowsError)172 TEST(SecaggServerTest, ReceiveMessageWithInvalidIdThrowsError) {
173 TestTracingRecorder tracing_recorder;
174 auto sender = std::make_unique<MockSendToClientsInterface>();
175 auto server = CreateServer(sender.get());
176
177 ClientToServerWrapperMessage client_abort_message;
178 client_abort_message.mutable_abort()->set_diagnostic_info("Abort for test.");
179 EXPECT_THAT(
180 server
181 ->ReceiveMessage(1001, std::make_unique<ClientToServerWrapperMessage>(
182 client_abort_message))
183 .status()
184 .code(),
185 Eq(FAILED_PRECONDITION));
186 }
187
TEST(SecaggServerTest,AbortCausesStateTransitionAndMessageToBeSent)188 TEST(SecaggServerTest, AbortCausesStateTransitionAndMessageToBeSent) {
189 TestTracingRecorder tracing_recorder;
190 auto sender = std::make_unique<MockSendToClientsInterface>();
191 auto server = CreateServer(sender.get());
192
193 const ServerToClientWrapperMessage abort_message = PARSE_TEXT_PROTO(R"pb(
194 abort: {
195 early_success: false
196 diagnostic_info: "Abort upon external request."
197 })pb");
198
199 EXPECT_CALL(*sender, SendBroadcast(EqualsProto(abort_message)));
200 Status result = server->Abort();
201
202 EXPECT_THAT(result.code(), Eq(OK));
203 EXPECT_THAT(server->IsAborted(), Eq(true));
204 EXPECT_THAT(server->State(), Eq(SecAggServerStateKind::ABORTED));
205 ASSERT_THAT(server->ErrorMessage().ok(), Eq(true));
206 EXPECT_THAT(server->ErrorMessage().value(),
207 Eq("Abort upon external request."));
208 EXPECT_THAT(
209 tracing_recorder.root(),
210 TraceRecorderHas(
211 AllOf(IsSpan<SecureAggServerState>(
212 SecAggServerTraceState_R0AdvertiseKeys),
213 ElementsAre(
214 IsSpan<AbortSecAggServer>("Abort upon external request."))),
215 IsSpan<SecureAggServerState>(SecAggServerTraceState_Aborted)));
216 }
217
TEST(SecaggServerTest,AbortWithReasonCausesStateTransitionAndMessageToBeSent)218 TEST(SecaggServerTest, AbortWithReasonCausesStateTransitionAndMessageToBeSent) {
219 TestTracingRecorder tracing_recorder;
220 auto sender = std::make_unique<MockSendToClientsInterface>();
221 auto server = CreateServer(sender.get());
222
223 const ServerToClientWrapperMessage abort_message = PARSE_TEXT_PROTO(R"pb(
224 abort: {
225 early_success: false
226 diagnostic_info: "Abort upon external request for reason <Test reason.>."
227 })pb");
228
229 EXPECT_CALL(*sender, SendBroadcast(EqualsProto(abort_message)));
230 Status result =
231 server->Abort("Test reason.", SecAggServerOutcome::EXTERNAL_REQUEST);
232
233 EXPECT_THAT(result.code(), Eq(OK));
234 EXPECT_THAT(server->IsAborted(), Eq(true));
235 EXPECT_THAT(server->State(), Eq(SecAggServerStateKind::ABORTED));
236 ASSERT_THAT(server->ErrorMessage().ok(), Eq(true));
237 EXPECT_THAT(server->ErrorMessage().value(),
238 Eq("Abort upon external request for reason <Test reason.>."));
239 EXPECT_THAT(
240 tracing_recorder.root(),
241 TraceRecorderHas(
242 AllOf(IsSpan<SecureAggServerState>(
243 SecAggServerTraceState_R0AdvertiseKeys),
244 ElementsAre(IsSpan<AbortSecAggServer>(
245 "Abort upon external request for reason <Test "
246 "reason.>."))),
247 IsSpan<SecureAggServerState>(SecAggServerTraceState_Aborted)));
248 }
249
TEST(SecaggServerTest,AbortClientNotCheckedIn)250 TEST(SecaggServerTest, AbortClientNotCheckedIn) {
251 TestTracingRecorder tracing_recorder;
252 auto sender = std::make_unique<MockSendToClientsInterface>();
253 MockSecAggServerMetricsListener* metrics =
254 new MockSecAggServerMetricsListener();
255 auto server = CreateServer(sender.get(), metrics);
256
257 EXPECT_CALL(*metrics, ClientsDropped(
258 Eq(ClientStatus::DEAD_BEFORE_SENDING_ANYTHING),
259 Eq(ClientDropReason::SERVER_PROTOCOL_ABORT_CLIENT)))
260 .Times(0);
261 // Client is not notified
262 EXPECT_CALL(*sender, Send(_, _)).Times(0);
263 Status result = server->AbortClient(2, ClientAbortReason::NOT_CHECKED_IN);
264
265 EXPECT_THAT(result.code(), Eq(OK));
266 EXPECT_THAT(server->AbortedClientIds().contains(2), Eq(true));
267 EXPECT_THAT(
268 tracing_recorder.root(),
269 TraceRecorderHas(AllOf(
270 IsSpan<SecureAggServerState>(SecAggServerTraceState_R0AdvertiseKeys),
271 ElementsAre(IsSpan<AbortSecAggClient>(2, "NOT_CHECKED_IN")))));
272 }
273
TEST(SecaggServerTest,AbortClientWhenConnectionDropped)274 TEST(SecaggServerTest, AbortClientWhenConnectionDropped) {
275 TestTracingRecorder tracing_recorder;
276 auto sender = std::make_unique<MockSendToClientsInterface>();
277 MockSecAggServerMetricsListener* metrics =
278 new MockSecAggServerMetricsListener();
279 auto server = CreateServer(sender.get(), metrics);
280
281 EXPECT_CALL(*metrics,
282 ClientsDropped(Eq(ClientStatus::DEAD_BEFORE_SENDING_ANYTHING),
283 Eq(ClientDropReason::CONNECTION_CLOSED)));
284 // Client is not notified
285 EXPECT_CALL(*sender, Send(_, _)).Times(0);
286 Status result = server->AbortClient(2, ClientAbortReason::CONNECTION_DROPPED);
287
288 EXPECT_THAT(result.code(), Eq(OK));
289 EXPECT_THAT(server->AbortedClientIds().contains(2), Eq(true));
290 EXPECT_THAT(
291 tracing_recorder.root(),
292 TraceRecorderHas(AllOf(
293 IsSpan<SecureAggServerState>(SecAggServerTraceState_R0AdvertiseKeys),
294 ElementsAre(IsSpan<AbortSecAggClient>(2, "CONNECTION_DROPPED")))));
295 }
296
TEST(SecaggServerTest,AbortClientWhenInvalidMessageSent)297 TEST(SecaggServerTest, AbortClientWhenInvalidMessageSent) {
298 TestTracingRecorder tracing_recorder;
299 auto sender = std::make_unique<MockSendToClientsInterface>();
300 MockSecAggServerMetricsListener* metrics =
301 new MockSecAggServerMetricsListener();
302 auto server = CreateServer(sender.get(), metrics);
303
304 const ServerToClientWrapperMessage message = PARSE_TEXT_PROTO(R"pb(
305 abort: {
306 early_success: false
307 diagnostic_info: "The protocol is closing client with ClientAbortReason <INVALID_MESSAGE>."
308 })pb");
309 EXPECT_CALL(*sender, Send(2, EqualsProto(message)));
310
311 EXPECT_CALL(
312 *metrics,
313 ClientsDropped(Eq(ClientStatus::DEAD_BEFORE_SENDING_ANYTHING),
314 Eq(ClientDropReason::SERVER_PROTOCOL_ABORT_CLIENT)));
315 Status result = server->AbortClient(2, ClientAbortReason::INVALID_MESSAGE);
316
317 EXPECT_THAT(result.code(), Eq(OK));
318 EXPECT_THAT(server->AbortedClientIds().contains(2), Eq(true));
319 EXPECT_THAT(
320 tracing_recorder.root(),
321 TraceRecorderHas(AllOf(
322 IsSpan<SecureAggServerState>(SecAggServerTraceState_R0AdvertiseKeys),
323 ElementsAre(IsSpan<AbortSecAggClient>(2, "INVALID_MESSAGE")))));
324 }
325
TEST(SecaggServerTest,ReceiveMessageCausesServerToAbortIfTooManyClientsAbort)326 TEST(SecaggServerTest, ReceiveMessageCausesServerToAbortIfTooManyClientsAbort) {
327 // The actual behavior of the server upon receipt of messages is tested in the
328 // state class test files, but this tests the special behavior that the server
329 // should automatically transition to an abort state if it cannot continue.
330 TestTracingRecorder tracing_recorder;
331 auto sender = std::make_unique<MockSendToClientsInterface>();
332 auto server = CreateServer(sender.get());
333 StatusOr<int> clients_needed = server->MinimumMessagesNeededForNextRound();
334 ASSERT_THAT(clients_needed.ok(), Eq(true));
335 int maximum_number_of_aborts =
336 server->NumberOfAliveClients() - clients_needed.value();
337 EcdhPregeneratedTestKeys ecdh_keys;
338 ClientToServerWrapperMessage client_abort_message;
339 client_abort_message.mutable_abort()->set_diagnostic_info("Abort for test.");
340
341 // Receiving `maximum_number_of_aborts - 1` aborts should not cause the entire
342 // protocol to abort.
343 std::vector<Matcher<const TestTracingRecorder::SpanOrEvent&>> matchers;
344 for (int i = 0; i < maximum_number_of_aborts; ++i) {
345 StatusOr<bool> result = server->ReceiveMessage(
346 i,
347 std::make_unique<ClientToServerWrapperMessage>(client_abort_message));
348 matchers.push_back(IsSpan<ReceiveSecAggMessage>(i));
349 ASSERT_THAT(result.ok(), Eq(true));
350 EXPECT_THAT(result.value(), Eq(false));
351 EXPECT_THAT(server->IsAborted(), Eq(false));
352 EXPECT_THAT(
353 tracing_recorder.root(),
354 TraceRecorderHas(AllOf(IsSpan<SecureAggServerState>(
355 SecAggServerTraceState_R0AdvertiseKeys),
356 ElementsAreArray(matchers))));
357 }
358 // Receiving `maximum_number_of_aborts` aborts means the protocol is ready to
359 // proceed to the aborted state, which is indicated by ReceiveMessage
360 // returning true.
361 StatusOr<bool> result = server->ReceiveMessage(
362 maximum_number_of_aborts,
363 std::make_unique<ClientToServerWrapperMessage>(client_abort_message));
364 matchers.push_back(IsSpan<ReceiveSecAggMessage>(maximum_number_of_aborts));
365 ASSERT_THAT(result.ok(), Eq(true));
366 EXPECT_THAT(result.value(), Eq(true));
367 // However the server is not aborted until ProceedToNextRound is called.
368 EXPECT_THAT(server->IsAborted(), Eq(false));
369
370 EXPECT_THAT(server->ProceedToNextRound(), IsOk());
371 matchers.push_back(IsSpan<ProceedToNextSecAggRound>());
372 EXPECT_THAT(server->IsAborted(), Eq(true));
373 EXPECT_THAT(server->State(), Eq(SecAggServerStateKind::ABORTED));
374
375 EXPECT_THAT(
376 tracing_recorder.root(),
377 TraceRecorderHas(
378 AllOf(IsSpan<SecureAggServerState>(
379 SecAggServerTraceState_R0AdvertiseKeys),
380 ElementsAreArray(matchers)),
381 IsSpan<SecureAggServerState>(SecAggServerTraceState_Aborted)));
382 }
383
TEST(SecaggServerTest,VerifyErrorsInAbortedState)384 TEST(SecaggServerTest, VerifyErrorsInAbortedState) {
385 TestTracingRecorder tracing_recorder;
386 auto sender = std::make_unique<MockSendToClientsInterface>();
387 auto server = CreateServer(sender.get());
388 EXPECT_THAT(server->Abort(), IsOk());
389
390 EXPECT_THAT(
391 server->ReceiveMessage(1, std::make_unique<ClientToServerWrapperMessage>(
392 ClientToServerWrapperMessage{})),
393 IsCode(FAILED_PRECONDITION));
394 EXPECT_THAT(server->ProceedToNextRound(), IsCode(FAILED_PRECONDITION));
395 EXPECT_THAT(server->MinimumMessagesNeededForNextRound(),
396 IsCode(FAILED_PRECONDITION));
397 EXPECT_THAT(server->NumberOfMessagesReceivedInThisRound(),
398 IsCode(FAILED_PRECONDITION));
399 EXPECT_THAT(server->ReadyForNextRound(), IsCode(FAILED_PRECONDITION));
400 }
401
402 } // namespace
403 } // namespace secagg
404 } // namespace fcp
405