1 /*
2  * Copyright 2019 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/secagg/server/secagg_server_r2_masked_input_coll_state.h"
18 
19 #include <algorithm>
20 #include <functional>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 
25 #include "fcp/base/monitoring.h"
26 #include "fcp/secagg/server/secagg_server_r3_unmasking_state.h"
27 
28 namespace fcp {
29 namespace secagg {
30 
SecAggServerR2MaskedInputCollState(std::unique_ptr<SecAggServerProtocolImpl> impl,int number_of_clients_failed_after_sending_masked_input,int number_of_clients_failed_before_sending_masked_input,int number_of_clients_terminated_without_unmasking)31 SecAggServerR2MaskedInputCollState::SecAggServerR2MaskedInputCollState(
32     std::unique_ptr<SecAggServerProtocolImpl> impl,
33     int number_of_clients_failed_after_sending_masked_input,
34     int number_of_clients_failed_before_sending_masked_input,
35     int number_of_clients_terminated_without_unmasking)
36     : SecAggServerState(number_of_clients_failed_after_sending_masked_input,
37                         number_of_clients_failed_before_sending_masked_input,
38                         number_of_clients_terminated_without_unmasking,
39                         SecAggServerStateKind::R2_MASKED_INPUT_COLLECTION,
40                         std::move(impl)) {
41   accumulator_ = this->impl()->SetupMaskedInputCollection();
42 }
43 
~SecAggServerR2MaskedInputCollState()44 SecAggServerR2MaskedInputCollState::~SecAggServerR2MaskedInputCollState() {}
45 
HandleMessage(uint32_t client_id,const ClientToServerWrapperMessage & message)46 Status SecAggServerR2MaskedInputCollState::HandleMessage(
47     uint32_t client_id, const ClientToServerWrapperMessage& message) {
48   return FCP_STATUS(FAILED_PRECONDITION)
49          << "Call to deprecated HandleMessage method.";
50 }
51 
HandleMessage(uint32_t client_id,std::unique_ptr<ClientToServerWrapperMessage> message)52 Status SecAggServerR2MaskedInputCollState::HandleMessage(
53     uint32_t client_id, std::unique_ptr<ClientToServerWrapperMessage> message) {
54   if (message->has_abort()) {
55     MessageReceived(*message, false);
56     AbortClient(client_id, "", ClientDropReason::SENT_ABORT_MESSAGE,
57                 /*notify=*/false);
58     return FCP_STATUS(OK);
59   }
60   // If the client has aborted already, ignore its messages.
61   if (client_status(client_id) != ClientStatus::SHARE_KEYS_RECEIVED) {
62     MessageReceived(*message, false);
63     AbortClient(client_id,
64                 "Not expecting an MaskedInputCollectionResponse from this "
65                 "client - either the client already aborted or one such "
66                 "message was already received.",
67                 ClientDropReason::MASKED_INPUT_UNEXPECTED);
68     return FCP_STATUS(OK);
69   }
70   if (!message->has_masked_input_response()) {
71     MessageReceived(*message, false);
72     AbortClient(client_id,
73                 "Message type received is different from what was expected.",
74                 ClientDropReason::UNEXPECTED_MESSAGE_TYPE);
75     return FCP_STATUS(OK);
76   }
77   MessageReceived(*message, true);
78 
79   Status check_and_accumulate_status =
80       impl()->HandleMaskedInputCollectionResponse(
81           std::make_unique<MaskedInputCollectionResponse>(
82               std::move(*message->mutable_masked_input_response())));
83   if (!check_and_accumulate_status.ok()) {
84     AbortClient(client_id, std::string(check_and_accumulate_status.message()),
85                 ClientDropReason::INVALID_MASKED_INPUT);
86     return FCP_STATUS(OK);
87   }
88   set_client_status(client_id, ClientStatus::MASKED_INPUT_RESPONSE_RECEIVED);
89   number_of_messages_received_in_this_round_++;
90   number_of_clients_ready_for_next_round_++;
91   return FCP_STATUS(OK);
92 }
93 
IsNumberOfIncludedInputsCommitted() const94 bool SecAggServerR2MaskedInputCollState::IsNumberOfIncludedInputsCommitted()
95     const {
96   return false;
97 }
98 
MinimumMessagesNeededForNextRound() const99 int SecAggServerR2MaskedInputCollState::MinimumMessagesNeededForNextRound()
100     const {
101   return std::max(0, minimum_number_of_clients_to_proceed() -
102                          number_of_clients_ready_for_next_round_);
103 }
104 
NumberOfIncludedInputs() const105 int SecAggServerR2MaskedInputCollState::NumberOfIncludedInputs() const {
106   return number_of_messages_received_in_this_round_;
107 }
108 
NumberOfPendingClients() const109 int SecAggServerR2MaskedInputCollState::NumberOfPendingClients() const {
110   return NumberOfAliveClients() - number_of_clients_ready_for_next_round_;
111 }
112 
HandleAbortClient(uint32_t client_id,ClientDropReason reason_code)113 void SecAggServerR2MaskedInputCollState::HandleAbortClient(
114     uint32_t client_id, ClientDropReason reason_code) {
115   if (client_status(client_id) ==
116       ClientStatus::MASKED_INPUT_RESPONSE_RECEIVED) {
117     number_of_clients_ready_for_next_round_--;
118     number_of_clients_failed_after_sending_masked_input_++;
119     set_client_status(client_id,
120                       ClientStatus::DEAD_AFTER_MASKED_INPUT_RESPONSE_RECEIVED);
121   } else {
122     number_of_clients_failed_before_sending_masked_input_++;
123     clients_aborted_at_round_2_.push_back(client_id);
124     set_client_status(client_id, ClientStatus::DEAD_AFTER_SHARE_KEYS_RECEIVED);
125   }
126   if (NumberOfAliveClients() < minimum_number_of_clients_to_proceed()) {
127     needs_to_abort_ = true;
128   }
129 }
130 
HandleAbort()131 void SecAggServerR2MaskedInputCollState::HandleAbort() {
132   if (accumulator_) {
133     accumulator_->Cancel();
134   }
135 }
136 
137 StatusOr<std::unique_ptr<SecAggServerState>>
ProceedToNextRound()138 SecAggServerR2MaskedInputCollState::ProceedToNextRound() {
139   if (!ReadyForNextRound()) {
140     return FCP_STATUS(UNAVAILABLE);
141   }
142   if (needs_to_abort_) {
143     std::string error_string = "Too many clients aborted.";
144     ServerToClientWrapperMessage message;
145     message.mutable_abort()->set_diagnostic_info(error_string);
146     message.mutable_abort()->set_early_success(false);
147     SendBroadcast(message);
148     HandleAbort();
149 
150     return AbortState(error_string,
151                       SecAggServerOutcome::NOT_ENOUGH_CLIENTS_REMAINING);
152   }
153 
154   // Close all clients that haven't yet sent a message.
155   for (int i = 0; i < total_number_of_clients(); ++i) {
156     if (!IsClientDead(i) &&
157         client_status(i) != ClientStatus::MASKED_INPUT_RESPONSE_RECEIVED) {
158       AbortClient(i,
159                   "Client did not send MaskedInputCollectionResponse before "
160                   "round transition.",
161                   ClientDropReason::NO_MASKED_INPUT);
162     }
163   }
164   // Send to each alive client the list of their aborted neighbors
165   for (int i = 0; i < total_number_of_clients(); ++i) {
166     if (IsClientDead(i)) {
167       continue;
168     }
169     ServerToClientWrapperMessage message_to_i;
170     // Set message to proper type
171     auto request = message_to_i.mutable_unmasking_request();
172     for (uint32_t aborted_client : clients_aborted_at_round_2_) {
173       //  neighbor_index has a value iff i and aborted_client are neighbors
174       auto neighbor_index = GetNeighborIndex(i, aborted_client);
175       if (neighbor_index.has_value()) {
176         // TODO(team): Stop adding + 1 here once we don't need
177         // compatibility.
178         request->add_dead_3_client_ids(neighbor_index.value() + 1);
179       }
180     }
181     Send(i, message_to_i);
182   }
183 
184   impl()->FinalizeMaskedInputCollection();
185 
186   return {std::make_unique<SecAggServerR3UnmaskingState>(
187       ExitState(StateTransition::kSuccess),
188       number_of_clients_failed_after_sending_masked_input_,
189       number_of_clients_failed_before_sending_masked_input_,
190       number_of_clients_terminated_without_unmasking_)};
191 }
192 
SetAsyncCallback(std::function<void ()> async_callback)193 bool SecAggServerR2MaskedInputCollState::SetAsyncCallback(
194     std::function<void()> async_callback) {
195   if (accumulator_) {
196     return accumulator_->SetAsyncObserver(async_callback);
197   }
198   return false;
199 }
200 
ReadyForNextRound() const201 bool SecAggServerR2MaskedInputCollState::ReadyForNextRound() const {
202   // Accumulator is not set (this is a synchronous session) or it does not have
203   // unobserved work.
204   bool accumulator_is_idle = (!accumulator_ || accumulator_->IsIdle());
205   return accumulator_is_idle && ((number_of_clients_ready_for_next_round_ >=
206                                   minimum_number_of_clients_to_proceed()) ||
207                                  (needs_to_abort_));
208 }
209 
210 }  // namespace secagg
211 }  // namespace fcp
212