1 /*
2 * Copyright 2019 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fcp/secagg/server/secagg_server.h"
18
19 #include <algorithm>
20 #include <cstddef>
21 #include <functional>
22 #include <memory>
23 #include <string>
24 #include <utility>
25
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/container/node_hash_set.h"
28 #include "absl/strings/str_cat.h"
29 #include "fcp/base/monitoring.h"
30 #include "fcp/secagg/server/aes/aes_secagg_server_protocol_impl.h"
31 #include "fcp/secagg/server/experiments_names.h"
32 #include "fcp/secagg/server/graph_parameter_finder.h"
33 #include "fcp/secagg/server/secagg_scheduler.h"
34 #include "fcp/secagg/server/secagg_server_aborted_state.h"
35 #include "fcp/secagg/server/secagg_server_enums.pb.h"
36 #include "fcp/secagg/server/secagg_server_messages.pb.h"
37 #include "fcp/secagg/server/secagg_server_metrics_listener.h"
38 #include "fcp/secagg/server/secagg_server_r0_advertise_keys_state.h"
39 #include "fcp/secagg/server/secagg_server_state.h"
40 #include "fcp/secagg/server/secagg_trace_utility.h"
41 #include "fcp/secagg/server/secret_sharing_graph.h"
42 #include "fcp/secagg/server/secret_sharing_graph_factory.h"
43 #include "fcp/secagg/server/send_to_clients_interface.h"
44 #include "fcp/secagg/server/tracing_schema.h"
45 #include "fcp/secagg/shared/aes_ctr_prng_factory.h"
46 #include "fcp/secagg/shared/aes_prng_factory.h"
47 #include "fcp/secagg/shared/input_vector_specification.h"
48 #include "fcp/secagg/shared/secagg_messages.pb.h"
49 #include "fcp/tracing/tracing_span.h"
50
51 namespace fcp {
52 namespace secagg {
53
SecAggServer(std::unique_ptr<SecAggServerProtocolImpl> impl)54 SecAggServer::SecAggServer(std::unique_ptr<SecAggServerProtocolImpl> impl) {
55 state_ = std::make_unique<SecAggServerR0AdvertiseKeysState>(std::move(impl));
56
57 // Start the span for the current state. The rest of the state span
58 // transitioning is done in TransitionState.
59 state_span_ = std::make_unique<UnscopedTracingSpan<SecureAggServerState>>(
60 span_.Ref(), TracingState(state_->State()));
61 }
62
Create(int minimum_number_of_clients_to_proceed,int total_number_of_clients,const std::vector<InputVectorSpecification> & input_vector_specs,SendToClientsInterface * sender,std::unique_ptr<SecAggServerMetricsListener> metrics,std::unique_ptr<SecAggScheduler> prng_runner,std::unique_ptr<ExperimentsInterface> experiments,const SecureAggregationRequirements & threat_model)63 StatusOr<std::unique_ptr<SecAggServer>> SecAggServer::Create(
64 int minimum_number_of_clients_to_proceed, int total_number_of_clients,
65 const std::vector<InputVectorSpecification>& input_vector_specs,
66 SendToClientsInterface* sender,
67 std::unique_ptr<SecAggServerMetricsListener> metrics,
68 std::unique_ptr<SecAggScheduler> prng_runner,
69 std::unique_ptr<ExperimentsInterface> experiments,
70 const SecureAggregationRequirements& threat_model) {
71 TracingSpan<CreateSecAggServer> span;
72 SecretSharingGraphFactory factory;
73 std::unique_ptr<SecretSharingGraph> secret_sharing_graph;
74 ServerVariant server_variant = ServerVariant::UNKNOWN_VERSION;
75
76 bool is_fullgraph_protocol_variant =
77 experiments->IsEnabled(kFullgraphSecAggExperiment);
78 int degree, threshold;
79 // We first compute parameters degree and threshold for the subgraph variant,
80 // unless the kFullgraphSecAggExperiment is enabled, and then set
81 // is_subgraph_protocol_variant to false if the parameter finding procedure
82 // fails. In that case we resort to classical full-graph secagg.
83 // This will happen for very small values of total_number_of_clients (e.g. <
84 // 65), i.e. cohort sizes where subgraph-secagg does not give much advantage.
85 if (!is_fullgraph_protocol_variant) {
86 if (experiments->IsEnabled(kForceSubgraphSecAggExperiment)) {
87 // In kForceSubgraphSecAggExperiment (which is only for testing
88 // purposes) we fix the degree in the Harary graph to be half the number
89 // of clients (rounding to the next odd number to account for self-edges
90 // as above) and degree to be half of the degree (or 2 whatever is
91 // larger). This means that, for example in a simple test with 5
92 // clients, each client shares keys with 2 other clients and the
93 // threshold is one.
94 degree = total_number_of_clients / 2;
95 if (degree % 2 == 0) {
96 degree += 1;
97 }
98 threshold = std::max(2, degree / 2);
99
100 } else {
101 // kSubgraphSecAggCuriousServerExperiment sets the threat model to
102 // CURIOUS_SERVER in subgraph-secagg executions.
103 // This experiment was introduced as part of go/subgraph-secagg-rollout
104 // and is temporary (see b/191179307).
105 StatusOr<fcp::secagg::HararyGraphParameters>
106 computed_params_status_or_value;
107 if (experiments->IsEnabled(kSubgraphSecAggCuriousServerExperiment)) {
108 SecureAggregationRequirements alternate_threat_model = threat_model;
109 alternate_threat_model.set_adversary_class(
110 AdversaryClass::CURIOUS_SERVER);
111 computed_params_status_or_value = ComputeHararyGraphParameters(
112 total_number_of_clients, alternate_threat_model);
113 } else {
114 computed_params_status_or_value =
115 ComputeHararyGraphParameters(total_number_of_clients, threat_model);
116 }
117 if (computed_params_status_or_value.ok()) {
118 // We add 1 to the computed degree to account for a self-edge in the
119 // SecretSharingHararyGraph graph
120 degree = computed_params_status_or_value->degree + 1;
121 threshold = computed_params_status_or_value->threshold;
122 } else {
123 is_fullgraph_protocol_variant = true;
124 }
125 }
126 }
127
128 // In both the FullGraph and SubGraph variants, the protocol only successfully
129 // completes and returns a sum if no more than
130 // floor(total_number_of_clients * threat_model.estimated_dropout_rate())
131 // clients dropout before the end of the protocol execution. This ensure that
132 // at least ceil(total_number_of_clients *(1. -
133 // threat_model.estimated_dropout_rate() -
134 // threat_model.adversarial_client_rate)) values from honest clients are
135 // included in the final sum.
136 // The protocol allows to make that threshold larger by providing a larger
137 // value of minimum_number_of_clients_to_proceed to the create function, but
138 // never lower.
139 minimum_number_of_clients_to_proceed =
140 std::max(minimum_number_of_clients_to_proceed,
141 static_cast<int>(
142 std::ceil(total_number_of_clients *
143 (1. - threat_model.estimated_dropout_rate()))));
144 if (is_fullgraph_protocol_variant) {
145 // We're instantiating full-graph secagg, either because that was
146 // the intent of the caller (by setting kFullgraphSecAggExperiment), or
147 // because ComputeHararyGraphParameters returned and error.
148 FCP_RETURN_IF_ERROR(CheckFullGraphParameters(
149 total_number_of_clients, minimum_number_of_clients_to_proceed,
150 threat_model));
151 secret_sharing_graph = factory.CreateCompleteGraph(
152 total_number_of_clients, minimum_number_of_clients_to_proceed);
153 server_variant = ServerVariant::NATIVE_V1;
154 Trace<FullGraphServerParameters>(
155 total_number_of_clients, minimum_number_of_clients_to_proceed,
156 experiments->IsEnabled(kSecAggAsyncRound2Experiment));
157 } else {
158 secret_sharing_graph =
159 factory.CreateHararyGraph(total_number_of_clients, degree, threshold);
160 server_variant = ServerVariant::NATIVE_SUBGRAPH;
161 Trace<SubGraphServerParameters>(
162 total_number_of_clients, degree, threshold,
163 minimum_number_of_clients_to_proceed,
164 experiments->IsEnabled(kSecAggAsyncRound2Experiment));
165 }
166
167 return absl::WrapUnique(
168 new SecAggServer(std::make_unique<AesSecAggServerProtocolImpl>(
169 std::move(secret_sharing_graph), minimum_number_of_clients_to_proceed,
170 input_vector_specs, std::move(metrics),
171 std::make_unique<AesCtrPrngFactory>(), sender, std::move(prng_runner),
172 std::vector<ClientStatus>(total_number_of_clients,
173 ClientStatus::READY_TO_START),
174 server_variant, std::move(experiments))));
175 }
176
Abort()177 Status SecAggServer::Abort() {
178 const std::string reason = "Abort upon external request.";
179 TracingSpan<AbortSecAggServer> span(state_span_->Ref(), reason);
180 FCP_RETURN_IF_ERROR(ErrorIfAbortedOrCompleted());
181 TransitionState(state_->Abort(reason, SecAggServerOutcome::EXTERNAL_REQUEST));
182 return FCP_STATUS(OK);
183 }
184
Abort(const std::string & reason,SecAggServerOutcome outcome)185 Status SecAggServer::Abort(const std::string& reason,
186 SecAggServerOutcome outcome) {
187 const std::string formatted_reason =
188 absl::StrCat("Abort upon external request for reason <", reason, ">.");
189 TracingSpan<AbortSecAggServer> span(state_span_->Ref(), formatted_reason);
190 FCP_RETURN_IF_ERROR(ErrorIfAbortedOrCompleted());
191 TransitionState(state_->Abort(formatted_reason, outcome));
192 return FCP_STATUS(OK);
193 }
194
MakeClientAbortMessage(ClientAbortReason reason)195 std::string MakeClientAbortMessage(ClientAbortReason reason) {
196 return absl::StrCat("The protocol is closing client with ClientAbortReason <",
197 ClientAbortReason_Name(reason), ">.");
198 }
199
AbortClient(uint32_t client_id,ClientAbortReason reason)200 Status SecAggServer::AbortClient(uint32_t client_id, ClientAbortReason reason) {
201 TracingSpan<AbortSecAggClient> span(
202 state_span_->Ref(), client_id,
203 ClientAbortReason_descriptor()->FindValueByNumber(reason)->name());
204 FCP_RETURN_IF_ERROR(ErrorIfAbortedOrCompleted());
205 FCP_RETURN_IF_ERROR(ValidateClientId(client_id));
206 // By default, put all AbortClient calls in the same bucket (with some
207 // exceptions below).
208 ClientDropReason client_drop_reason =
209 ClientDropReason::SERVER_PROTOCOL_ABORT_CLIENT;
210 bool notify_client = false;
211 bool log_metrics = true;
212 std::string message;
213 // Handle all specific abortClient cases
214 switch (reason) {
215 case ClientAbortReason::INVALID_MESSAGE:
216 notify_client = true;
217 message = MakeClientAbortMessage(reason);
218 break;
219 case ClientAbortReason::CONNECTION_DROPPED:
220 client_drop_reason = ClientDropReason::CONNECTION_CLOSED;
221 break;
222 default:
223 log_metrics = false;
224 message = MakeClientAbortMessage(reason);
225 break;
226 }
227
228 state_->AbortClient(client_id, message, client_drop_reason, notify_client,
229 log_metrics);
230 return FCP_STATUS(OK);
231 }
232
ProceedToNextRound()233 Status SecAggServer::ProceedToNextRound() {
234 TracingSpan<ProceedToNextSecAggRound> span(state_span_->Ref());
235 FCP_RETURN_IF_ERROR(ErrorIfAbortedOrCompleted());
236 StatusOr<std::unique_ptr<SecAggServerState>> status_or_next_state =
237 state_->ProceedToNextRound();
238 if (status_or_next_state.ok()) {
239 TransitionState(std::move(status_or_next_state.value()));
240 }
241 return status_or_next_state.status();
242 }
243
ReceiveMessage(uint32_t client_id,std::unique_ptr<ClientToServerWrapperMessage> message)244 StatusOr<bool> SecAggServer::ReceiveMessage(
245 uint32_t client_id, std::unique_ptr<ClientToServerWrapperMessage> message) {
246 TracingSpan<ReceiveSecAggMessage> span(state_span_->Ref(), client_id);
247 FCP_RETURN_IF_ERROR(ErrorIfAbortedOrCompleted());
248 FCP_RETURN_IF_ERROR(ValidateClientId(client_id));
249 FCP_RETURN_IF_ERROR(state_->HandleMessage(client_id, std::move(message)));
250 return ReadyForNextRound();
251 }
252
SetAsyncCallback(std::function<void ()> async_callback)253 bool SecAggServer::SetAsyncCallback(std::function<void()> async_callback) {
254 return state_->SetAsyncCallback(async_callback);
255 }
256
TransitionState(std::unique_ptr<SecAggServerState> new_state)257 void SecAggServer::TransitionState(
258 std::unique_ptr<SecAggServerState> new_state) {
259 // Reset state_span_ before creating a new unscoped span for the next state
260 // to ensure old span is destructed before the new one is created.
261 state_span_.reset();
262 state_ = std::move(new_state);
263 state_span_ = std::make_unique<UnscopedTracingSpan<SecureAggServerState>>(
264 span_.Ref(), TracingState(state_->State()));
265 state_->EnterState();
266 }
267
AbortedClientIds() const268 absl::flat_hash_set<uint32_t> SecAggServer::AbortedClientIds() const {
269 return state_->AbortedClientIds();
270 }
271
ErrorMessage() const272 StatusOr<std::string> SecAggServer::ErrorMessage() const {
273 return state_->ErrorMessage();
274 }
275
IsAborted() const276 bool SecAggServer::IsAborted() const { return state_->IsAborted(); }
277
IsCompletedSuccessfully() const278 bool SecAggServer::IsCompletedSuccessfully() const {
279 return state_->IsCompletedSuccessfully();
280 }
281
IsNumberOfIncludedInputsCommitted() const282 bool SecAggServer::IsNumberOfIncludedInputsCommitted() const {
283 return state_->IsNumberOfIncludedInputsCommitted();
284 }
285
MinimumMessagesNeededForNextRound() const286 StatusOr<int> SecAggServer::MinimumMessagesNeededForNextRound() const {
287 FCP_RETURN_IF_ERROR(ErrorIfAbortedOrCompleted());
288 return state_->MinimumMessagesNeededForNextRound();
289 }
290
NumberOfAliveClients() const291 int SecAggServer::NumberOfAliveClients() const {
292 return state_->NumberOfAliveClients();
293 }
294
NumberOfClientsFailedAfterSendingMaskedInput() const295 int SecAggServer::NumberOfClientsFailedAfterSendingMaskedInput() const {
296 return state_->NumberOfClientsFailedAfterSendingMaskedInput();
297 }
298
NumberOfClientsFailedBeforeSendingMaskedInput() const299 int SecAggServer::NumberOfClientsFailedBeforeSendingMaskedInput() const {
300 return state_->NumberOfClientsFailedBeforeSendingMaskedInput();
301 }
302
NumberOfClientsTerminatedWithoutUnmasking() const303 int SecAggServer::NumberOfClientsTerminatedWithoutUnmasking() const {
304 return state_->NumberOfClientsTerminatedWithoutUnmasking();
305 }
306
NumberOfIncludedInputs() const307 int SecAggServer::NumberOfIncludedInputs() const {
308 return state_->NumberOfIncludedInputs();
309 }
310
NumberOfPendingClients() const311 int SecAggServer::NumberOfPendingClients() const {
312 return state_->NumberOfPendingClients();
313 }
314
NumberOfClientsReadyForNextRound() const315 StatusOr<int> SecAggServer::NumberOfClientsReadyForNextRound() const {
316 FCP_RETURN_IF_ERROR(ErrorIfAbortedOrCompleted());
317 return state_->NumberOfClientsReadyForNextRound();
318 }
319
NumberOfMessagesReceivedInThisRound() const320 StatusOr<int> SecAggServer::NumberOfMessagesReceivedInThisRound() const {
321 FCP_RETURN_IF_ERROR(ErrorIfAbortedOrCompleted());
322 return state_->NumberOfMessagesReceivedInThisRound();
323 }
324
ReadyForNextRound() const325 StatusOr<bool> SecAggServer::ReadyForNextRound() const {
326 FCP_RETURN_IF_ERROR(ErrorIfAbortedOrCompleted());
327 return state_->ReadyForNextRound();
328 }
329
Result()330 StatusOr<std::unique_ptr<SecAggVectorMap>> SecAggServer::Result() {
331 return state_->Result();
332 }
333
NumberOfNeighbors() const334 int SecAggServer::NumberOfNeighbors() const {
335 return state_->number_of_neighbors();
336 }
337
MinimumSurvivingNeighborsForReconstruction() const338 int SecAggServer::MinimumSurvivingNeighborsForReconstruction() const {
339 return state_->minimum_surviving_neighbors_for_reconstruction();
340 }
341
State() const342 SecAggServerStateKind SecAggServer::State() const { return state_->State(); }
343
ValidateClientId(uint32_t client_id) const344 Status SecAggServer::ValidateClientId(uint32_t client_id) const {
345 if (client_id >= state_->total_number_of_clients()) {
346 return FCP_STATUS(FAILED_PRECONDITION)
347 << "Client Id " << client_id
348 << " is outside of the expected bounds - 0 to "
349 << state_->total_number_of_clients();
350 }
351 return FCP_STATUS(OK);
352 }
353
ErrorIfAbortedOrCompleted() const354 Status SecAggServer::ErrorIfAbortedOrCompleted() const {
355 if (state_->IsAborted()) {
356 return FCP_STATUS(FAILED_PRECONDITION)
357 << "The server has already aborted. The request cannot be "
358 "satisfied.";
359 }
360 if (state_->IsCompletedSuccessfully()) {
361 return FCP_STATUS(FAILED_PRECONDITION)
362 << "The server has already completed the protocol. "
363 << "Call getOutput() to retrieve the output.";
364 }
365 return FCP_STATUS(OK);
366 }
367
368 } // namespace secagg
369 } // namespace fcp
370