xref: /aosp_15_r20/external/federated-compute/fcp/secagg/client/secagg_client_r3_unmasking_state.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2018 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/secagg/client/secagg_client_r3_unmasking_state.h"
18 
19 #include <cstdint>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 
24 #include "fcp/base/monitoring.h"
25 #include "fcp/secagg/client/other_client_state.h"
26 #include "fcp/secagg/client/secagg_client_aborted_state.h"
27 #include "fcp/secagg/client/secagg_client_alive_base_state.h"
28 #include "fcp/secagg/client/secagg_client_completed_state.h"
29 #include "fcp/secagg/client/secagg_client_state.h"
30 #include "fcp/secagg/client/send_to_server_interface.h"
31 #include "fcp/secagg/client/state_transition_listener_interface.h"
32 #include "fcp/secagg/shared/secagg_messages.pb.h"
33 #include "fcp/secagg/shared/secagg_vector.h"
34 
35 namespace fcp {
36 namespace secagg {
37 
SecAggClientR3UnmaskingState(uint32_t client_id,uint32_t number_of_alive_neighbors,uint32_t minimum_surviving_neighbors_for_reconstruction,uint32_t number_of_neighbors,std::unique_ptr<std::vector<OtherClientState>> other_client_states,std::unique_ptr<std::vector<ShamirShare>> pairwise_key_shares,std::unique_ptr<std::vector<ShamirShare>> self_key_shares,std::unique_ptr<SendToServerInterface> sender,std::unique_ptr<StateTransitionListenerInterface> transition_listener,AsyncAbort * async_abort)38 SecAggClientR3UnmaskingState::SecAggClientR3UnmaskingState(
39     uint32_t client_id, uint32_t number_of_alive_neighbors,
40     uint32_t minimum_surviving_neighbors_for_reconstruction,
41     uint32_t number_of_neighbors,
42     std::unique_ptr<std::vector<OtherClientState> > other_client_states,
43     std::unique_ptr<std::vector<ShamirShare> > pairwise_key_shares,
44     std::unique_ptr<std::vector<ShamirShare> > self_key_shares,
45     std::unique_ptr<SendToServerInterface> sender,
46     std::unique_ptr<StateTransitionListenerInterface> transition_listener,
47     AsyncAbort* async_abort)
48     : SecAggClientAliveBaseState(std::move(sender),
49                                  std::move(transition_listener),
50                                  ClientState::R3_UNMASKING, async_abort),
51       client_id_(client_id),
52       number_of_alive_neighbors_(number_of_alive_neighbors),
53       minimum_surviving_neighbors_for_reconstruction_(
54           minimum_surviving_neighbors_for_reconstruction),
55       number_of_neighbors_(number_of_neighbors),
56       other_client_states_(std::move(other_client_states)),
57       pairwise_key_shares_(std::move(pairwise_key_shares)),
58       self_key_shares_(std::move(self_key_shares)) {
59   FCP_CHECK(client_id_ >= 0)
60       << "Client id must not be negative but was " << client_id_;
61 }
62 
63 SecAggClientR3UnmaskingState::~SecAggClientR3UnmaskingState() = default;
64 
65 StatusOr<std::unique_ptr<SecAggClientState> >
HandleMessage(const ServerToClientWrapperMessage & message)66 SecAggClientR3UnmaskingState::HandleMessage(
67     const ServerToClientWrapperMessage& message) {
68   // Handle abort messages or unmasking requests only.
69   if (message.has_abort()) {
70     if (message.abort().early_success()) {
71       return {std::make_unique<SecAggClientCompletedState>(
72           std::move(sender_), std::move(transition_listener_))};
73     } else {
74       return {std::make_unique<SecAggClientAbortedState>(
75           "Aborting because of abort message from the server.",
76           std::move(sender_), std::move(transition_listener_))};
77     }
78   } else if (!message.has_unmasking_request()) {
79     // Returns an error indicating that the message is of invalid type.
80     return SecAggClientState::HandleMessage(message);
81   }
82   if (async_abort_ && async_abort_->Signalled())
83     return AbortAndNotifyServer(async_abort_->Message());
84 
85   const UnmaskingRequest& request = message.unmasking_request();
86   std::set<uint32_t> dead_at_round_3_client_ids;
87 
88   // Parse incoming request and mark dead clients as dead.
89   for (uint32_t i : request.dead_3_client_ids()) {
90     // TODO(team): Remove this once backwards compatibility not needed.
91     uint32_t id = i - 1;
92     if (id == client_id_) {
93       return AbortAndNotifyServer(
94           "The received UnmaskingRequest states this client has aborted, but "
95           "this client had not yet aborted.");
96     } else if (id >= number_of_neighbors_) {
97       return AbortAndNotifyServer(
98           "The received UnmaskingRequest contains a client id that does not "
99           "correspond to any client.");
100     }
101     switch ((*other_client_states_)[id]) {
102       case OtherClientState::kAlive:
103         (*other_client_states_)[id] = OtherClientState::kDeadAtRound3;
104         --number_of_alive_neighbors_;
105         break;
106       case OtherClientState::kDeadAtRound3:
107         return AbortAndNotifyServer(
108             "The received UnmaskingRequest repeated a client more than once "
109             "as a dead client.");
110         break;
111       case OtherClientState::kDeadAtRound1:
112       case OtherClientState::kDeadAtRound2:
113       default:
114         return AbortAndNotifyServer(
115             "The received UnmaskingRequest considers a client dead in round 3 "
116             "that was already considered dead.");
117         break;
118     }
119   }
120 
121   if (number_of_alive_neighbors_ <
122       minimum_surviving_neighbors_for_reconstruction_) {
123     return AbortAndNotifyServer(
124         "Not enough clients survived. The server should not have sent this "
125         "UnmaskingRequest.");
126   }
127 
128   /*
129    * Construct a response for the server by choosing the appropriate shares for
130    * each client (i.e. the pairwise share if the client died at round 3, the
131    * self share if the client is alive, or no shares at all if the client died
132    * at or before round 2.
133    */
134   ClientToServerWrapperMessage message_to_server;
135   UnmaskingResponse* unmasking_response =
136       message_to_server.mutable_unmasking_response();
137   for (uint32_t i = 0; i < number_of_neighbors_; ++i) {
138     if (async_abort_ && async_abort_->Signalled())
139       return AbortAndNotifyServer(async_abort_->Message());
140     switch ((*other_client_states_)[i]) {
141       case OtherClientState::kAlive:
142         unmasking_response->add_noise_or_prf_key_shares()->set_prf_sk_share(
143             (*self_key_shares_)[i].data);
144         break;
145       case OtherClientState::kDeadAtRound3:
146         unmasking_response->add_noise_or_prf_key_shares()->set_noise_sk_share(
147             (*pairwise_key_shares_)[i].data);
148         break;
149       case OtherClientState::kDeadAtRound1:
150       case OtherClientState::kDeadAtRound2:
151       default:
152         unmasking_response->add_noise_or_prf_key_shares();
153         break;
154     }
155   }
156 
157   // Send this final message to the server, then enter Completed state.
158   sender_->Send(&message_to_server);
159   return {std::make_unique<SecAggClientCompletedState>(
160       std::move(sender_), std::move(transition_listener_))};
161 }
162 
StateName() const163 std::string SecAggClientR3UnmaskingState::StateName() const {
164   return "R3_UNMASKING";
165 }
166 
167 }  // namespace secagg
168 }  // namespace fcp
169