1 /*
2 * Copyright 2019 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fcp/secagg/server/secagg_server_r2_masked_input_coll_state.h"
18
19 #include <algorithm>
20 #include <functional>
21 #include <memory>
22 #include <string>
23 #include <utility>
24
25 #include "fcp/base/monitoring.h"
26 #include "fcp/secagg/server/secagg_server_r3_unmasking_state.h"
27
28 namespace fcp {
29 namespace secagg {
30
SecAggServerR2MaskedInputCollState(std::unique_ptr<SecAggServerProtocolImpl> impl,int number_of_clients_failed_after_sending_masked_input,int number_of_clients_failed_before_sending_masked_input,int number_of_clients_terminated_without_unmasking)31 SecAggServerR2MaskedInputCollState::SecAggServerR2MaskedInputCollState(
32 std::unique_ptr<SecAggServerProtocolImpl> impl,
33 int number_of_clients_failed_after_sending_masked_input,
34 int number_of_clients_failed_before_sending_masked_input,
35 int number_of_clients_terminated_without_unmasking)
36 : SecAggServerState(number_of_clients_failed_after_sending_masked_input,
37 number_of_clients_failed_before_sending_masked_input,
38 number_of_clients_terminated_without_unmasking,
39 SecAggServerStateKind::R2_MASKED_INPUT_COLLECTION,
40 std::move(impl)) {
41 accumulator_ = this->impl()->SetupMaskedInputCollection();
42 }
43
~SecAggServerR2MaskedInputCollState()44 SecAggServerR2MaskedInputCollState::~SecAggServerR2MaskedInputCollState() {}
45
HandleMessage(uint32_t client_id,const ClientToServerWrapperMessage & message)46 Status SecAggServerR2MaskedInputCollState::HandleMessage(
47 uint32_t client_id, const ClientToServerWrapperMessage& message) {
48 return FCP_STATUS(FAILED_PRECONDITION)
49 << "Call to deprecated HandleMessage method.";
50 }
51
HandleMessage(uint32_t client_id,std::unique_ptr<ClientToServerWrapperMessage> message)52 Status SecAggServerR2MaskedInputCollState::HandleMessage(
53 uint32_t client_id, std::unique_ptr<ClientToServerWrapperMessage> message) {
54 if (message->has_abort()) {
55 MessageReceived(*message, false);
56 AbortClient(client_id, "", ClientDropReason::SENT_ABORT_MESSAGE,
57 /*notify=*/false);
58 return FCP_STATUS(OK);
59 }
60 // If the client has aborted already, ignore its messages.
61 if (client_status(client_id) != ClientStatus::SHARE_KEYS_RECEIVED) {
62 MessageReceived(*message, false);
63 AbortClient(client_id,
64 "Not expecting an MaskedInputCollectionResponse from this "
65 "client - either the client already aborted or one such "
66 "message was already received.",
67 ClientDropReason::MASKED_INPUT_UNEXPECTED);
68 return FCP_STATUS(OK);
69 }
70 if (!message->has_masked_input_response()) {
71 MessageReceived(*message, false);
72 AbortClient(client_id,
73 "Message type received is different from what was expected.",
74 ClientDropReason::UNEXPECTED_MESSAGE_TYPE);
75 return FCP_STATUS(OK);
76 }
77 MessageReceived(*message, true);
78
79 Status check_and_accumulate_status =
80 impl()->HandleMaskedInputCollectionResponse(
81 std::make_unique<MaskedInputCollectionResponse>(
82 std::move(*message->mutable_masked_input_response())));
83 if (!check_and_accumulate_status.ok()) {
84 AbortClient(client_id, std::string(check_and_accumulate_status.message()),
85 ClientDropReason::INVALID_MASKED_INPUT);
86 return FCP_STATUS(OK);
87 }
88 set_client_status(client_id, ClientStatus::MASKED_INPUT_RESPONSE_RECEIVED);
89 number_of_messages_received_in_this_round_++;
90 number_of_clients_ready_for_next_round_++;
91 return FCP_STATUS(OK);
92 }
93
IsNumberOfIncludedInputsCommitted() const94 bool SecAggServerR2MaskedInputCollState::IsNumberOfIncludedInputsCommitted()
95 const {
96 return false;
97 }
98
MinimumMessagesNeededForNextRound() const99 int SecAggServerR2MaskedInputCollState::MinimumMessagesNeededForNextRound()
100 const {
101 return std::max(0, minimum_number_of_clients_to_proceed() -
102 number_of_clients_ready_for_next_round_);
103 }
104
NumberOfIncludedInputs() const105 int SecAggServerR2MaskedInputCollState::NumberOfIncludedInputs() const {
106 return number_of_messages_received_in_this_round_;
107 }
108
NumberOfPendingClients() const109 int SecAggServerR2MaskedInputCollState::NumberOfPendingClients() const {
110 return NumberOfAliveClients() - number_of_clients_ready_for_next_round_;
111 }
112
HandleAbortClient(uint32_t client_id,ClientDropReason reason_code)113 void SecAggServerR2MaskedInputCollState::HandleAbortClient(
114 uint32_t client_id, ClientDropReason reason_code) {
115 if (client_status(client_id) ==
116 ClientStatus::MASKED_INPUT_RESPONSE_RECEIVED) {
117 number_of_clients_ready_for_next_round_--;
118 number_of_clients_failed_after_sending_masked_input_++;
119 set_client_status(client_id,
120 ClientStatus::DEAD_AFTER_MASKED_INPUT_RESPONSE_RECEIVED);
121 } else {
122 number_of_clients_failed_before_sending_masked_input_++;
123 clients_aborted_at_round_2_.push_back(client_id);
124 set_client_status(client_id, ClientStatus::DEAD_AFTER_SHARE_KEYS_RECEIVED);
125 }
126 if (NumberOfAliveClients() < minimum_number_of_clients_to_proceed()) {
127 needs_to_abort_ = true;
128 }
129 }
130
HandleAbort()131 void SecAggServerR2MaskedInputCollState::HandleAbort() {
132 if (accumulator_) {
133 accumulator_->Cancel();
134 }
135 }
136
137 StatusOr<std::unique_ptr<SecAggServerState>>
ProceedToNextRound()138 SecAggServerR2MaskedInputCollState::ProceedToNextRound() {
139 if (!ReadyForNextRound()) {
140 return FCP_STATUS(UNAVAILABLE);
141 }
142 if (needs_to_abort_) {
143 std::string error_string = "Too many clients aborted.";
144 ServerToClientWrapperMessage message;
145 message.mutable_abort()->set_diagnostic_info(error_string);
146 message.mutable_abort()->set_early_success(false);
147 SendBroadcast(message);
148 HandleAbort();
149
150 return AbortState(error_string,
151 SecAggServerOutcome::NOT_ENOUGH_CLIENTS_REMAINING);
152 }
153
154 // Close all clients that haven't yet sent a message.
155 for (int i = 0; i < total_number_of_clients(); ++i) {
156 if (!IsClientDead(i) &&
157 client_status(i) != ClientStatus::MASKED_INPUT_RESPONSE_RECEIVED) {
158 AbortClient(i,
159 "Client did not send MaskedInputCollectionResponse before "
160 "round transition.",
161 ClientDropReason::NO_MASKED_INPUT);
162 }
163 }
164 // Send to each alive client the list of their aborted neighbors
165 for (int i = 0; i < total_number_of_clients(); ++i) {
166 if (IsClientDead(i)) {
167 continue;
168 }
169 ServerToClientWrapperMessage message_to_i;
170 // Set message to proper type
171 auto request = message_to_i.mutable_unmasking_request();
172 for (uint32_t aborted_client : clients_aborted_at_round_2_) {
173 // neighbor_index has a value iff i and aborted_client are neighbors
174 auto neighbor_index = GetNeighborIndex(i, aborted_client);
175 if (neighbor_index.has_value()) {
176 // TODO(team): Stop adding + 1 here once we don't need
177 // compatibility.
178 request->add_dead_3_client_ids(neighbor_index.value() + 1);
179 }
180 }
181 Send(i, message_to_i);
182 }
183
184 impl()->FinalizeMaskedInputCollection();
185
186 return {std::make_unique<SecAggServerR3UnmaskingState>(
187 ExitState(StateTransition::kSuccess),
188 number_of_clients_failed_after_sending_masked_input_,
189 number_of_clients_failed_before_sending_masked_input_,
190 number_of_clients_terminated_without_unmasking_)};
191 }
192
SetAsyncCallback(std::function<void ()> async_callback)193 bool SecAggServerR2MaskedInputCollState::SetAsyncCallback(
194 std::function<void()> async_callback) {
195 if (accumulator_) {
196 return accumulator_->SetAsyncObserver(async_callback);
197 }
198 return false;
199 }
200
ReadyForNextRound() const201 bool SecAggServerR2MaskedInputCollState::ReadyForNextRound() const {
202 // Accumulator is not set (this is a synchronous session) or it does not have
203 // unobserved work.
204 bool accumulator_is_idle = (!accumulator_ || accumulator_->IsIdle());
205 return accumulator_is_idle && ((number_of_clients_ready_for_next_round_ >=
206 minimum_number_of_clients_to_proceed()) ||
207 (needs_to_abort_));
208 }
209
210 } // namespace secagg
211 } // namespace fcp
212