xref: /aosp_15_r20/external/federated-compute/fcp/secagg/server/secagg_server_r3_unmasking_state.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2018 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/secagg/server/secagg_server_r3_unmasking_state.h"
18 
19 #include <algorithm>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 
24 #include "fcp/base/monitoring.h"
25 #include "fcp/secagg/server/secagg_server_prng_running_state.h"
26 
27 namespace fcp {
28 namespace secagg {
29 
SecAggServerR3UnmaskingState(std::unique_ptr<SecAggServerProtocolImpl> impl,int number_of_clients_failed_after_sending_masked_input,int number_of_clients_failed_before_sending_masked_input,int number_of_clients_terminated_without_unmasking)30 SecAggServerR3UnmaskingState::SecAggServerR3UnmaskingState(
31     std::unique_ptr<SecAggServerProtocolImpl> impl,
32     int number_of_clients_failed_after_sending_masked_input,
33     int number_of_clients_failed_before_sending_masked_input,
34     int number_of_clients_terminated_without_unmasking)
35     : SecAggServerState(number_of_clients_failed_after_sending_masked_input,
36                         number_of_clients_failed_before_sending_masked_input,
37                         number_of_clients_terminated_without_unmasking,
38                         SecAggServerStateKind::R3_UNMASKING, std::move(impl)) {
39   this->impl()->SetUpShamirSharesTables();
40 }
41 
~SecAggServerR3UnmaskingState()42 SecAggServerR3UnmaskingState::~SecAggServerR3UnmaskingState() {}
43 
HandleMessage(uint32_t client_id,const ClientToServerWrapperMessage & message)44 Status SecAggServerR3UnmaskingState::HandleMessage(
45     uint32_t client_id, const ClientToServerWrapperMessage& message) {
46   if (message.has_abort()) {
47     MessageReceived(message, false);
48     AbortClient(client_id, "Client sent abort message.",
49                 ClientDropReason::SENT_ABORT_MESSAGE,
50                 /*notify=*/false);
51     return FCP_STATUS(OK);
52   }
53   // If the client has aborted already, ignore its messages.
54   if (client_status(client_id) !=
55       ClientStatus::MASKED_INPUT_RESPONSE_RECEIVED) {
56     MessageReceived(message, false);
57     AbortClient(
58         client_id,
59         "Not expecting an UnmaskingResponse from this client - either the "
60         "client already aborted or one such message was already received.",
61         ClientDropReason::UNMASKING_RESPONSE_UNEXPECTED);
62     return FCP_STATUS(OK);
63   }
64   if (!message.has_unmasking_response()) {
65     MessageReceived(message, false);
66     AbortClient(client_id,
67                 "Message type received is different from what was expected.",
68                 ClientDropReason::UNEXPECTED_MESSAGE_TYPE);
69     return FCP_STATUS(OK);
70   }
71   MessageReceived(message, true);
72 
73   Status status =
74       impl()->HandleUnmaskingResponse(client_id, message.unmasking_response());
75   if (!status.ok()) {
76     AbortClient(client_id, std::string(status.message()),
77                 ClientDropReason::INVALID_UNMASKING_RESPONSE);
78     return FCP_STATUS(OK);
79   }
80 
81   set_client_status(client_id, ClientStatus::UNMASKING_RESPONSE_RECEIVED);
82   number_of_messages_received_in_this_round_++;
83   number_of_clients_ready_for_next_round_++;
84   return FCP_STATUS(OK);
85 }
86 
IsNumberOfIncludedInputsCommitted() const87 bool SecAggServerR3UnmaskingState::IsNumberOfIncludedInputsCommitted() const {
88   return true;
89 }
90 
MinimumMessagesNeededForNextRound() const91 int SecAggServerR3UnmaskingState::MinimumMessagesNeededForNextRound() const {
92   return std::max(0, minimum_number_of_clients_to_proceed() -
93                          number_of_messages_received_in_this_round_);
94 }
95 
NumberOfIncludedInputs() const96 int SecAggServerR3UnmaskingState::NumberOfIncludedInputs() const {
97   return total_number_of_clients() -
98          number_of_clients_failed_before_sending_masked_input_;
99 }
100 
NumberOfPendingClients() const101 int SecAggServerR3UnmaskingState::NumberOfPendingClients() const {
102   return NumberOfAliveClients() - number_of_clients_ready_for_next_round_;
103 }
104 
HandleAbortClient(uint32_t client_id,ClientDropReason reason_code)105 void SecAggServerR3UnmaskingState::HandleAbortClient(
106     uint32_t client_id, ClientDropReason reason_code) {
107   if (client_status(client_id) == ClientStatus::UNMASKING_RESPONSE_RECEIVED) {
108     set_client_status(client_id,
109                       ClientStatus::DEAD_AFTER_UNMASKING_RESPONSE_RECEIVED);
110     return;
111   }
112   if (reason_code == ClientDropReason::EARLY_SUCCESS) {
113     number_of_clients_terminated_without_unmasking_++;
114   } else {
115     number_of_clients_failed_after_sending_masked_input_++;
116   }
117   set_client_status(client_id,
118                     ClientStatus::DEAD_AFTER_MASKED_INPUT_RESPONSE_RECEIVED);
119   if (NumberOfPendingClients() + number_of_messages_received_in_this_round_ <
120       minimum_number_of_clients_to_proceed()) {
121     needs_to_abort_ = true;
122   }
123 }
124 
ReadyForNextRound() const125 bool SecAggServerR3UnmaskingState::ReadyForNextRound() const {
126   return (number_of_messages_received_in_this_round_ >=
127           minimum_number_of_clients_to_proceed()) ||
128          (needs_to_abort_);
129 }
130 
131 StatusOr<std::unique_ptr<SecAggServerState> >
ProceedToNextRound()132 SecAggServerR3UnmaskingState::ProceedToNextRound() {
133   if (!ReadyForNextRound()) {
134     return FCP_STATUS(UNAVAILABLE);
135   }
136   if (needs_to_abort_) {
137     std::string error_string = "Too many clients aborted.";
138     ServerToClientWrapperMessage message;
139     message.mutable_abort()->set_diagnostic_info(error_string);
140     message.mutable_abort()->set_early_success(false);
141     SendBroadcast(message);
142 
143     return AbortState(error_string,
144                       SecAggServerOutcome::NOT_ENOUGH_CLIENTS_REMAINING);
145   }
146 
147   // Abort all clients that haven't yet sent a message, but let them count it as
148   // a success.
149   for (int i = 0; i < total_number_of_clients(); ++i) {
150     if (client_status(i) != ClientStatus::UNMASKING_RESPONSE_RECEIVED) {
151       AbortClient(
152           i,
153           "Client did not send unmasking response but protocol completed "
154           "successfully.",
155           ClientDropReason::EARLY_SUCCESS);
156     }
157   }
158 
159   return {std::make_unique<SecAggServerPrngRunningState>(
160       ExitState(StateTransition::kSuccess),
161       number_of_clients_failed_after_sending_masked_input_,
162       number_of_clients_failed_before_sending_masked_input_,
163       number_of_clients_terminated_without_unmasking_)};
164 }
165 
166 }  // namespace secagg
167 }  // namespace fcp
168