xref: /aosp_15_r20/external/federated-compute/fcp/secagg/server/secagg_server_state.h (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2018 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef FCP_SECAGG_SERVER_SECAGG_SERVER_STATE_H_
18 #define FCP_SECAGG_SERVER_SECAGG_SERVER_STATE_H_
19 
20 #include <functional>
21 #include <memory>
22 #include <string>
23 
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/time/time.h"
26 #include "fcp/secagg/server/secagg_server_enums.pb.h"
27 #include "fcp/secagg/server/secagg_server_protocol_impl.h"
28 #include "fcp/secagg/server/tracing_schema.h"
29 #include "fcp/secagg/shared/secagg_messages.pb.h"
30 
31 namespace fcp {
32 namespace secagg {
33 
34 // This is an abstract class which is the parent of the other SecAggServer*State
35 // classes. It should not be instantiated directly. Default versions of all the
36 // methods declared here are provided for use by states which do not expect, and
37 // therefore do not implement, those methods.
38 
39 class SecAggServerState {
40  public:
41   // Returns the number of clients selected to be in the cohort for this
42   // instance of Secure Aggregation.
total_number_of_clients()43   inline const size_t total_number_of_clients() const {
44     return impl_->total_number_of_clients();
45   }
46 
47   // Returns the number of neighbors of each client.
number_of_neighbors()48   inline const int number_of_neighbors() const {
49     return impl_->number_of_neighbors();
50   }
51 
52   // Returns the minimum number of neighbors of a client that must not drop-out
53   // for that client's contribution to be included in the sum. This corresponds
54   // to the threshold in the shamir secret sharing of self and pairwise masks.
minimum_surviving_neighbors_for_reconstruction()55   inline const int minimum_surviving_neighbors_for_reconstruction() const {
56     return impl_->minimum_surviving_neighbors_for_reconstruction();
57   }
58 
59   // Returns the index of client_id_2 in the list of neighbors of client_id_1,
60   // if present
GetNeighborIndex(int client_id_1,int client_id_2)61   inline const std::optional<int> GetNeighborIndex(int client_id_1,
62                                                    int client_id_2) const {
63     return impl_->GetNeighborIndex(client_id_1, client_id_2);
64   }
65 
66   // EnterState must be called just after transitioning to a state.
67   // States may use this to initialize their state or trigger work.
EnterState()68   virtual void EnterState() {}
69 
70   // Processes the received message in a way consistent with the current state.
71   //
72   // Returns OK status to indicate that the message has been handled
73   // successfully.
74   //
75   // Returns a FAILED_PRECONDITION status if the server is in a state from which
76   // it does not expect to receive any messages. In that case no reply will be
77   // sent.
78   virtual Status HandleMessage(uint32_t client_id,
79                                const ClientToServerWrapperMessage& message);
80   // Analog of the above method, bu giving ownership of the message.
81   virtual Status HandleMessage(
82       uint32_t client_id,
83       std::unique_ptr<ClientToServerWrapperMessage> message);
84 
85   // Proceeds to the next round, doing all necessary computation and sending
86   // messages to clients as appropriate. If the server is not yet ready to
87   // proceed, returns an UNAVAILABLE status.
88   //
89   // If the server is in a terminal state, returns a FAILED_PRECONDITION status.
90   //
91   // Otherwise, returns the new state. This may be an abort state if the server
92   // has aborted.
93   //
94   // If this method returns a new state (i.e. if the status is OK), then the old
95   // state is no longer valid and the new state must be considered the current
96   // state. If it returns a non-OK status, this method does not change the
97   // underlying state.
98   virtual StatusOr<std::unique_ptr<SecAggServerState>> ProceedToNextRound();
99 
100   // Returns true if the client state is considered to be "dead" e.g. aborted or
101   // disconnected; otherwise returns false.
102   bool IsClientDead(uint32_t client_id) const;
103 
104   // Abort the specified client for the given reason. If notify is true, send a
105   // notification message to the client. (If the client was already closed, no
106   // message will be sent).
107   //
108   // The reason code will be used for recording metrics if log_metrics is true,
109   // else no metrics are recorded. By default, metrics will always be logged.
110   void AbortClient(uint32_t client_id, const std::string& reason,
111                    ClientDropReason reason_code, bool notify = true,
112                    bool log_metrics = true);
113 
114   // Aborts the protocol for the specified reason. Notifies all clients of
115   // the abort. Returns the new state.
116   // Calling this method on a terminal state isn't valid.
117   std::unique_ptr<SecAggServerState> Abort(const std::string& reason,
118                                            SecAggServerOutcome outcome);
119 
120   // Returns true if the current state is Abort, false else.
121   virtual bool IsAborted() const;
122 
123   // Returns true if the current state is ProtocolCompleted, false else.
124   virtual bool IsCompletedSuccessfully() const;
125 
126   // Returns an error message explaining why the server aborted, if the current
127   // state is an abort state. If not returns an error Status with code
128   // FAILED_PRECONDITION.
129   virtual StatusOr<std::string> ErrorMessage() const;
130 
131   // Returns an enum specifying the current state.
132   SecAggServerStateKind State() const;
133 
134   // Returns the name of the current state in the form of a short string.
135   std::string StateName() const;
136 
137   // Returns whether or not the server has received enough messages to be ready
138   // for the next phase of the protocol.
139   // In the PRNG Running state, it returns whether or not the PRNG has stopped
140   // running.
141   // Always false in a terminal state.
142   virtual bool ReadyForNextRound() const;
143 
144   // Returns the number of valid messages received by clients this round.
145   int NumberOfMessagesReceivedInThisRound() const;
146 
147   // Returns the number of clients that would still be alive if
148   // ProceedToNextRound were called immediately after. This value may be less
149   // than NumberOfMessagesReceivedInThisRound if a client fails after sending a
150   // message in this round.
151   // Note that this value is not guaranteed to be monotonically increasing, even
152   // within a round. Client failures can cause this value to decrease.
153   virtual int NumberOfClientsReadyForNextRound() const;
154 
155   // Indicates the total number of clients that the server expects to receive a
156   // response from in this round (i.e. the ones that have not aborted).
157   // In the COMPLETED state, this returns the number of clients that survived to
158   // the final protocol message.
159   virtual int NumberOfAliveClients() const;
160 
161   // Number of clients that failed before submitting their masked input. These
162   // clients' inputs won't be included in the aggregate value, even if the
163   // protocol succeeds.
164   int NumberOfClientsFailedBeforeSendingMaskedInput() const;
165 
166   // Number of clients that failed after submitting their masked input. These
167   // clients' inputs will be included in the aggregate value, even though these
168   // clients did not complete the protocol.
169   int NumberOfClientsFailedAfterSendingMaskedInput() const;
170 
171   // Number of clients that submitted a masked value, but didn't report their
172   // unmasking values fast enough to have them used in the final unmasking
173   // process. These clients' inputs will be included in the aggregate value.
174   int NumberOfClientsTerminatedWithoutUnmasking() const;
175 
176   // Returns the number of live clients that have not yet submitted the expected
177   // response for the current round. In terminal states, this will be 0.
178   virtual int NumberOfPendingClients() const;
179 
180   // Returns the number of inputs that will appear in the final sum, if the
181   // protocol completes.
182   // Once IsNumberOfIncludedInputsCommitted is true, this value will be fixed
183   // for the remainder of the protocol.
184   // This will be 0 if the server is aborted. This will also be 0 if the server
185   // is in an early state, prior to receiving masked inputs. It is incremented
186   // only when the server receives a masked input from a client.
187   virtual int NumberOfIncludedInputs() const;
188 
189   // Whether the set of inputs that will be included in the final aggregation is
190   // fixed.
191   // If true, the value of NumberOfIncludedInputs will be fixed for the
192   // remainder of the protocol.
193   virtual bool IsNumberOfIncludedInputsCommitted() const = 0;
194 
195   // Indicates the minimum number of valid messages needed to be able to
196   // successfully move to the next round.
197   // Note that this value is not guaranteed to be monotonically decreasing.
198   // Client failures can cause this value to increase.
199   // In terminal states, this returns 0.
200   virtual int MinimumMessagesNeededForNextRound() const;
201 
202   // Returns the minimum threshold number of clients that need to send valid
203   // responses in order for the protocol to proceed from one round to the next.
minimum_number_of_clients_to_proceed()204   inline const int minimum_number_of_clients_to_proceed() const {
205     return impl_->minimum_number_of_clients_to_proceed();
206   }
207 
208   // Returns the set of clients that aborted the protocol. Can be used by the
209   // caller to close the relevant RPC connections or just start ignoring
210   // incoming messages from those clients for performance reasons.
211   absl::flat_hash_set<uint32_t> AbortedClientIds() const;
212 
213   // Returns true if the server has determined that it needs to abort itself,
214   // If the server is in a terminal state, returns false.
215   bool NeedsToAbort() const;
216 
217   // Sets up a callback to be triggered when any background asynchronous work
218   // has been done. The callback is guaranteed to invoked via the server's
219   // callback scheduler.
220   //
221   // Returns true if the state supports asynchronous processing and the callback
222   // has been setup successfully.
223   // Returns false if the state doesn't support asynchronous processing or if
224   // no further asynchronous processing is possible. The callback argument is
225   // ignored in this case.
226   virtual bool SetAsyncCallback(std::function<void()> async_callback);
227 
228   // Transfers ownership of the result of the protocol to the caller. Requires
229   // the server to be in a completed state; returns UNAVAILABLE otherwise.
230   // Can be called only once; any consecutive calls result in an error.
231   virtual StatusOr<std::unique_ptr<SecAggVectorMap>> Result();
232 
233   virtual ~SecAggServerState();
234 
235  protected:
236   // SecAggServerState should never be instantiated directly.
237   SecAggServerState(int number_of_clients_failed_after_sending_masked_input,
238                     int number_of_clients_failed_before_sending_masked_input,
239                     int number_of_clients_terminated_without_unmasking,
240                     SecAggServerStateKind state_kind,
241                     std::unique_ptr<SecAggServerProtocolImpl> impl);
242 
impl()243   SecAggServerProtocolImpl* impl() { return impl_.get(); }
244 
245   // Returns the callback interface for recording metrics.
metrics()246   inline SecAggServerMetricsListener* metrics() const {
247     return impl_->metrics();
248   }
249 
250   // Returns the callback interface for sending protocol buffer messages to the
251   // client.
sender()252   inline SendToClientsInterface* sender() const { return impl_->sender(); }
253 
client_status(uint32_t client_id)254   inline const ClientStatus& client_status(uint32_t client_id) const {
255     return impl_->client_status(client_id);
256   }
257 
set_client_status(uint32_t client_id,ClientStatus status)258   inline void set_client_status(uint32_t client_id, ClientStatus status) {
259     impl_->set_client_status(client_id, status);
260   }
261 
262   // Records information about a message that was received from a client.
263   void MessageReceived(const ClientToServerWrapperMessage& message,
264                        bool expected);
265 
266   // Broadcasts the message and records metrics.
267   void SendBroadcast(const ServerToClientWrapperMessage& message);
268 
269   // Sends the message to the given client and records metrics.
270   void Send(uint32_t recipient_id, const ServerToClientWrapperMessage& message);
271 
272   // Returns an aborted version of the current state, storing the specified
273   // reason. Calling this method makes the current state unusable. The caller is
274   // responsible for sending any failure messages that need to be sent, and for
275   // doing so BEFORE calling this method.
276   // The SecAggServerOutcome outcome is used for recording metrics.
277   std::unique_ptr<SecAggServerState> AbortState(const std::string& reason,
278                                                 SecAggServerOutcome outcome);
279 
280   // ExitState must be called on the current state just before transitioning to
281   // a new state to record metrics and transfer out the shared state.
282   enum class StateTransition {
283     // Indicates a successful state transition to any state other than Aborted.
284     kSuccess = 0,
285     // Indicates transition to Aborted state.
286     kAbort = 1
287   };
288   std::unique_ptr<SecAggServerProtocolImpl>&& ExitState(
289       StateTransition state_transition_status);
290 
291   bool needs_to_abort_;
292   int number_of_clients_failed_after_sending_masked_input_;
293   int number_of_clients_failed_before_sending_masked_input_;
294   int number_of_clients_ready_for_next_round_;
295   int number_of_clients_terminated_without_unmasking_;
296   int number_of_messages_received_in_this_round_;
297   absl::Time round_start_;
298   SecAggServerStateKind state_kind_;
299 
300  private:
301   // Performs state specific action when a client is aborted.
HandleAbortClient(uint32_t client_id,ClientDropReason reason_code)302   virtual void HandleAbortClient(uint32_t client_id,
303                                  ClientDropReason reason_code) {}
304 
305   // Performs state specific action when the server is aborted.
HandleAbort()306   virtual void HandleAbort() {}
307 
308   std::unique_ptr<SecAggServerProtocolImpl> impl_;
309 };
310 
311 }  // namespace secagg
312 }  // namespace fcp
313 
314 #endif  // FCP_SECAGG_SERVER_SECAGG_SERVER_STATE_H_
315