1 /*
2 * Copyright 2018 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fcp/secagg/server/secagg_server_state.h"
18
19 #include <functional>
20 #include <memory>
21 #include <string>
22 #include <unordered_set>
23 #include <utility>
24
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/container/node_hash_set.h"
28 #include "absl/time/time.h"
29 #include "fcp/base/monitoring.h"
30 #include "fcp/secagg/server/secagg_server_aborted_state.h"
31 #include "fcp/secagg/server/secagg_server_enums.pb.h"
32 #include "fcp/secagg/server/secagg_trace_utility.h"
33 #include "fcp/secagg/server/tracing_schema.h"
34 #include "fcp/secagg/shared/secagg_messages.pb.h"
35 #include "fcp/tracing/tracing_span.h"
36
37 namespace fcp {
38 namespace secagg {
39
SecAggServerState(int number_of_clients_failed_after_sending_masked_input,int number_of_clients_failed_before_sending_masked_input,int number_of_clients_terminated_without_unmasking,SecAggServerStateKind state_kind,std::unique_ptr<SecAggServerProtocolImpl> impl)40 SecAggServerState::SecAggServerState(
41 int number_of_clients_failed_after_sending_masked_input,
42 int number_of_clients_failed_before_sending_masked_input,
43 int number_of_clients_terminated_without_unmasking,
44 SecAggServerStateKind state_kind,
45 std::unique_ptr<SecAggServerProtocolImpl> impl)
46 : needs_to_abort_(false),
47 number_of_clients_failed_after_sending_masked_input_(
48 number_of_clients_failed_after_sending_masked_input),
49 number_of_clients_failed_before_sending_masked_input_(
50 number_of_clients_failed_before_sending_masked_input),
51 number_of_clients_ready_for_next_round_(0),
52 number_of_clients_terminated_without_unmasking_(
53 number_of_clients_terminated_without_unmasking),
54 number_of_messages_received_in_this_round_(0),
55 round_start_(absl::Now()),
56 state_kind_(state_kind),
57 impl_(std::move(impl)) {}
58
~SecAggServerState()59 SecAggServerState::~SecAggServerState() {}
60
ExitState(StateTransition state_transition_status)61 std::unique_ptr<SecAggServerProtocolImpl>&& SecAggServerState::ExitState(
62 StateTransition state_transition_status) {
63 bool record_success = state_transition_status == StateTransition::kSuccess;
64 auto elapsed_time = absl::ToInt64Milliseconds(absl::Now() - round_start_);
65 if (metrics()) {
66 metrics()->RoundTimes(state_kind_, record_success, elapsed_time);
67 metrics()->RoundSurvivingClients(state_kind_, NumberOfAliveClients());
68
69 // Fractions of clients by state
70 absl::flat_hash_map<ClientStatus, int> counts_by_state;
71 for (uint32_t i = 0; i < total_number_of_clients(); i++) {
72 counts_by_state[client_status(i)]++;
73 }
74 for (const auto& count_by_state : counts_by_state) {
75 double fraction = static_cast<double>(count_by_state.second) /
76 total_number_of_clients();
77 Trace<ClientCountsPerState>(TracingState(state_kind_),
78 ClientStatusType(count_by_state.first),
79 count_by_state.second, fraction);
80 metrics()->RoundCompletionFractions(state_kind_, count_by_state.first,
81 fraction);
82 }
83 }
84 Trace<StateCompletion>(TracingState(state_kind_), record_success,
85 elapsed_time, NumberOfAliveClients());
86 return std::move(impl_);
87 }
88
89 // These methods return default values unless overridden.
IsAborted() const90 bool SecAggServerState::IsAborted() const { return false; }
IsCompletedSuccessfully() const91 bool SecAggServerState::IsCompletedSuccessfully() const { return false; }
NumberOfPendingClients() const92 int SecAggServerState::NumberOfPendingClients() const { return 0; }
NumberOfIncludedInputs() const93 int SecAggServerState::NumberOfIncludedInputs() const { return 0; }
MinimumMessagesNeededForNextRound() const94 int SecAggServerState::MinimumMessagesNeededForNextRound() const { return 0; }
ReadyForNextRound() const95 bool SecAggServerState::ReadyForNextRound() const { return false; }
96
HandleMessage(uint32_t client_id,const ClientToServerWrapperMessage & message)97 Status SecAggServerState::HandleMessage(
98 uint32_t client_id, const ClientToServerWrapperMessage& message) {
99 MessageReceived(message, false);
100 if (message.message_content_case() ==
101 ClientToServerWrapperMessage::MESSAGE_CONTENT_NOT_SET) {
102 return FCP_STATUS(FAILED_PRECONDITION)
103 << "Server received a message of unknown type from client "
104 << client_id << " but was in state " << StateName();
105 } else {
106 return FCP_STATUS(FAILED_PRECONDITION)
107 << "Server received a message of type "
108 << message.message_content_case() << " from client " << client_id
109 << " but was in state " << StateName();
110 }
111 }
112
HandleMessage(uint32_t client_id,std::unique_ptr<ClientToServerWrapperMessage> message)113 Status SecAggServerState::HandleMessage(
114 uint32_t client_id, std::unique_ptr<ClientToServerWrapperMessage> message) {
115 return HandleMessage(client_id, *message);
116 }
117
118 StatusOr<std::unique_ptr<SecAggServerState>>
ProceedToNextRound()119 SecAggServerState::ProceedToNextRound() {
120 return FCP_STATUS(FAILED_PRECONDITION)
121 << "The server cannot proceed to next round from state "
122 << StateName();
123 }
124
IsClientDead(uint32_t client_id) const125 bool SecAggServerState::IsClientDead(uint32_t client_id) const {
126 switch (client_status(client_id)) {
127 case ClientStatus::DEAD_BEFORE_SENDING_ANYTHING:
128 case ClientStatus::DEAD_AFTER_ADVERTISE_KEYS_RECEIVED:
129 case ClientStatus::DEAD_AFTER_SHARE_KEYS_RECEIVED:
130 case ClientStatus::DEAD_AFTER_MASKED_INPUT_RESPONSE_RECEIVED:
131 case ClientStatus::DEAD_AFTER_UNMASKING_RESPONSE_RECEIVED:
132 return true;
133 break;
134 default:
135 return false;
136 }
137 }
138
AbortClient(uint32_t client_id,const std::string & reason,ClientDropReason reason_code,bool notify,bool log_metrics)139 void SecAggServerState::AbortClient(uint32_t client_id,
140 const std::string& reason,
141 ClientDropReason reason_code, bool notify,
142 bool log_metrics) {
143 FCP_CHECK(!(IsAborted() || IsCompletedSuccessfully()));
144
145 if (IsClientDead(client_id)) {
146 return; // without sending a message
147 }
148
149 HandleAbortClient(client_id, reason_code);
150 if (notify) {
151 ServerToClientWrapperMessage message;
152 message.mutable_abort()->set_diagnostic_info(reason);
153 message.mutable_abort()->set_early_success(reason_code ==
154 ClientDropReason::EARLY_SUCCESS);
155 Send(client_id, message);
156 }
157 // Clients that have successfully completed the protocol should not be logging
158 // metrics.
159 if (metrics() && log_metrics &&
160 client_status(client_id) !=
161 ClientStatus::DEAD_AFTER_UNMASKING_RESPONSE_RECEIVED) {
162 metrics()->ClientsDropped(client_status(client_id), reason_code);
163 }
164 auto elapsed_millis = absl::ToInt64Milliseconds(absl::Now() - round_start_);
165 Trace<ClientsDropped>(ClientStatusType(client_status(client_id)),
166 ClientDropReasonType(reason_code), elapsed_millis,
167 reason);
168 }
169
AbortState(const std::string & reason,SecAggServerOutcome outcome)170 std::unique_ptr<SecAggServerState> SecAggServerState::AbortState(
171 const std::string& reason, SecAggServerOutcome outcome) {
172 if (metrics()) {
173 metrics()->ProtocolOutcomes(outcome);
174 }
175 Trace<SecAggProtocolOutcome>(ConvertSecAccServerOutcomeToTrace(outcome));
176 return std::make_unique<SecAggServerAbortedState>(
177 reason, ExitState(StateTransition::kAbort),
178 number_of_clients_failed_after_sending_masked_input_,
179 number_of_clients_failed_before_sending_masked_input_,
180 number_of_clients_terminated_without_unmasking_);
181 }
182
Abort(const std::string & reason,SecAggServerOutcome outcome)183 std::unique_ptr<SecAggServerState> SecAggServerState::Abort(
184 const std::string& reason, SecAggServerOutcome outcome) {
185 FCP_CHECK(!(IsAborted() || IsCompletedSuccessfully()));
186
187 HandleAbort();
188
189 ServerToClientWrapperMessage message;
190 message.mutable_abort()->set_early_success(false);
191 message.mutable_abort()->set_diagnostic_info(reason);
192 SendBroadcast(message);
193
194 return AbortState(reason, outcome);
195 }
196
ErrorMessage() const197 StatusOr<std::string> SecAggServerState::ErrorMessage() const {
198 return FCP_STATUS(FAILED_PRECONDITION)
199 << "Error message requested, but server is in state " << StateName();
200 }
201
NumberOfAliveClients() const202 int SecAggServerState::NumberOfAliveClients() const {
203 return total_number_of_clients() -
204 number_of_clients_failed_before_sending_masked_input_ -
205 number_of_clients_failed_after_sending_masked_input_ -
206 number_of_clients_terminated_without_unmasking_;
207 }
208
NumberOfMessagesReceivedInThisRound() const209 int SecAggServerState::NumberOfMessagesReceivedInThisRound() const {
210 return number_of_messages_received_in_this_round_;
211 }
212
NumberOfClientsReadyForNextRound() const213 int SecAggServerState::NumberOfClientsReadyForNextRound() const {
214 return number_of_clients_ready_for_next_round_;
215 }
216
NumberOfClientsFailedAfterSendingMaskedInput() const217 int SecAggServerState::NumberOfClientsFailedAfterSendingMaskedInput() const {
218 return number_of_clients_failed_after_sending_masked_input_;
219 }
220
NumberOfClientsFailedBeforeSendingMaskedInput() const221 int SecAggServerState::NumberOfClientsFailedBeforeSendingMaskedInput() const {
222 return number_of_clients_failed_before_sending_masked_input_;
223 }
224
NumberOfClientsTerminatedWithoutUnmasking() const225 int SecAggServerState::NumberOfClientsTerminatedWithoutUnmasking() const {
226 return number_of_clients_terminated_without_unmasking_;
227 }
228
NeedsToAbort() const229 bool SecAggServerState::NeedsToAbort() const { return needs_to_abort_; }
230
AbortedClientIds() const231 absl::flat_hash_set<uint32_t> SecAggServerState::AbortedClientIds() const {
232 auto aborted_client_ids_ = absl::flat_hash_set<uint32_t>();
233 for (int i = 0; i < total_number_of_clients(); ++i) {
234 // Clients that have successfully completed the protocol are not reported
235 // as aborted.
236 if (IsClientDead(i)) {
237 aborted_client_ids_.insert(i);
238 }
239 }
240 return aborted_client_ids_;
241 }
242
SetAsyncCallback(std::function<void ()> async_callback)243 bool SecAggServerState::SetAsyncCallback(std::function<void()> async_callback) {
244 return false;
245 }
246
Result()247 StatusOr<std::unique_ptr<SecAggVectorMap>> SecAggServerState::Result() {
248 return FCP_STATUS(UNAVAILABLE)
249 << "Result requested, but server is in state " << StateName();
250 }
251
State() const252 SecAggServerStateKind SecAggServerState::State() const { return state_kind_; }
253
StateName() const254 std::string SecAggServerState::StateName() const {
255 switch (state_kind_) {
256 case SecAggServerStateKind::ABORTED:
257 return "Aborted";
258 case SecAggServerStateKind::COMPLETED:
259 return "Completed";
260 case SecAggServerStateKind::PRNG_RUNNING:
261 return "PrngRunning";
262 case SecAggServerStateKind::R0_ADVERTISE_KEYS:
263 return "R0AdvertiseKeys";
264 case SecAggServerStateKind::R1_SHARE_KEYS:
265 return "R1ShareKeys";
266 case SecAggServerStateKind::R2_MASKED_INPUT_COLLECTION:
267 return "R2MaskedInputCollection";
268 case SecAggServerStateKind::R3_UNMASKING:
269 return "R3Unmasking";
270 default:
271 return "Unknown";
272 }
273 }
274
MessageReceived(const ClientToServerWrapperMessage & message,bool expected)275 void SecAggServerState::MessageReceived(
276 const ClientToServerWrapperMessage& message, bool expected) {
277 auto elapsed_millis = absl::ToInt64Milliseconds(absl::Now() - round_start_);
278 if (metrics()) {
279 if (expected) {
280 metrics()->ClientResponseTimes(message.message_content_case(),
281 elapsed_millis);
282 }
283 metrics()->MessageReceivedSizes(message.message_content_case(), expected,
284 message.ByteSizeLong());
285 }
286 Trace<ClientMessageReceived>(GetClientToServerMessageType(message),
287 message.ByteSizeLong(), expected,
288 elapsed_millis);
289 }
290
SendBroadcast(const ServerToClientWrapperMessage & message)291 void SecAggServerState::SendBroadcast(
292 const ServerToClientWrapperMessage& message) {
293 FCP_CHECK(message.message_content_case() !=
294 ServerToClientWrapperMessage::MESSAGE_CONTENT_NOT_SET);
295 if (metrics()) {
296 metrics()->BroadcastMessageSizes(message.message_content_case(),
297 message.ByteSizeLong());
298 }
299 sender()->SendBroadcast(message);
300 Trace<BroadcastMessageSent>(GetServerToClientMessageType(message),
301 message.ByteSizeLong());
302 }
303
Send(uint32_t recipient_id,const ServerToClientWrapperMessage & message)304 void SecAggServerState::Send(uint32_t recipient_id,
305 const ServerToClientWrapperMessage& message) {
306 FCP_CHECK(message.message_content_case() !=
307 ServerToClientWrapperMessage::MESSAGE_CONTENT_NOT_SET);
308 if (metrics()) {
309 metrics()->IndividualMessageSizes(message.message_content_case(),
310 message.ByteSizeLong());
311 }
312 sender()->Send(recipient_id, message);
313
314 Trace<IndividualMessageSent>(recipient_id,
315 GetServerToClientMessageType(message),
316 message.ByteSizeLong());
317 }
318
319 } // namespace secagg
320 } // namespace fcp
321