1 /* 2 * Copyright 2022 Google LLC 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef FCP_AGGREGATION_PROTOCOL_SIMPLE_AGGREGATION_SIMPLE_AGGREGATION_PROTOCOL_H_ 18 #define FCP_AGGREGATION_PROTOCOL_SIMPLE_AGGREGATION_SIMPLE_AGGREGATION_PROTOCOL_H_ 19 20 #include <atomic> 21 #include <cstdint> 22 #include <memory> 23 #include <string> 24 #include <vector> 25 26 #include "absl/base/attributes.h" 27 #include "absl/container/flat_hash_map.h" 28 #include "absl/status/status.h" 29 #include "absl/status/statusor.h" 30 #include "absl/strings/cord.h" 31 #include "absl/strings/string_view.h" 32 #include "fcp/aggregation/core/tensor_aggregator.h" 33 #include "fcp/aggregation/core/tensor_spec.h" 34 #include "fcp/aggregation/protocol/aggregation_protocol.h" 35 #include "fcp/aggregation/protocol/aggregation_protocol_messages.pb.h" 36 #include "fcp/aggregation/protocol/checkpoint_builder.h" 37 #include "fcp/aggregation/protocol/checkpoint_parser.h" 38 #include "fcp/aggregation/protocol/configuration.pb.h" 39 #include "fcp/aggregation/protocol/resource_resolver.h" 40 41 namespace fcp::aggregation { 42 43 // Implementation of the simple aggregation protocol. 44 // 45 // This version of the protocol receives updates in the clear from clients in a 46 // TF checkpoint and aggregates them in memory. The aggregated updates are 47 // released only if the number of participants exceed configured threshold. 48 class SimpleAggregationProtocol final : public AggregationProtocol { 49 public: 50 // Validates the Configuration that will subsequently be used to create an 51 // instance of this protocol. 52 // Returns INVALID_ARGUMENT if the configuration is invalid. 53 static absl::Status ValidateConfig(const Configuration& configuration); 54 55 // Factory method to create an instance of the Simple Aggregation Protocol. 56 // 57 // Does not take ownership of the callback, which must refer to a valid object 58 // that outlives the SimpleAggregationProtocol instance. 59 static absl::StatusOr<std::unique_ptr<SimpleAggregationProtocol>> Create( 60 const Configuration& configuration, 61 AggregationProtocol::Callback* callback, 62 const CheckpointParserFactory* checkpoint_parser_factory, 63 const CheckpointBuilderFactory* checkpoint_builder_factory, 64 ResourceResolver* resource_resolver); 65 66 // Implementation of the overridden Aggregation Protocol methods. 67 absl::Status Start(int64_t num_clients) override; 68 absl::Status AddClients(int64_t num_clients) override; 69 absl::Status ReceiveClientMessage(int64_t client_id, 70 const ClientMessage& message) override; 71 absl::Status CloseClient(int64_t client_id, 72 absl::Status client_status) override; 73 absl::Status Complete() override; 74 absl::Status Abort() override; 75 StatusMessage GetStatus() override; 76 77 ~SimpleAggregationProtocol() override = default; 78 79 // SimpleAggregationProtocol is neither copyable nor movable. 80 SimpleAggregationProtocol(const SimpleAggregationProtocol&) = delete; 81 SimpleAggregationProtocol& operator=(const SimpleAggregationProtocol&) = 82 delete; 83 84 private: 85 // The structure representing a single aggregation intrinsic. 86 // TODO(team): Implement mapping of multiple inputs and outputs to 87 // individual TensorAggregator instances. 88 struct Intrinsic { 89 TensorSpec input; 90 TensorSpec output; 91 std::unique_ptr<TensorAggregator> aggregator 92 ABSL_PT_GUARDED_BY(&SimpleAggregationProtocol::aggregation_mu_); 93 }; 94 95 // Private constructor. 96 SimpleAggregationProtocol( 97 std::vector<Intrinsic> intrinsics, 98 AggregationProtocol::Callback* callback, 99 const CheckpointParserFactory* checkpoint_parser_factory, 100 const CheckpointBuilderFactory* checkpoint_builder_factory, 101 ResourceResolver* resource_resolver); 102 103 // Creates an aggregation intrinsic based on the intrinsic configuration. 104 static absl::StatusOr<Intrinsic> CreateIntrinsic( 105 const Configuration::ServerAggregationConfig& aggregation_config); 106 107 // Describes the overall protocol state. 108 enum ProtocolState { 109 // The initial state indicating that the protocol was created. 110 PROTOCOL_CREATED, 111 // The protocol `Start` method has been called. 112 PROTOCOL_STARTED, 113 // The protocol `Complete` method has finished successfully. 114 PROTOCOL_COMPLETED, 115 // The protocol `Abort` method has been called. 116 PROTOCOL_ABORTED 117 }; 118 119 // Describes state of each client participating in the protocol. 120 enum ClientState : uint8_t { 121 // No input received from the client yet. 122 CLIENT_PENDING, 123 // Client input received but the aggregation still pending, which may 124 // be the case when there are multiple concurrent ReceiveClientMessage 125 // calls. 126 CLIENT_RECEIVED_INPUT_AND_PENDING, 127 // Client input has been successfully aggregated. 128 CLIENT_COMPLETED, 129 // Client failed either by being closed with an error or by submitting a 130 // malformed input. 131 CLIENT_FAILED, 132 // Client which has been aborted by the server before its input has been 133 // received. 134 CLIENT_ABORTED, 135 // Client input has been received but discarded, for example due to the 136 // protocol Abort method being called. 137 CLIENT_DISCARDED 138 }; 139 140 // Returns string representation of the protocol state. 141 static absl::string_view ProtocolStateDebugString(ProtocolState state); 142 143 // Returns string representation of the client state. 144 static absl::string_view ClientStateDebugString(ClientState state); 145 146 // Returns an error if the current protocol state isn't the expected one. 147 absl::Status CheckProtocolState(ProtocolState state) const 148 ABSL_SHARED_LOCKS_REQUIRED(state_mu_); 149 150 // Changes the protocol state. 151 void SetProtocolState(ProtocolState state) 152 ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); 153 154 // Gets the client state for the given client ID. 155 absl::StatusOr<ClientState> GetClientState(int64_t client_id) const 156 ABSL_SHARED_LOCKS_REQUIRED(state_mu_); 157 158 // Sets the client state for the given client ID. 159 void SetClientState(int64_t client_id, ClientState state) 160 ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); 161 162 // Parses and validates the client report. 163 // This function involves a potentially expensive I/O and parsing and should 164 // run concurrently as much as possible. The ABSL_LOCKS_EXCLUDED attribution 165 // below is used to emphasize that. 166 using TensorMap = absl::flat_hash_map<std::string, Tensor>; 167 absl::StatusOr<TensorMap> ParseCheckpoint(absl::Cord report) const 168 ABSL_LOCKS_EXCLUDED(state_mu_, aggregation_mu_); 169 170 // Aggregates the input via the underlying aggregators. 171 absl::Status AggregateClientInput(TensorMap tensor_map) 172 ABSL_LOCKS_EXCLUDED(state_mu_, aggregation_mu_); 173 174 // Produces the report via the underlying aggregators. 175 absl::StatusOr<absl::Cord> CreateReport() 176 ABSL_LOCKS_EXCLUDED(aggregation_mu_); 177 178 // Protects the mutable state. 179 absl::Mutex state_mu_; 180 // Protects calls into the aggregators. 181 absl::Mutex aggregation_mu_; 182 // This indicates that the aggregation has finished either by completing 183 // the protocol or by aborting it. This can be triggered without locking on 184 // the aggregation_mu_ mutex first to allow aborting the protocol promptly and 185 // discarding all the pending aggregation calls. 186 std::atomic_bool aggregation_finished_ = false; 187 188 // The overall state of the protocol. 189 ProtocolState protocol_state_ ABSL_GUARDED_BY(state_mu_); 190 191 // Holds state of all clients. The length of the vector equals 192 // to the number of clients accepted into the protocol. 193 std::vector<ClientState> client_states_ ABSL_GUARDED_BY(state_mu_); 194 195 // Counters for various client states other than pending. 196 // Note that the number of pending clients can be found by subtracting the 197 // sum of the below counters from `client_states_.size()`. 198 uint64_t num_clients_received_and_pending_ ABSL_GUARDED_BY(state_mu_) = 0; 199 uint64_t num_clients_aggregated_ ABSL_GUARDED_BY(state_mu_) = 0; 200 uint64_t num_clients_failed_ ABSL_GUARDED_BY(state_mu_) = 0; 201 uint64_t num_clients_aborted_ ABSL_GUARDED_BY(state_mu_) = 0; 202 uint64_t num_clients_discarded_ ABSL_GUARDED_BY(state_mu_) = 0; 203 204 // Intrinsics are immutable and shouldn't be guarded by the either of mutexes. 205 // Please note that the access to the aggregators that intrinsics point to 206 // still needs to be strictly sequential. That is guarded separatedly by 207 // `aggregators_mu_`. 208 std::vector<Intrinsic> const intrinsics_; 209 210 AggregationProtocol::Callback* const callback_; 211 const CheckpointParserFactory* const checkpoint_parser_factory_; 212 const CheckpointBuilderFactory* const checkpoint_builder_factory_; 213 ResourceResolver* const resource_resolver_; 214 }; 215 } // namespace fcp::aggregation 216 217 #endif // FCP_AGGREGATION_PROTOCOL_SIMPLE_AGGREGATION_SIMPLE_AGGREGATION_PROTOCOL_H_ 218