xref: /aosp_15_r20/external/federated-compute/fcp/secagg/server/secagg_server.h (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2019 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef FCP_SECAGG_SERVER_SECAGG_SERVER_H_
18 #define FCP_SECAGG_SERVER_SECAGG_SERVER_H_
19 
20 #include <functional>
21 #include <memory>
22 #include <string>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/container/node_hash_set.h"
27 #include "fcp/base/monitoring.h"
28 #include "fcp/base/scheduler.h"
29 #include "fcp/secagg/server/experiments_interface.h"
30 #include "fcp/secagg/server/experiments_names.h"
31 #include "fcp/secagg/server/secagg_scheduler.h"
32 #include "fcp/secagg/server/secagg_server_enums.pb.h"
33 #include "fcp/secagg/server/secagg_server_messages.pb.h"
34 #include "fcp/secagg/server/secagg_server_metrics_listener.h"
35 #include "fcp/secagg/server/secagg_server_state.h"
36 #include "fcp/secagg/server/secret_sharing_graph.h"
37 #include "fcp/secagg/server/tracing_schema.h"
38 #include "fcp/secagg/shared/aes_prng_factory.h"
39 #include "fcp/secagg/shared/secagg_messages.pb.h"
40 #include "fcp/tracing/tracing_span.h"
41 
42 namespace fcp {
43 namespace secagg {
44 
45 // Represents a server for the Secure Aggregation protocol. Each instance of
46 // this class performs just *one* session of the protocol.
47 //
48 // To create a new instance, use the public constructor. Once constructed, the
49 // server is ready to receive messages from clients with the ReceiveMessage
50 // method.
51 //
52 // When enough messages have been received (i.e. when ReceiveMessage or
53 // ReadyForNextRound return true) or any time after that, proceed to the next
54 // round by calling ProceedToNextRound.
55 //
56 // After all client interaction is done, the server needs to do some
57 // multi-threaded computation using the supplied Scheduler. Call StartPrng to
58 // begin this computation.
59 //
60 // When the computation is complete, call Result to get the final result.
61 //
62 // This class is not thread-safe.
63 
64 class SecAggServer {
65  public:
66   // Constructs a new instance of the Secure Aggregation server.
67   //
68   // minimum_number_of_clients_to_proceed is the threshold lower bound on the
69   // total number of clients expected to complete the protocol. If there are
70   // ever fewer than this many clients still alive in the protocol, the server
71   // will abort (causing all clients to abort as well).
72   //
73   // total_number_of_clients is the number of clients selected to be in the
74   // cohort for this instance of Secure Aggregation.
75   //
76   // input_vector_specs must contain one InputVectorSpecification for each input
77   // vector which the protocol will aggregate.
78   //
79   // sender is used by the server to send messages to clients. The server will
80   // consume this object, taking ownership of it.
81   //
82   // sender may be called on a different thread than the thread used to call
83   // into SecAggServer, specifically in the PrngRunning state.
84   //
85   // prng_factory is a pointer to an instance of a subclass of AesPrngFactory.
86   // If this client will be communicating with the (C++) version of SecAggClient
87   // in this package, then the server and all clients should use
88   // AesCtrPrngFactory.
89   //
90   // metrics will be called over the course of the protocol to record message
91   // sizes and events. If it is null, no metrics will be recorded.
92   //
93   // threat_model includes the assumed maximum adversarial, maximum dropout
94   // rate, and adversary class.
95   //
96   //
97   // The protocol successfully
98   // completes and returns a sum if and only if no more than
99   // floor(total_number_of_clients * threat_model.estimated_dropout_rate())
100   // clients dropout before the end of the protocol execution. This ensure that
101   // at least ceil(total_number_of_clients
102   // *(1. - threat_model.estimated_dropout_rate() -
103   // threat_model.adversarial_client_rate)) values from honest clients are
104   // included in the final sum.
105   // The protocol allows to make that threshold larger by providing a larger
106   // value of minimum_number_of_clients_to_proceed, but
107   // never lower (if the provided minimum_number_of_clients_to_proceed is
108   // smaller than ceil(total_number_of_clients *(1. -
109   // threat_model.estimated_dropout_rate())), the protocol defaults to the
110   // latter value.
111   static StatusOr<std::unique_ptr<SecAggServer>> Create(
112       int minimum_number_of_clients_to_proceed, int total_number_of_clients,
113       const std::vector<InputVectorSpecification>& input_vector_specs,
114       SendToClientsInterface* sender,
115       std::unique_ptr<SecAggServerMetricsListener> metrics,
116       std::unique_ptr<SecAggScheduler> prng_runner,
117       std::unique_ptr<ExperimentsInterface> experiments,
118       const SecureAggregationRequirements& threat_model);
119 
120   ////////////////////////////// PROTOCOL METHODS //////////////////////////////
121 
122   // Makes the server abort the protocol, sending a message to all still-alive
123   // clients that the protocol has been aborted. Most of the state will be
124   // erased except for some diagnostic information. A new instance of
125   // SecAggServer will be needed to restart the protocol.
126   //
127   // If a reason string is provided, it will be stored by the server and sent to
128   // the clients as diagnostic information.
129   // An optional outcome can be provided for diagnostic purposes to be recorded
130   // via SecAggServerMetricsListener. By default, EXTERNAL_REQUEST outcome is
131   // assumed.
132   //
133   // The status will be OK unless the protocol was already completed or aborted.
134   Status Abort();
135   Status Abort(const std::string& reason, SecAggServerOutcome outcome);
136 
137   // Abort the specified client for the given reason.
138   //
139   // If the server is in a terminal state, returns a FAILED_PRECONDITION status.
140   Status AbortClient(uint32_t client_id, ClientAbortReason reason);
141 
142   // Proceeds to the next round, doing necessary computation and sending
143   // messages to clients as appropriate.
144   //
145   // If the server is not ready to proceed, this method will do nothing and
146   // return an UNAVAILABLE status. If the server is already in a terminal state,
147   // this method will do nothing and return a FAILED_PRECONDITION status.
148   //
149   // If the server is ready to proceed, but not all clients have yet sent in
150   // responses, any client that hasn't yet sent a response will be aborted (and
151   // a message informing them of this will be sent).
152   //
153   // After proceeding to the next round, the server is ready to receive more
154   // messages from clients in rounds 1, 2, and 3. In the PrngRunning round, it
155   // is instead ready to have StartPrng called.
156   //
157   // Returns OK as long as the server has actually executed the transition to
158   // the next state.
159   Status ProceedToNextRound();
160 
161   // Processes a message that has been received from a client with the given
162   // client_id.
163   //
164   // The boolean returned indicates whether the server is ready to proceed to
165   // the next round. This will be true when a number of clients equal to the
166   // minimum_number_of_clients_to_proceed threshold have sent in valid messages
167   // (and not subsequently aborted), including this one.
168   //
169   // If the message is invalid, the client who sent it will be aborted, and a
170   // message will be sent to them notifying them of the fact. A client may also
171   // send the server a message that it wishes to abort (in which case no further
172   // message to it is sent). This may cause a server that was previously ready
173   // for the next round to no longer be ready, or it may cause the server to
174   // abort if not enough clients remain alive.
175   //
176   // Returns a FAILED_PRECONDITION status if the server is in a terminal state
177   // or the PRNG_RUNNING state.
178   //
179   // Returns an ABORTED status to signify that the server has aborted after
180   // receiving this message. (This will cause all surviving clients to be
181   // notified as well.)
182   StatusOr<bool> ReceiveMessage(
183       uint32_t client_id,
184       std::unique_ptr<ClientToServerWrapperMessage> message);
185   // Sets up a callback to be invoked when any background asynchronous work
186   // has been done. The callback is guaranteed to invoked via the server's
187   // callback scheduler.
188   //
189   // Returns true if asynchronous processing is supported in the current
190   // server state and the callback has been setup successfully. Returns false
191   // if asynchronous processing isn't supported in the current server state or
192   // if no further asynchronous processing is possible. The callback argument
193   // is ignored in that case.
194   bool SetAsyncCallback(std::function<void()> async_callback);
195 
196   /////////////////////////////// STATUS METHODS ///////////////////////////////
197 
198   // Returns the set of clients that aborted the protocol. Can be used by the
199   // caller to close the relevant RPC connections or just start ignoring
200   // incoming messages from those clients for performance reasons.
201   absl::flat_hash_set<uint32_t> AbortedClientIds() const;
202 
203   // Returns a string describing the reason that the protocol was aborted.
204   // If the protocol has not actually been aborted, returns an error Status
205   // with code PRECONDITION_FAILED.
206   StatusOr<std::string> ErrorMessage() const;
207 
208   // Returns true if the protocol has been aborted, false else.
209   bool IsAborted() const;
210 
211   // Returns true if the protocol has been successfully completed, false else.
212   // The Result method can be called exactly when this method returns true.
213   bool IsCompletedSuccessfully() const;
214 
215   // Whether the set of inputs that will be included in the final aggregation
216   // has been fixed.
217   //
218   // If true, the value of NumberOfIncludedInputs will be fixed for the
219   // remainder of the protocol.
220   bool IsNumberOfIncludedInputsCommitted() const;
221 
222   // Indicates the minimum number of valid messages needed to be able to
223   // successfully move to the next round.
224   //
225   // Note that this value is not guaranteed to be monotonically decreasing.
226   // Client failures can cause this value to increase.
227   //
228   // Calling this in a terminal state results in an error.
229   StatusOr<int> MinimumMessagesNeededForNextRound() const;
230 
231   // Indicates the total number of clients that the server expects to receive
232   // a response from in this round (i.e. the ones that have not aborted). In
233   // the COMPLETED state, this returns the number of clients that survived to
234   // the final protocol message.
235   int NumberOfAliveClients() const;
236 
237   // Number of clients that failed after submitting their masked input. These
238   // clients' inputs will be included in the aggregate value, even though
239   // these clients did not complete the protocol.
240   int NumberOfClientsFailedAfterSendingMaskedInput() const;
241 
242   // Number of clients that failed before submitting their masked input. These
243   // clients' inputs won't be included in the aggregate value, even if the
244   // protocol succeeds.
245   int NumberOfClientsFailedBeforeSendingMaskedInput() const;
246 
247   // Number of clients that submitted a masked value, but didn't report their
248   // unmasking values fast enough to have them used in the final unmasking
249   // process. These clients' inputs will be included in the aggregate value.
250   int NumberOfClientsTerminatedWithoutUnmasking() const;
251 
252   // Returns the number of inputs that will appear in the final sum, if the
253   // protocol completes.
254   //
255   // Once IsNumberOfIncludedInputsCommitted is true, this value will be fixed
256   // for the remainder of the protocol.
257   //
258   // This will be 0 if the server is aborted. This will also be 0 if the
259   // server is in an early state, prior to receiving masked inputs. It is
260   // incremented only when the server receives a masked input from a client.
261   int NumberOfIncludedInputs() const;
262 
263   // Returns the number of live clients that have not yet submitted the
264   // expected response for the current round. In terminal states, this will be
265   // 0.
266   int NumberOfPendingClients() const;
267 
268   // Returns the number of clients that would still be alive if
269   // ProceedToNextRound were called immediately after. This value may be less
270   // than NumberOfMessagesReceivedInThisRound if a client fails after sending
271   // a message in this round.
272   //
273   // Note that this value is not guaranteed to be monotonically increasing,
274   // even within a round. Client failures can cause this value to decrease.
275   //
276   // Calling this in a terminal state results in an error.
277   StatusOr<int> NumberOfClientsReadyForNextRound() const;
278 
279   // Returns the number of valid messages received by clients this round.
280   // Unlike NumberOfClientsReadyForNextRound, this number is monotonically
281   // increasing until ProceedToNextRound is called, or the server aborts.
282   //
283   // Calling this in a terminal state results in an error.
284   StatusOr<int> NumberOfMessagesReceivedInThisRound() const;
285 
286   // Returns a boolean indicating if the server has received enough messages
287   // from clients (who have not subsequently aborted) to proceed to the next
288   // round. ProceedToNextRound will do nothing unless this returns true.
289   //
290   // Even after this method returns true, the server will remain in the
291   // current round until ProceedToNextRound is called.
292   //
293   // Calling this in a terminal state results in an error.
294   StatusOr<bool> ReadyForNextRound() const;
295 
296   // Transfers ownership of the result of the protocol to the caller. Requires
297   // the server to be in a completed state; returns UNAVAILABLE otherwise.
298   // Can be called only once; any consequitive calls result in an error.
299   StatusOr<std::unique_ptr<SecAggVectorMap>> Result();
300 
301   // Returns the number of neighbors of each client.
302   int NumberOfNeighbors() const;
303 
304   // Returns the minimum number of neighbors of a client that must not
305   // drop-out for that client's contribution to be included in the sum. This
306   // corresponds to the threshold in the shamir secret sharing of self and
307   // pairwise masks.
308   int MinimumSurvivingNeighborsForReconstruction() const;
309 
310   // Returns a value uniquely describing the current state of the client's
311   // FSM.
312   SecAggServerStateKind State() const;
313 
314  private:
315   // Constructs a new instance of the Secure Aggregation server.
316   explicit SecAggServer(std::unique_ptr<SecAggServerProtocolImpl> impl);
317 
318   // This causes the server to transition into a new state, and call the
319   // callback if one is provided.
320   void TransitionState(std::unique_ptr<SecAggServerState> new_state);
321 
322   // Validates if the client_id is within the expected bounds.
323   Status ValidateClientId(uint32_t client_id) const;
324 
325   // Returns an error if the server is in Aborted or Completed state.
326   Status ErrorIfAbortedOrCompleted() const;
327 
328   // The internal state object, containing details about the server's current
329   // state.
330   std::unique_ptr<SecAggServerState> state_;
331 
332   // Tracing span for this session of SecAggServer. This is bound to the
333   // lifetime of SecAggServer i.e. from the time the object is created till it
334   // is destroyed.
335   UnscopedTracingSpan<SecureAggServerSession> span_;
336 
337   // Holds pointer to a tracing span corresponding to the current active
338   // SecAggServerState.
339   std::unique_ptr<UnscopedTracingSpan<SecureAggServerState>> state_span_;
340 };
341 
342 }  // namespace secagg
343 }  // namespace fcp
344 #endif  // FCP_SECAGG_SERVER_SECAGG_SERVER_H_
345