1 /* 2 * Copyright 2018 Google LLC 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef FCP_SECAGG_SERVER_SECAGG_SERVER_STATE_H_ 18 #define FCP_SECAGG_SERVER_SECAGG_SERVER_STATE_H_ 19 20 #include <functional> 21 #include <memory> 22 #include <string> 23 24 #include "absl/container/flat_hash_set.h" 25 #include "absl/time/time.h" 26 #include "fcp/secagg/server/secagg_server_enums.pb.h" 27 #include "fcp/secagg/server/secagg_server_protocol_impl.h" 28 #include "fcp/secagg/server/tracing_schema.h" 29 #include "fcp/secagg/shared/secagg_messages.pb.h" 30 31 namespace fcp { 32 namespace secagg { 33 34 // This is an abstract class which is the parent of the other SecAggServer*State 35 // classes. It should not be instantiated directly. Default versions of all the 36 // methods declared here are provided for use by states which do not expect, and 37 // therefore do not implement, those methods. 38 39 class SecAggServerState { 40 public: 41 // Returns the number of clients selected to be in the cohort for this 42 // instance of Secure Aggregation. total_number_of_clients()43 inline const size_t total_number_of_clients() const { 44 return impl_->total_number_of_clients(); 45 } 46 47 // Returns the number of neighbors of each client. number_of_neighbors()48 inline const int number_of_neighbors() const { 49 return impl_->number_of_neighbors(); 50 } 51 52 // Returns the minimum number of neighbors of a client that must not drop-out 53 // for that client's contribution to be included in the sum. This corresponds 54 // to the threshold in the shamir secret sharing of self and pairwise masks. minimum_surviving_neighbors_for_reconstruction()55 inline const int minimum_surviving_neighbors_for_reconstruction() const { 56 return impl_->minimum_surviving_neighbors_for_reconstruction(); 57 } 58 59 // Returns the index of client_id_2 in the list of neighbors of client_id_1, 60 // if present GetNeighborIndex(int client_id_1,int client_id_2)61 inline const std::optional<int> GetNeighborIndex(int client_id_1, 62 int client_id_2) const { 63 return impl_->GetNeighborIndex(client_id_1, client_id_2); 64 } 65 66 // EnterState must be called just after transitioning to a state. 67 // States may use this to initialize their state or trigger work. EnterState()68 virtual void EnterState() {} 69 70 // Processes the received message in a way consistent with the current state. 71 // 72 // Returns OK status to indicate that the message has been handled 73 // successfully. 74 // 75 // Returns a FAILED_PRECONDITION status if the server is in a state from which 76 // it does not expect to receive any messages. In that case no reply will be 77 // sent. 78 virtual Status HandleMessage(uint32_t client_id, 79 const ClientToServerWrapperMessage& message); 80 // Analog of the above method, bu giving ownership of the message. 81 virtual Status HandleMessage( 82 uint32_t client_id, 83 std::unique_ptr<ClientToServerWrapperMessage> message); 84 85 // Proceeds to the next round, doing all necessary computation and sending 86 // messages to clients as appropriate. If the server is not yet ready to 87 // proceed, returns an UNAVAILABLE status. 88 // 89 // If the server is in a terminal state, returns a FAILED_PRECONDITION status. 90 // 91 // Otherwise, returns the new state. This may be an abort state if the server 92 // has aborted. 93 // 94 // If this method returns a new state (i.e. if the status is OK), then the old 95 // state is no longer valid and the new state must be considered the current 96 // state. If it returns a non-OK status, this method does not change the 97 // underlying state. 98 virtual StatusOr<std::unique_ptr<SecAggServerState>> ProceedToNextRound(); 99 100 // Returns true if the client state is considered to be "dead" e.g. aborted or 101 // disconnected; otherwise returns false. 102 bool IsClientDead(uint32_t client_id) const; 103 104 // Abort the specified client for the given reason. If notify is true, send a 105 // notification message to the client. (If the client was already closed, no 106 // message will be sent). 107 // 108 // The reason code will be used for recording metrics if log_metrics is true, 109 // else no metrics are recorded. By default, metrics will always be logged. 110 void AbortClient(uint32_t client_id, const std::string& reason, 111 ClientDropReason reason_code, bool notify = true, 112 bool log_metrics = true); 113 114 // Aborts the protocol for the specified reason. Notifies all clients of 115 // the abort. Returns the new state. 116 // Calling this method on a terminal state isn't valid. 117 std::unique_ptr<SecAggServerState> Abort(const std::string& reason, 118 SecAggServerOutcome outcome); 119 120 // Returns true if the current state is Abort, false else. 121 virtual bool IsAborted() const; 122 123 // Returns true if the current state is ProtocolCompleted, false else. 124 virtual bool IsCompletedSuccessfully() const; 125 126 // Returns an error message explaining why the server aborted, if the current 127 // state is an abort state. If not returns an error Status with code 128 // FAILED_PRECONDITION. 129 virtual StatusOr<std::string> ErrorMessage() const; 130 131 // Returns an enum specifying the current state. 132 SecAggServerStateKind State() const; 133 134 // Returns the name of the current state in the form of a short string. 135 std::string StateName() const; 136 137 // Returns whether or not the server has received enough messages to be ready 138 // for the next phase of the protocol. 139 // In the PRNG Running state, it returns whether or not the PRNG has stopped 140 // running. 141 // Always false in a terminal state. 142 virtual bool ReadyForNextRound() const; 143 144 // Returns the number of valid messages received by clients this round. 145 int NumberOfMessagesReceivedInThisRound() const; 146 147 // Returns the number of clients that would still be alive if 148 // ProceedToNextRound were called immediately after. This value may be less 149 // than NumberOfMessagesReceivedInThisRound if a client fails after sending a 150 // message in this round. 151 // Note that this value is not guaranteed to be monotonically increasing, even 152 // within a round. Client failures can cause this value to decrease. 153 virtual int NumberOfClientsReadyForNextRound() const; 154 155 // Indicates the total number of clients that the server expects to receive a 156 // response from in this round (i.e. the ones that have not aborted). 157 // In the COMPLETED state, this returns the number of clients that survived to 158 // the final protocol message. 159 virtual int NumberOfAliveClients() const; 160 161 // Number of clients that failed before submitting their masked input. These 162 // clients' inputs won't be included in the aggregate value, even if the 163 // protocol succeeds. 164 int NumberOfClientsFailedBeforeSendingMaskedInput() const; 165 166 // Number of clients that failed after submitting their masked input. These 167 // clients' inputs will be included in the aggregate value, even though these 168 // clients did not complete the protocol. 169 int NumberOfClientsFailedAfterSendingMaskedInput() const; 170 171 // Number of clients that submitted a masked value, but didn't report their 172 // unmasking values fast enough to have them used in the final unmasking 173 // process. These clients' inputs will be included in the aggregate value. 174 int NumberOfClientsTerminatedWithoutUnmasking() const; 175 176 // Returns the number of live clients that have not yet submitted the expected 177 // response for the current round. In terminal states, this will be 0. 178 virtual int NumberOfPendingClients() const; 179 180 // Returns the number of inputs that will appear in the final sum, if the 181 // protocol completes. 182 // Once IsNumberOfIncludedInputsCommitted is true, this value will be fixed 183 // for the remainder of the protocol. 184 // This will be 0 if the server is aborted. This will also be 0 if the server 185 // is in an early state, prior to receiving masked inputs. It is incremented 186 // only when the server receives a masked input from a client. 187 virtual int NumberOfIncludedInputs() const; 188 189 // Whether the set of inputs that will be included in the final aggregation is 190 // fixed. 191 // If true, the value of NumberOfIncludedInputs will be fixed for the 192 // remainder of the protocol. 193 virtual bool IsNumberOfIncludedInputsCommitted() const = 0; 194 195 // Indicates the minimum number of valid messages needed to be able to 196 // successfully move to the next round. 197 // Note that this value is not guaranteed to be monotonically decreasing. 198 // Client failures can cause this value to increase. 199 // In terminal states, this returns 0. 200 virtual int MinimumMessagesNeededForNextRound() const; 201 202 // Returns the minimum threshold number of clients that need to send valid 203 // responses in order for the protocol to proceed from one round to the next. minimum_number_of_clients_to_proceed()204 inline const int minimum_number_of_clients_to_proceed() const { 205 return impl_->minimum_number_of_clients_to_proceed(); 206 } 207 208 // Returns the set of clients that aborted the protocol. Can be used by the 209 // caller to close the relevant RPC connections or just start ignoring 210 // incoming messages from those clients for performance reasons. 211 absl::flat_hash_set<uint32_t> AbortedClientIds() const; 212 213 // Returns true if the server has determined that it needs to abort itself, 214 // If the server is in a terminal state, returns false. 215 bool NeedsToAbort() const; 216 217 // Sets up a callback to be triggered when any background asynchronous work 218 // has been done. The callback is guaranteed to invoked via the server's 219 // callback scheduler. 220 // 221 // Returns true if the state supports asynchronous processing and the callback 222 // has been setup successfully. 223 // Returns false if the state doesn't support asynchronous processing or if 224 // no further asynchronous processing is possible. The callback argument is 225 // ignored in this case. 226 virtual bool SetAsyncCallback(std::function<void()> async_callback); 227 228 // Transfers ownership of the result of the protocol to the caller. Requires 229 // the server to be in a completed state; returns UNAVAILABLE otherwise. 230 // Can be called only once; any consecutive calls result in an error. 231 virtual StatusOr<std::unique_ptr<SecAggVectorMap>> Result(); 232 233 virtual ~SecAggServerState(); 234 235 protected: 236 // SecAggServerState should never be instantiated directly. 237 SecAggServerState(int number_of_clients_failed_after_sending_masked_input, 238 int number_of_clients_failed_before_sending_masked_input, 239 int number_of_clients_terminated_without_unmasking, 240 SecAggServerStateKind state_kind, 241 std::unique_ptr<SecAggServerProtocolImpl> impl); 242 impl()243 SecAggServerProtocolImpl* impl() { return impl_.get(); } 244 245 // Returns the callback interface for recording metrics. metrics()246 inline SecAggServerMetricsListener* metrics() const { 247 return impl_->metrics(); 248 } 249 250 // Returns the callback interface for sending protocol buffer messages to the 251 // client. sender()252 inline SendToClientsInterface* sender() const { return impl_->sender(); } 253 client_status(uint32_t client_id)254 inline const ClientStatus& client_status(uint32_t client_id) const { 255 return impl_->client_status(client_id); 256 } 257 set_client_status(uint32_t client_id,ClientStatus status)258 inline void set_client_status(uint32_t client_id, ClientStatus status) { 259 impl_->set_client_status(client_id, status); 260 } 261 262 // Records information about a message that was received from a client. 263 void MessageReceived(const ClientToServerWrapperMessage& message, 264 bool expected); 265 266 // Broadcasts the message and records metrics. 267 void SendBroadcast(const ServerToClientWrapperMessage& message); 268 269 // Sends the message to the given client and records metrics. 270 void Send(uint32_t recipient_id, const ServerToClientWrapperMessage& message); 271 272 // Returns an aborted version of the current state, storing the specified 273 // reason. Calling this method makes the current state unusable. The caller is 274 // responsible for sending any failure messages that need to be sent, and for 275 // doing so BEFORE calling this method. 276 // The SecAggServerOutcome outcome is used for recording metrics. 277 std::unique_ptr<SecAggServerState> AbortState(const std::string& reason, 278 SecAggServerOutcome outcome); 279 280 // ExitState must be called on the current state just before transitioning to 281 // a new state to record metrics and transfer out the shared state. 282 enum class StateTransition { 283 // Indicates a successful state transition to any state other than Aborted. 284 kSuccess = 0, 285 // Indicates transition to Aborted state. 286 kAbort = 1 287 }; 288 std::unique_ptr<SecAggServerProtocolImpl>&& ExitState( 289 StateTransition state_transition_status); 290 291 bool needs_to_abort_; 292 int number_of_clients_failed_after_sending_masked_input_; 293 int number_of_clients_failed_before_sending_masked_input_; 294 int number_of_clients_ready_for_next_round_; 295 int number_of_clients_terminated_without_unmasking_; 296 int number_of_messages_received_in_this_round_; 297 absl::Time round_start_; 298 SecAggServerStateKind state_kind_; 299 300 private: 301 // Performs state specific action when a client is aborted. HandleAbortClient(uint32_t client_id,ClientDropReason reason_code)302 virtual void HandleAbortClient(uint32_t client_id, 303 ClientDropReason reason_code) {} 304 305 // Performs state specific action when the server is aborted. HandleAbort()306 virtual void HandleAbort() {} 307 308 std::unique_ptr<SecAggServerProtocolImpl> impl_; 309 }; 310 311 } // namespace secagg 312 } // namespace fcp 313 314 #endif // FCP_SECAGG_SERVER_SECAGG_SERVER_STATE_H_ 315