xref: /aosp_15_r20/external/federated-compute/fcp/secagg/server/secagg_server.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2019 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/secagg/server/secagg_server.h"
18 
19 #include <algorithm>
20 #include <cstddef>
21 #include <functional>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/container/node_hash_set.h"
28 #include "absl/strings/str_cat.h"
29 #include "fcp/base/monitoring.h"
30 #include "fcp/secagg/server/aes/aes_secagg_server_protocol_impl.h"
31 #include "fcp/secagg/server/experiments_names.h"
32 #include "fcp/secagg/server/graph_parameter_finder.h"
33 #include "fcp/secagg/server/secagg_scheduler.h"
34 #include "fcp/secagg/server/secagg_server_aborted_state.h"
35 #include "fcp/secagg/server/secagg_server_enums.pb.h"
36 #include "fcp/secagg/server/secagg_server_messages.pb.h"
37 #include "fcp/secagg/server/secagg_server_metrics_listener.h"
38 #include "fcp/secagg/server/secagg_server_r0_advertise_keys_state.h"
39 #include "fcp/secagg/server/secagg_server_state.h"
40 #include "fcp/secagg/server/secagg_trace_utility.h"
41 #include "fcp/secagg/server/secret_sharing_graph.h"
42 #include "fcp/secagg/server/secret_sharing_graph_factory.h"
43 #include "fcp/secagg/server/send_to_clients_interface.h"
44 #include "fcp/secagg/server/tracing_schema.h"
45 #include "fcp/secagg/shared/aes_ctr_prng_factory.h"
46 #include "fcp/secagg/shared/aes_prng_factory.h"
47 #include "fcp/secagg/shared/input_vector_specification.h"
48 #include "fcp/secagg/shared/secagg_messages.pb.h"
49 #include "fcp/tracing/tracing_span.h"
50 
51 namespace fcp {
52 namespace secagg {
53 
SecAggServer(std::unique_ptr<SecAggServerProtocolImpl> impl)54 SecAggServer::SecAggServer(std::unique_ptr<SecAggServerProtocolImpl> impl) {
55   state_ = std::make_unique<SecAggServerR0AdvertiseKeysState>(std::move(impl));
56 
57   // Start the span for the current state. The rest of the state span
58   // transitioning is done in TransitionState.
59   state_span_ = std::make_unique<UnscopedTracingSpan<SecureAggServerState>>(
60       span_.Ref(), TracingState(state_->State()));
61 }
62 
Create(int minimum_number_of_clients_to_proceed,int total_number_of_clients,const std::vector<InputVectorSpecification> & input_vector_specs,SendToClientsInterface * sender,std::unique_ptr<SecAggServerMetricsListener> metrics,std::unique_ptr<SecAggScheduler> prng_runner,std::unique_ptr<ExperimentsInterface> experiments,const SecureAggregationRequirements & threat_model)63 StatusOr<std::unique_ptr<SecAggServer>> SecAggServer::Create(
64     int minimum_number_of_clients_to_proceed, int total_number_of_clients,
65     const std::vector<InputVectorSpecification>& input_vector_specs,
66     SendToClientsInterface* sender,
67     std::unique_ptr<SecAggServerMetricsListener> metrics,
68     std::unique_ptr<SecAggScheduler> prng_runner,
69     std::unique_ptr<ExperimentsInterface> experiments,
70     const SecureAggregationRequirements& threat_model) {
71   TracingSpan<CreateSecAggServer> span;
72   SecretSharingGraphFactory factory;
73   std::unique_ptr<SecretSharingGraph> secret_sharing_graph;
74   ServerVariant server_variant = ServerVariant::UNKNOWN_VERSION;
75 
76   bool is_fullgraph_protocol_variant =
77       experiments->IsEnabled(kFullgraphSecAggExperiment);
78   int degree, threshold;
79   // We first compute parameters degree and threshold for the subgraph variant,
80   // unless the kFullgraphSecAggExperiment is enabled, and then set
81   // is_subgraph_protocol_variant to false if the parameter finding procedure
82   // fails. In that case we resort to classical full-graph secagg.
83   // This will happen for very small values of total_number_of_clients (e.g. <
84   // 65), i.e. cohort sizes where subgraph-secagg does not give much advantage.
85   if (!is_fullgraph_protocol_variant) {
86     if (experiments->IsEnabled(kForceSubgraphSecAggExperiment)) {
87       // In kForceSubgraphSecAggExperiment (which is only for testing
88       // purposes) we fix the degree in the Harary graph to be half the number
89       // of clients (rounding to the next odd number to account for self-edges
90       // as above) and degree to be half of the degree (or 2 whatever is
91       // larger). This means that, for example in a simple test with 5
92       // clients, each client shares keys with 2 other clients and the
93       // threshold is one.
94       degree = total_number_of_clients / 2;
95       if (degree % 2 == 0) {
96         degree += 1;
97       }
98       threshold = std::max(2, degree / 2);
99 
100     } else {
101       // kSubgraphSecAggCuriousServerExperiment sets the threat model to
102       // CURIOUS_SERVER in subgraph-secagg executions.
103       // This experiment was introduced as part of go/subgraph-secagg-rollout
104       // and is temporary (see b/191179307).
105       StatusOr<fcp::secagg::HararyGraphParameters>
106           computed_params_status_or_value;
107       if (experiments->IsEnabled(kSubgraphSecAggCuriousServerExperiment)) {
108         SecureAggregationRequirements alternate_threat_model = threat_model;
109         alternate_threat_model.set_adversary_class(
110             AdversaryClass::CURIOUS_SERVER);
111         computed_params_status_or_value = ComputeHararyGraphParameters(
112             total_number_of_clients, alternate_threat_model);
113       } else {
114         computed_params_status_or_value =
115             ComputeHararyGraphParameters(total_number_of_clients, threat_model);
116       }
117       if (computed_params_status_or_value.ok()) {
118         // We add 1 to the computed degree to account for a self-edge in the
119         // SecretSharingHararyGraph graph
120         degree = computed_params_status_or_value->degree + 1;
121         threshold = computed_params_status_or_value->threshold;
122       } else {
123         is_fullgraph_protocol_variant = true;
124       }
125     }
126   }
127 
128   // In both the FullGraph and SubGraph variants, the protocol only successfully
129   // completes and returns a sum if no more than
130   // floor(total_number_of_clients * threat_model.estimated_dropout_rate())
131   // clients dropout before the end of the protocol execution. This ensure that
132   // at least ceil(total_number_of_clients *(1. -
133   // threat_model.estimated_dropout_rate() -
134   // threat_model.adversarial_client_rate)) values from honest clients are
135   // included in the final sum.
136   // The protocol allows to make that threshold larger by providing a larger
137   // value of minimum_number_of_clients_to_proceed to the create function, but
138   // never lower.
139   minimum_number_of_clients_to_proceed =
140       std::max(minimum_number_of_clients_to_proceed,
141                static_cast<int>(
142                    std::ceil(total_number_of_clients *
143                              (1. - threat_model.estimated_dropout_rate()))));
144   if (is_fullgraph_protocol_variant) {
145     // We're instantiating full-graph secagg, either because that was
146     // the intent of the caller (by setting kFullgraphSecAggExperiment), or
147     // because ComputeHararyGraphParameters returned and error.
148     FCP_RETURN_IF_ERROR(CheckFullGraphParameters(
149         total_number_of_clients, minimum_number_of_clients_to_proceed,
150         threat_model));
151     secret_sharing_graph = factory.CreateCompleteGraph(
152         total_number_of_clients, minimum_number_of_clients_to_proceed);
153     server_variant = ServerVariant::NATIVE_V1;
154     Trace<FullGraphServerParameters>(
155         total_number_of_clients, minimum_number_of_clients_to_proceed,
156         experiments->IsEnabled(kSecAggAsyncRound2Experiment));
157   } else {
158     secret_sharing_graph =
159         factory.CreateHararyGraph(total_number_of_clients, degree, threshold);
160     server_variant = ServerVariant::NATIVE_SUBGRAPH;
161     Trace<SubGraphServerParameters>(
162         total_number_of_clients, degree, threshold,
163         minimum_number_of_clients_to_proceed,
164         experiments->IsEnabled(kSecAggAsyncRound2Experiment));
165   }
166 
167   return absl::WrapUnique(
168       new SecAggServer(std::make_unique<AesSecAggServerProtocolImpl>(
169           std::move(secret_sharing_graph), minimum_number_of_clients_to_proceed,
170           input_vector_specs, std::move(metrics),
171           std::make_unique<AesCtrPrngFactory>(), sender, std::move(prng_runner),
172           std::vector<ClientStatus>(total_number_of_clients,
173                                     ClientStatus::READY_TO_START),
174           server_variant, std::move(experiments))));
175 }
176 
Abort()177 Status SecAggServer::Abort() {
178   const std::string reason = "Abort upon external request.";
179   TracingSpan<AbortSecAggServer> span(state_span_->Ref(), reason);
180   FCP_RETURN_IF_ERROR(ErrorIfAbortedOrCompleted());
181   TransitionState(state_->Abort(reason, SecAggServerOutcome::EXTERNAL_REQUEST));
182   return FCP_STATUS(OK);
183 }
184 
Abort(const std::string & reason,SecAggServerOutcome outcome)185 Status SecAggServer::Abort(const std::string& reason,
186                            SecAggServerOutcome outcome) {
187   const std::string formatted_reason =
188       absl::StrCat("Abort upon external request for reason <", reason, ">.");
189   TracingSpan<AbortSecAggServer> span(state_span_->Ref(), formatted_reason);
190   FCP_RETURN_IF_ERROR(ErrorIfAbortedOrCompleted());
191   TransitionState(state_->Abort(formatted_reason, outcome));
192   return FCP_STATUS(OK);
193 }
194 
MakeClientAbortMessage(ClientAbortReason reason)195 std::string MakeClientAbortMessage(ClientAbortReason reason) {
196   return absl::StrCat("The protocol is closing client with ClientAbortReason <",
197                       ClientAbortReason_Name(reason), ">.");
198 }
199 
AbortClient(uint32_t client_id,ClientAbortReason reason)200 Status SecAggServer::AbortClient(uint32_t client_id, ClientAbortReason reason) {
201   TracingSpan<AbortSecAggClient> span(
202       state_span_->Ref(), client_id,
203       ClientAbortReason_descriptor()->FindValueByNumber(reason)->name());
204   FCP_RETURN_IF_ERROR(ErrorIfAbortedOrCompleted());
205   FCP_RETURN_IF_ERROR(ValidateClientId(client_id));
206   // By default, put all AbortClient calls in the same bucket (with some
207   // exceptions below).
208   ClientDropReason client_drop_reason =
209       ClientDropReason::SERVER_PROTOCOL_ABORT_CLIENT;
210   bool notify_client = false;
211   bool log_metrics = true;
212   std::string message;
213   // Handle all specific abortClient cases
214   switch (reason) {
215     case ClientAbortReason::INVALID_MESSAGE:
216       notify_client = true;
217       message = MakeClientAbortMessage(reason);
218       break;
219     case ClientAbortReason::CONNECTION_DROPPED:
220       client_drop_reason = ClientDropReason::CONNECTION_CLOSED;
221       break;
222     default:
223       log_metrics = false;
224       message = MakeClientAbortMessage(reason);
225       break;
226   }
227 
228   state_->AbortClient(client_id, message, client_drop_reason, notify_client,
229                       log_metrics);
230   return FCP_STATUS(OK);
231 }
232 
ProceedToNextRound()233 Status SecAggServer::ProceedToNextRound() {
234   TracingSpan<ProceedToNextSecAggRound> span(state_span_->Ref());
235   FCP_RETURN_IF_ERROR(ErrorIfAbortedOrCompleted());
236   StatusOr<std::unique_ptr<SecAggServerState>> status_or_next_state =
237       state_->ProceedToNextRound();
238   if (status_or_next_state.ok()) {
239     TransitionState(std::move(status_or_next_state.value()));
240   }
241   return status_or_next_state.status();
242 }
243 
ReceiveMessage(uint32_t client_id,std::unique_ptr<ClientToServerWrapperMessage> message)244 StatusOr<bool> SecAggServer::ReceiveMessage(
245     uint32_t client_id, std::unique_ptr<ClientToServerWrapperMessage> message) {
246   TracingSpan<ReceiveSecAggMessage> span(state_span_->Ref(), client_id);
247   FCP_RETURN_IF_ERROR(ErrorIfAbortedOrCompleted());
248   FCP_RETURN_IF_ERROR(ValidateClientId(client_id));
249   FCP_RETURN_IF_ERROR(state_->HandleMessage(client_id, std::move(message)));
250   return ReadyForNextRound();
251 }
252 
SetAsyncCallback(std::function<void ()> async_callback)253 bool SecAggServer::SetAsyncCallback(std::function<void()> async_callback) {
254   return state_->SetAsyncCallback(async_callback);
255 }
256 
TransitionState(std::unique_ptr<SecAggServerState> new_state)257 void SecAggServer::TransitionState(
258     std::unique_ptr<SecAggServerState> new_state) {
259   // Reset state_span_ before creating a new unscoped span for the next state
260   // to ensure old span is destructed before the new one is created.
261   state_span_.reset();
262   state_ = std::move(new_state);
263   state_span_ = std::make_unique<UnscopedTracingSpan<SecureAggServerState>>(
264       span_.Ref(), TracingState(state_->State()));
265   state_->EnterState();
266 }
267 
AbortedClientIds() const268 absl::flat_hash_set<uint32_t> SecAggServer::AbortedClientIds() const {
269   return state_->AbortedClientIds();
270 }
271 
ErrorMessage() const272 StatusOr<std::string> SecAggServer::ErrorMessage() const {
273   return state_->ErrorMessage();
274 }
275 
IsAborted() const276 bool SecAggServer::IsAborted() const { return state_->IsAborted(); }
277 
IsCompletedSuccessfully() const278 bool SecAggServer::IsCompletedSuccessfully() const {
279   return state_->IsCompletedSuccessfully();
280 }
281 
IsNumberOfIncludedInputsCommitted() const282 bool SecAggServer::IsNumberOfIncludedInputsCommitted() const {
283   return state_->IsNumberOfIncludedInputsCommitted();
284 }
285 
MinimumMessagesNeededForNextRound() const286 StatusOr<int> SecAggServer::MinimumMessagesNeededForNextRound() const {
287   FCP_RETURN_IF_ERROR(ErrorIfAbortedOrCompleted());
288   return state_->MinimumMessagesNeededForNextRound();
289 }
290 
NumberOfAliveClients() const291 int SecAggServer::NumberOfAliveClients() const {
292   return state_->NumberOfAliveClients();
293 }
294 
NumberOfClientsFailedAfterSendingMaskedInput() const295 int SecAggServer::NumberOfClientsFailedAfterSendingMaskedInput() const {
296   return state_->NumberOfClientsFailedAfterSendingMaskedInput();
297 }
298 
NumberOfClientsFailedBeforeSendingMaskedInput() const299 int SecAggServer::NumberOfClientsFailedBeforeSendingMaskedInput() const {
300   return state_->NumberOfClientsFailedBeforeSendingMaskedInput();
301 }
302 
NumberOfClientsTerminatedWithoutUnmasking() const303 int SecAggServer::NumberOfClientsTerminatedWithoutUnmasking() const {
304   return state_->NumberOfClientsTerminatedWithoutUnmasking();
305 }
306 
NumberOfIncludedInputs() const307 int SecAggServer::NumberOfIncludedInputs() const {
308   return state_->NumberOfIncludedInputs();
309 }
310 
NumberOfPendingClients() const311 int SecAggServer::NumberOfPendingClients() const {
312   return state_->NumberOfPendingClients();
313 }
314 
NumberOfClientsReadyForNextRound() const315 StatusOr<int> SecAggServer::NumberOfClientsReadyForNextRound() const {
316   FCP_RETURN_IF_ERROR(ErrorIfAbortedOrCompleted());
317   return state_->NumberOfClientsReadyForNextRound();
318 }
319 
NumberOfMessagesReceivedInThisRound() const320 StatusOr<int> SecAggServer::NumberOfMessagesReceivedInThisRound() const {
321   FCP_RETURN_IF_ERROR(ErrorIfAbortedOrCompleted());
322   return state_->NumberOfMessagesReceivedInThisRound();
323 }
324 
ReadyForNextRound() const325 StatusOr<bool> SecAggServer::ReadyForNextRound() const {
326   FCP_RETURN_IF_ERROR(ErrorIfAbortedOrCompleted());
327   return state_->ReadyForNextRound();
328 }
329 
Result()330 StatusOr<std::unique_ptr<SecAggVectorMap>> SecAggServer::Result() {
331   return state_->Result();
332 }
333 
NumberOfNeighbors() const334 int SecAggServer::NumberOfNeighbors() const {
335   return state_->number_of_neighbors();
336 }
337 
MinimumSurvivingNeighborsForReconstruction() const338 int SecAggServer::MinimumSurvivingNeighborsForReconstruction() const {
339   return state_->minimum_surviving_neighbors_for_reconstruction();
340 }
341 
State() const342 SecAggServerStateKind SecAggServer::State() const { return state_->State(); }
343 
ValidateClientId(uint32_t client_id) const344 Status SecAggServer::ValidateClientId(uint32_t client_id) const {
345   if (client_id >= state_->total_number_of_clients()) {
346     return FCP_STATUS(FAILED_PRECONDITION)
347            << "Client Id " << client_id
348            << " is outside of the expected bounds - 0 to "
349            << state_->total_number_of_clients();
350   }
351   return FCP_STATUS(OK);
352 }
353 
ErrorIfAbortedOrCompleted() const354 Status SecAggServer::ErrorIfAbortedOrCompleted() const {
355   if (state_->IsAborted()) {
356     return FCP_STATUS(FAILED_PRECONDITION)
357            << "The server has already aborted. The request cannot be "
358               "satisfied.";
359   }
360   if (state_->IsCompletedSuccessfully()) {
361     return FCP_STATUS(FAILED_PRECONDITION)
362            << "The server has already completed the protocol. "
363            << "Call getOutput() to retrieve the output.";
364   }
365   return FCP_STATUS(OK);
366 }
367 
368 }  // namespace secagg
369 }  // namespace fcp
370