xref: /aosp_15_r20/external/federated-compute/fcp/secagg/server/secagg_server_state.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2018 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/secagg/server/secagg_server_state.h"
18 
19 #include <functional>
20 #include <memory>
21 #include <string>
22 #include <unordered_set>
23 #include <utility>
24 
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/container/node_hash_set.h"
28 #include "absl/time/time.h"
29 #include "fcp/base/monitoring.h"
30 #include "fcp/secagg/server/secagg_server_aborted_state.h"
31 #include "fcp/secagg/server/secagg_server_enums.pb.h"
32 #include "fcp/secagg/server/secagg_trace_utility.h"
33 #include "fcp/secagg/server/tracing_schema.h"
34 #include "fcp/secagg/shared/secagg_messages.pb.h"
35 #include "fcp/tracing/tracing_span.h"
36 
37 namespace fcp {
38 namespace secagg {
39 
SecAggServerState(int number_of_clients_failed_after_sending_masked_input,int number_of_clients_failed_before_sending_masked_input,int number_of_clients_terminated_without_unmasking,SecAggServerStateKind state_kind,std::unique_ptr<SecAggServerProtocolImpl> impl)40 SecAggServerState::SecAggServerState(
41     int number_of_clients_failed_after_sending_masked_input,
42     int number_of_clients_failed_before_sending_masked_input,
43     int number_of_clients_terminated_without_unmasking,
44     SecAggServerStateKind state_kind,
45     std::unique_ptr<SecAggServerProtocolImpl> impl)
46     : needs_to_abort_(false),
47       number_of_clients_failed_after_sending_masked_input_(
48           number_of_clients_failed_after_sending_masked_input),
49       number_of_clients_failed_before_sending_masked_input_(
50           number_of_clients_failed_before_sending_masked_input),
51       number_of_clients_ready_for_next_round_(0),
52       number_of_clients_terminated_without_unmasking_(
53           number_of_clients_terminated_without_unmasking),
54       number_of_messages_received_in_this_round_(0),
55       round_start_(absl::Now()),
56       state_kind_(state_kind),
57       impl_(std::move(impl)) {}
58 
~SecAggServerState()59 SecAggServerState::~SecAggServerState() {}
60 
ExitState(StateTransition state_transition_status)61 std::unique_ptr<SecAggServerProtocolImpl>&& SecAggServerState::ExitState(
62     StateTransition state_transition_status) {
63   bool record_success = state_transition_status == StateTransition::kSuccess;
64   auto elapsed_time = absl::ToInt64Milliseconds(absl::Now() - round_start_);
65   if (metrics()) {
66     metrics()->RoundTimes(state_kind_, record_success, elapsed_time);
67     metrics()->RoundSurvivingClients(state_kind_, NumberOfAliveClients());
68 
69     // Fractions of clients by state
70     absl::flat_hash_map<ClientStatus, int> counts_by_state;
71     for (uint32_t i = 0; i < total_number_of_clients(); i++) {
72       counts_by_state[client_status(i)]++;
73     }
74     for (const auto& count_by_state : counts_by_state) {
75       double fraction = static_cast<double>(count_by_state.second) /
76                         total_number_of_clients();
77       Trace<ClientCountsPerState>(TracingState(state_kind_),
78                                   ClientStatusType(count_by_state.first),
79                                   count_by_state.second, fraction);
80       metrics()->RoundCompletionFractions(state_kind_, count_by_state.first,
81                                           fraction);
82     }
83   }
84   Trace<StateCompletion>(TracingState(state_kind_), record_success,
85                          elapsed_time, NumberOfAliveClients());
86   return std::move(impl_);
87 }
88 
89 // These methods return default values unless overridden.
IsAborted() const90 bool SecAggServerState::IsAborted() const { return false; }
IsCompletedSuccessfully() const91 bool SecAggServerState::IsCompletedSuccessfully() const { return false; }
NumberOfPendingClients() const92 int SecAggServerState::NumberOfPendingClients() const { return 0; }
NumberOfIncludedInputs() const93 int SecAggServerState::NumberOfIncludedInputs() const { return 0; }
MinimumMessagesNeededForNextRound() const94 int SecAggServerState::MinimumMessagesNeededForNextRound() const { return 0; }
ReadyForNextRound() const95 bool SecAggServerState::ReadyForNextRound() const { return false; }
96 
HandleMessage(uint32_t client_id,const ClientToServerWrapperMessage & message)97 Status SecAggServerState::HandleMessage(
98     uint32_t client_id, const ClientToServerWrapperMessage& message) {
99   MessageReceived(message, false);
100   if (message.message_content_case() ==
101       ClientToServerWrapperMessage::MESSAGE_CONTENT_NOT_SET) {
102     return FCP_STATUS(FAILED_PRECONDITION)
103            << "Server received a message of unknown type from client "
104            << client_id << " but was in state " << StateName();
105   } else {
106     return FCP_STATUS(FAILED_PRECONDITION)
107            << "Server received a message of type "
108            << message.message_content_case() << " from client " << client_id
109            << " but was in state " << StateName();
110   }
111 }
112 
HandleMessage(uint32_t client_id,std::unique_ptr<ClientToServerWrapperMessage> message)113 Status SecAggServerState::HandleMessage(
114     uint32_t client_id, std::unique_ptr<ClientToServerWrapperMessage> message) {
115   return HandleMessage(client_id, *message);
116 }
117 
118 StatusOr<std::unique_ptr<SecAggServerState>>
ProceedToNextRound()119 SecAggServerState::ProceedToNextRound() {
120   return FCP_STATUS(FAILED_PRECONDITION)
121          << "The server cannot proceed to next round from state "
122          << StateName();
123 }
124 
IsClientDead(uint32_t client_id) const125 bool SecAggServerState::IsClientDead(uint32_t client_id) const {
126   switch (client_status(client_id)) {
127     case ClientStatus::DEAD_BEFORE_SENDING_ANYTHING:
128     case ClientStatus::DEAD_AFTER_ADVERTISE_KEYS_RECEIVED:
129     case ClientStatus::DEAD_AFTER_SHARE_KEYS_RECEIVED:
130     case ClientStatus::DEAD_AFTER_MASKED_INPUT_RESPONSE_RECEIVED:
131     case ClientStatus::DEAD_AFTER_UNMASKING_RESPONSE_RECEIVED:
132       return true;
133       break;
134     default:
135       return false;
136   }
137 }
138 
AbortClient(uint32_t client_id,const std::string & reason,ClientDropReason reason_code,bool notify,bool log_metrics)139 void SecAggServerState::AbortClient(uint32_t client_id,
140                                     const std::string& reason,
141                                     ClientDropReason reason_code, bool notify,
142                                     bool log_metrics) {
143   FCP_CHECK(!(IsAborted() || IsCompletedSuccessfully()));
144 
145   if (IsClientDead(client_id)) {
146     return;  // without sending a message
147   }
148 
149   HandleAbortClient(client_id, reason_code);
150   if (notify) {
151     ServerToClientWrapperMessage message;
152     message.mutable_abort()->set_diagnostic_info(reason);
153     message.mutable_abort()->set_early_success(reason_code ==
154                                                ClientDropReason::EARLY_SUCCESS);
155     Send(client_id, message);
156   }
157   // Clients that have successfully completed the protocol should not be logging
158   // metrics.
159   if (metrics() && log_metrics &&
160       client_status(client_id) !=
161           ClientStatus::DEAD_AFTER_UNMASKING_RESPONSE_RECEIVED) {
162     metrics()->ClientsDropped(client_status(client_id), reason_code);
163   }
164   auto elapsed_millis = absl::ToInt64Milliseconds(absl::Now() - round_start_);
165   Trace<ClientsDropped>(ClientStatusType(client_status(client_id)),
166                         ClientDropReasonType(reason_code), elapsed_millis,
167                         reason);
168 }
169 
AbortState(const std::string & reason,SecAggServerOutcome outcome)170 std::unique_ptr<SecAggServerState> SecAggServerState::AbortState(
171     const std::string& reason, SecAggServerOutcome outcome) {
172   if (metrics()) {
173     metrics()->ProtocolOutcomes(outcome);
174   }
175   Trace<SecAggProtocolOutcome>(ConvertSecAccServerOutcomeToTrace(outcome));
176   return std::make_unique<SecAggServerAbortedState>(
177       reason, ExitState(StateTransition::kAbort),
178       number_of_clients_failed_after_sending_masked_input_,
179       number_of_clients_failed_before_sending_masked_input_,
180       number_of_clients_terminated_without_unmasking_);
181 }
182 
Abort(const std::string & reason,SecAggServerOutcome outcome)183 std::unique_ptr<SecAggServerState> SecAggServerState::Abort(
184     const std::string& reason, SecAggServerOutcome outcome) {
185   FCP_CHECK(!(IsAborted() || IsCompletedSuccessfully()));
186 
187   HandleAbort();
188 
189   ServerToClientWrapperMessage message;
190   message.mutable_abort()->set_early_success(false);
191   message.mutable_abort()->set_diagnostic_info(reason);
192   SendBroadcast(message);
193 
194   return AbortState(reason, outcome);
195 }
196 
ErrorMessage() const197 StatusOr<std::string> SecAggServerState::ErrorMessage() const {
198   return FCP_STATUS(FAILED_PRECONDITION)
199          << "Error message requested, but server is in state " << StateName();
200 }
201 
NumberOfAliveClients() const202 int SecAggServerState::NumberOfAliveClients() const {
203   return total_number_of_clients() -
204          number_of_clients_failed_before_sending_masked_input_ -
205          number_of_clients_failed_after_sending_masked_input_ -
206          number_of_clients_terminated_without_unmasking_;
207 }
208 
NumberOfMessagesReceivedInThisRound() const209 int SecAggServerState::NumberOfMessagesReceivedInThisRound() const {
210   return number_of_messages_received_in_this_round_;
211 }
212 
NumberOfClientsReadyForNextRound() const213 int SecAggServerState::NumberOfClientsReadyForNextRound() const {
214   return number_of_clients_ready_for_next_round_;
215 }
216 
NumberOfClientsFailedAfterSendingMaskedInput() const217 int SecAggServerState::NumberOfClientsFailedAfterSendingMaskedInput() const {
218   return number_of_clients_failed_after_sending_masked_input_;
219 }
220 
NumberOfClientsFailedBeforeSendingMaskedInput() const221 int SecAggServerState::NumberOfClientsFailedBeforeSendingMaskedInput() const {
222   return number_of_clients_failed_before_sending_masked_input_;
223 }
224 
NumberOfClientsTerminatedWithoutUnmasking() const225 int SecAggServerState::NumberOfClientsTerminatedWithoutUnmasking() const {
226   return number_of_clients_terminated_without_unmasking_;
227 }
228 
NeedsToAbort() const229 bool SecAggServerState::NeedsToAbort() const { return needs_to_abort_; }
230 
AbortedClientIds() const231 absl::flat_hash_set<uint32_t> SecAggServerState::AbortedClientIds() const {
232   auto aborted_client_ids_ = absl::flat_hash_set<uint32_t>();
233   for (int i = 0; i < total_number_of_clients(); ++i) {
234     // Clients that have successfully completed the protocol are not reported
235     // as aborted.
236     if (IsClientDead(i)) {
237       aborted_client_ids_.insert(i);
238     }
239   }
240   return aborted_client_ids_;
241 }
242 
SetAsyncCallback(std::function<void ()> async_callback)243 bool SecAggServerState::SetAsyncCallback(std::function<void()> async_callback) {
244   return false;
245 }
246 
Result()247 StatusOr<std::unique_ptr<SecAggVectorMap>> SecAggServerState::Result() {
248   return FCP_STATUS(UNAVAILABLE)
249          << "Result requested, but server is in state " << StateName();
250 }
251 
State() const252 SecAggServerStateKind SecAggServerState::State() const { return state_kind_; }
253 
StateName() const254 std::string SecAggServerState::StateName() const {
255   switch (state_kind_) {
256     case SecAggServerStateKind::ABORTED:
257       return "Aborted";
258     case SecAggServerStateKind::COMPLETED:
259       return "Completed";
260     case SecAggServerStateKind::PRNG_RUNNING:
261       return "PrngRunning";
262     case SecAggServerStateKind::R0_ADVERTISE_KEYS:
263       return "R0AdvertiseKeys";
264     case SecAggServerStateKind::R1_SHARE_KEYS:
265       return "R1ShareKeys";
266     case SecAggServerStateKind::R2_MASKED_INPUT_COLLECTION:
267       return "R2MaskedInputCollection";
268     case SecAggServerStateKind::R3_UNMASKING:
269       return "R3Unmasking";
270     default:
271       return "Unknown";
272   }
273 }
274 
MessageReceived(const ClientToServerWrapperMessage & message,bool expected)275 void SecAggServerState::MessageReceived(
276     const ClientToServerWrapperMessage& message, bool expected) {
277   auto elapsed_millis = absl::ToInt64Milliseconds(absl::Now() - round_start_);
278   if (metrics()) {
279     if (expected) {
280       metrics()->ClientResponseTimes(message.message_content_case(),
281                                      elapsed_millis);
282     }
283     metrics()->MessageReceivedSizes(message.message_content_case(), expected,
284                                     message.ByteSizeLong());
285   }
286   Trace<ClientMessageReceived>(GetClientToServerMessageType(message),
287                                message.ByteSizeLong(), expected,
288                                elapsed_millis);
289 }
290 
SendBroadcast(const ServerToClientWrapperMessage & message)291 void SecAggServerState::SendBroadcast(
292     const ServerToClientWrapperMessage& message) {
293   FCP_CHECK(message.message_content_case() !=
294             ServerToClientWrapperMessage::MESSAGE_CONTENT_NOT_SET);
295   if (metrics()) {
296     metrics()->BroadcastMessageSizes(message.message_content_case(),
297                                      message.ByteSizeLong());
298   }
299   sender()->SendBroadcast(message);
300   Trace<BroadcastMessageSent>(GetServerToClientMessageType(message),
301                               message.ByteSizeLong());
302 }
303 
Send(uint32_t recipient_id,const ServerToClientWrapperMessage & message)304 void SecAggServerState::Send(uint32_t recipient_id,
305                              const ServerToClientWrapperMessage& message) {
306   FCP_CHECK(message.message_content_case() !=
307             ServerToClientWrapperMessage::MESSAGE_CONTENT_NOT_SET);
308   if (metrics()) {
309     metrics()->IndividualMessageSizes(message.message_content_case(),
310                                       message.ByteSizeLong());
311   }
312   sender()->Send(recipient_id, message);
313 
314   Trace<IndividualMessageSent>(recipient_id,
315                                GetServerToClientMessageType(message),
316                                message.ByteSizeLong());
317 }
318 
319 }  // namespace secagg
320 }  // namespace fcp
321