1 /* 2 * Copyright 2019 Google LLC 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef FCP_SECAGG_SERVER_SECAGG_SERVER_H_ 18 #define FCP_SECAGG_SERVER_SECAGG_SERVER_H_ 19 20 #include <functional> 21 #include <memory> 22 #include <string> 23 #include <vector> 24 25 #include "absl/container/flat_hash_set.h" 26 #include "absl/container/node_hash_set.h" 27 #include "fcp/base/monitoring.h" 28 #include "fcp/base/scheduler.h" 29 #include "fcp/secagg/server/experiments_interface.h" 30 #include "fcp/secagg/server/experiments_names.h" 31 #include "fcp/secagg/server/secagg_scheduler.h" 32 #include "fcp/secagg/server/secagg_server_enums.pb.h" 33 #include "fcp/secagg/server/secagg_server_messages.pb.h" 34 #include "fcp/secagg/server/secagg_server_metrics_listener.h" 35 #include "fcp/secagg/server/secagg_server_state.h" 36 #include "fcp/secagg/server/secret_sharing_graph.h" 37 #include "fcp/secagg/server/tracing_schema.h" 38 #include "fcp/secagg/shared/aes_prng_factory.h" 39 #include "fcp/secagg/shared/secagg_messages.pb.h" 40 #include "fcp/tracing/tracing_span.h" 41 42 namespace fcp { 43 namespace secagg { 44 45 // Represents a server for the Secure Aggregation protocol. Each instance of 46 // this class performs just *one* session of the protocol. 47 // 48 // To create a new instance, use the public constructor. Once constructed, the 49 // server is ready to receive messages from clients with the ReceiveMessage 50 // method. 51 // 52 // When enough messages have been received (i.e. when ReceiveMessage or 53 // ReadyForNextRound return true) or any time after that, proceed to the next 54 // round by calling ProceedToNextRound. 55 // 56 // After all client interaction is done, the server needs to do some 57 // multi-threaded computation using the supplied Scheduler. Call StartPrng to 58 // begin this computation. 59 // 60 // When the computation is complete, call Result to get the final result. 61 // 62 // This class is not thread-safe. 63 64 class SecAggServer { 65 public: 66 // Constructs a new instance of the Secure Aggregation server. 67 // 68 // minimum_number_of_clients_to_proceed is the threshold lower bound on the 69 // total number of clients expected to complete the protocol. If there are 70 // ever fewer than this many clients still alive in the protocol, the server 71 // will abort (causing all clients to abort as well). 72 // 73 // total_number_of_clients is the number of clients selected to be in the 74 // cohort for this instance of Secure Aggregation. 75 // 76 // input_vector_specs must contain one InputVectorSpecification for each input 77 // vector which the protocol will aggregate. 78 // 79 // sender is used by the server to send messages to clients. The server will 80 // consume this object, taking ownership of it. 81 // 82 // sender may be called on a different thread than the thread used to call 83 // into SecAggServer, specifically in the PrngRunning state. 84 // 85 // prng_factory is a pointer to an instance of a subclass of AesPrngFactory. 86 // If this client will be communicating with the (C++) version of SecAggClient 87 // in this package, then the server and all clients should use 88 // AesCtrPrngFactory. 89 // 90 // metrics will be called over the course of the protocol to record message 91 // sizes and events. If it is null, no metrics will be recorded. 92 // 93 // threat_model includes the assumed maximum adversarial, maximum dropout 94 // rate, and adversary class. 95 // 96 // 97 // The protocol successfully 98 // completes and returns a sum if and only if no more than 99 // floor(total_number_of_clients * threat_model.estimated_dropout_rate()) 100 // clients dropout before the end of the protocol execution. This ensure that 101 // at least ceil(total_number_of_clients 102 // *(1. - threat_model.estimated_dropout_rate() - 103 // threat_model.adversarial_client_rate)) values from honest clients are 104 // included in the final sum. 105 // The protocol allows to make that threshold larger by providing a larger 106 // value of minimum_number_of_clients_to_proceed, but 107 // never lower (if the provided minimum_number_of_clients_to_proceed is 108 // smaller than ceil(total_number_of_clients *(1. - 109 // threat_model.estimated_dropout_rate())), the protocol defaults to the 110 // latter value. 111 static StatusOr<std::unique_ptr<SecAggServer>> Create( 112 int minimum_number_of_clients_to_proceed, int total_number_of_clients, 113 const std::vector<InputVectorSpecification>& input_vector_specs, 114 SendToClientsInterface* sender, 115 std::unique_ptr<SecAggServerMetricsListener> metrics, 116 std::unique_ptr<SecAggScheduler> prng_runner, 117 std::unique_ptr<ExperimentsInterface> experiments, 118 const SecureAggregationRequirements& threat_model); 119 120 ////////////////////////////// PROTOCOL METHODS ////////////////////////////// 121 122 // Makes the server abort the protocol, sending a message to all still-alive 123 // clients that the protocol has been aborted. Most of the state will be 124 // erased except for some diagnostic information. A new instance of 125 // SecAggServer will be needed to restart the protocol. 126 // 127 // If a reason string is provided, it will be stored by the server and sent to 128 // the clients as diagnostic information. 129 // An optional outcome can be provided for diagnostic purposes to be recorded 130 // via SecAggServerMetricsListener. By default, EXTERNAL_REQUEST outcome is 131 // assumed. 132 // 133 // The status will be OK unless the protocol was already completed or aborted. 134 Status Abort(); 135 Status Abort(const std::string& reason, SecAggServerOutcome outcome); 136 137 // Abort the specified client for the given reason. 138 // 139 // If the server is in a terminal state, returns a FAILED_PRECONDITION status. 140 Status AbortClient(uint32_t client_id, ClientAbortReason reason); 141 142 // Proceeds to the next round, doing necessary computation and sending 143 // messages to clients as appropriate. 144 // 145 // If the server is not ready to proceed, this method will do nothing and 146 // return an UNAVAILABLE status. If the server is already in a terminal state, 147 // this method will do nothing and return a FAILED_PRECONDITION status. 148 // 149 // If the server is ready to proceed, but not all clients have yet sent in 150 // responses, any client that hasn't yet sent a response will be aborted (and 151 // a message informing them of this will be sent). 152 // 153 // After proceeding to the next round, the server is ready to receive more 154 // messages from clients in rounds 1, 2, and 3. In the PrngRunning round, it 155 // is instead ready to have StartPrng called. 156 // 157 // Returns OK as long as the server has actually executed the transition to 158 // the next state. 159 Status ProceedToNextRound(); 160 161 // Processes a message that has been received from a client with the given 162 // client_id. 163 // 164 // The boolean returned indicates whether the server is ready to proceed to 165 // the next round. This will be true when a number of clients equal to the 166 // minimum_number_of_clients_to_proceed threshold have sent in valid messages 167 // (and not subsequently aborted), including this one. 168 // 169 // If the message is invalid, the client who sent it will be aborted, and a 170 // message will be sent to them notifying them of the fact. A client may also 171 // send the server a message that it wishes to abort (in which case no further 172 // message to it is sent). This may cause a server that was previously ready 173 // for the next round to no longer be ready, or it may cause the server to 174 // abort if not enough clients remain alive. 175 // 176 // Returns a FAILED_PRECONDITION status if the server is in a terminal state 177 // or the PRNG_RUNNING state. 178 // 179 // Returns an ABORTED status to signify that the server has aborted after 180 // receiving this message. (This will cause all surviving clients to be 181 // notified as well.) 182 StatusOr<bool> ReceiveMessage( 183 uint32_t client_id, 184 std::unique_ptr<ClientToServerWrapperMessage> message); 185 // Sets up a callback to be invoked when any background asynchronous work 186 // has been done. The callback is guaranteed to invoked via the server's 187 // callback scheduler. 188 // 189 // Returns true if asynchronous processing is supported in the current 190 // server state and the callback has been setup successfully. Returns false 191 // if asynchronous processing isn't supported in the current server state or 192 // if no further asynchronous processing is possible. The callback argument 193 // is ignored in that case. 194 bool SetAsyncCallback(std::function<void()> async_callback); 195 196 /////////////////////////////// STATUS METHODS /////////////////////////////// 197 198 // Returns the set of clients that aborted the protocol. Can be used by the 199 // caller to close the relevant RPC connections or just start ignoring 200 // incoming messages from those clients for performance reasons. 201 absl::flat_hash_set<uint32_t> AbortedClientIds() const; 202 203 // Returns a string describing the reason that the protocol was aborted. 204 // If the protocol has not actually been aborted, returns an error Status 205 // with code PRECONDITION_FAILED. 206 StatusOr<std::string> ErrorMessage() const; 207 208 // Returns true if the protocol has been aborted, false else. 209 bool IsAborted() const; 210 211 // Returns true if the protocol has been successfully completed, false else. 212 // The Result method can be called exactly when this method returns true. 213 bool IsCompletedSuccessfully() const; 214 215 // Whether the set of inputs that will be included in the final aggregation 216 // has been fixed. 217 // 218 // If true, the value of NumberOfIncludedInputs will be fixed for the 219 // remainder of the protocol. 220 bool IsNumberOfIncludedInputsCommitted() const; 221 222 // Indicates the minimum number of valid messages needed to be able to 223 // successfully move to the next round. 224 // 225 // Note that this value is not guaranteed to be monotonically decreasing. 226 // Client failures can cause this value to increase. 227 // 228 // Calling this in a terminal state results in an error. 229 StatusOr<int> MinimumMessagesNeededForNextRound() const; 230 231 // Indicates the total number of clients that the server expects to receive 232 // a response from in this round (i.e. the ones that have not aborted). In 233 // the COMPLETED state, this returns the number of clients that survived to 234 // the final protocol message. 235 int NumberOfAliveClients() const; 236 237 // Number of clients that failed after submitting their masked input. These 238 // clients' inputs will be included in the aggregate value, even though 239 // these clients did not complete the protocol. 240 int NumberOfClientsFailedAfterSendingMaskedInput() const; 241 242 // Number of clients that failed before submitting their masked input. These 243 // clients' inputs won't be included in the aggregate value, even if the 244 // protocol succeeds. 245 int NumberOfClientsFailedBeforeSendingMaskedInput() const; 246 247 // Number of clients that submitted a masked value, but didn't report their 248 // unmasking values fast enough to have them used in the final unmasking 249 // process. These clients' inputs will be included in the aggregate value. 250 int NumberOfClientsTerminatedWithoutUnmasking() const; 251 252 // Returns the number of inputs that will appear in the final sum, if the 253 // protocol completes. 254 // 255 // Once IsNumberOfIncludedInputsCommitted is true, this value will be fixed 256 // for the remainder of the protocol. 257 // 258 // This will be 0 if the server is aborted. This will also be 0 if the 259 // server is in an early state, prior to receiving masked inputs. It is 260 // incremented only when the server receives a masked input from a client. 261 int NumberOfIncludedInputs() const; 262 263 // Returns the number of live clients that have not yet submitted the 264 // expected response for the current round. In terminal states, this will be 265 // 0. 266 int NumberOfPendingClients() const; 267 268 // Returns the number of clients that would still be alive if 269 // ProceedToNextRound were called immediately after. This value may be less 270 // than NumberOfMessagesReceivedInThisRound if a client fails after sending 271 // a message in this round. 272 // 273 // Note that this value is not guaranteed to be monotonically increasing, 274 // even within a round. Client failures can cause this value to decrease. 275 // 276 // Calling this in a terminal state results in an error. 277 StatusOr<int> NumberOfClientsReadyForNextRound() const; 278 279 // Returns the number of valid messages received by clients this round. 280 // Unlike NumberOfClientsReadyForNextRound, this number is monotonically 281 // increasing until ProceedToNextRound is called, or the server aborts. 282 // 283 // Calling this in a terminal state results in an error. 284 StatusOr<int> NumberOfMessagesReceivedInThisRound() const; 285 286 // Returns a boolean indicating if the server has received enough messages 287 // from clients (who have not subsequently aborted) to proceed to the next 288 // round. ProceedToNextRound will do nothing unless this returns true. 289 // 290 // Even after this method returns true, the server will remain in the 291 // current round until ProceedToNextRound is called. 292 // 293 // Calling this in a terminal state results in an error. 294 StatusOr<bool> ReadyForNextRound() const; 295 296 // Transfers ownership of the result of the protocol to the caller. Requires 297 // the server to be in a completed state; returns UNAVAILABLE otherwise. 298 // Can be called only once; any consequitive calls result in an error. 299 StatusOr<std::unique_ptr<SecAggVectorMap>> Result(); 300 301 // Returns the number of neighbors of each client. 302 int NumberOfNeighbors() const; 303 304 // Returns the minimum number of neighbors of a client that must not 305 // drop-out for that client's contribution to be included in the sum. This 306 // corresponds to the threshold in the shamir secret sharing of self and 307 // pairwise masks. 308 int MinimumSurvivingNeighborsForReconstruction() const; 309 310 // Returns a value uniquely describing the current state of the client's 311 // FSM. 312 SecAggServerStateKind State() const; 313 314 private: 315 // Constructs a new instance of the Secure Aggregation server. 316 explicit SecAggServer(std::unique_ptr<SecAggServerProtocolImpl> impl); 317 318 // This causes the server to transition into a new state, and call the 319 // callback if one is provided. 320 void TransitionState(std::unique_ptr<SecAggServerState> new_state); 321 322 // Validates if the client_id is within the expected bounds. 323 Status ValidateClientId(uint32_t client_id) const; 324 325 // Returns an error if the server is in Aborted or Completed state. 326 Status ErrorIfAbortedOrCompleted() const; 327 328 // The internal state object, containing details about the server's current 329 // state. 330 std::unique_ptr<SecAggServerState> state_; 331 332 // Tracing span for this session of SecAggServer. This is bound to the 333 // lifetime of SecAggServer i.e. from the time the object is created till it 334 // is destroyed. 335 UnscopedTracingSpan<SecureAggServerSession> span_; 336 337 // Holds pointer to a tracing span corresponding to the current active 338 // SecAggServerState. 339 std::unique_ptr<UnscopedTracingSpan<SecureAggServerState>> state_span_; 340 }; 341 342 } // namespace secagg 343 } // namespace fcp 344 #endif // FCP_SECAGG_SERVER_SECAGG_SERVER_H_ 345