1 /*
2  * Copyright 2022 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef FCP_AGGREGATION_PROTOCOL_SIMPLE_AGGREGATION_SIMPLE_AGGREGATION_PROTOCOL_H_
18 #define FCP_AGGREGATION_PROTOCOL_SIMPLE_AGGREGATION_SIMPLE_AGGREGATION_PROTOCOL_H_
19 
20 #include <atomic>
21 #include <cstdint>
22 #include <memory>
23 #include <string>
24 #include <vector>
25 
26 #include "absl/base/attributes.h"
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/status/status.h"
29 #include "absl/status/statusor.h"
30 #include "absl/strings/cord.h"
31 #include "absl/strings/string_view.h"
32 #include "fcp/aggregation/core/tensor_aggregator.h"
33 #include "fcp/aggregation/core/tensor_spec.h"
34 #include "fcp/aggregation/protocol/aggregation_protocol.h"
35 #include "fcp/aggregation/protocol/aggregation_protocol_messages.pb.h"
36 #include "fcp/aggregation/protocol/checkpoint_builder.h"
37 #include "fcp/aggregation/protocol/checkpoint_parser.h"
38 #include "fcp/aggregation/protocol/configuration.pb.h"
39 #include "fcp/aggregation/protocol/resource_resolver.h"
40 
41 namespace fcp::aggregation {
42 
43 // Implementation of the simple aggregation protocol.
44 //
45 // This version of the protocol receives updates in the clear from clients in a
46 // TF checkpoint and aggregates them in memory. The aggregated updates are
47 // released only if the number of participants exceed configured threshold.
48 class SimpleAggregationProtocol final : public AggregationProtocol {
49  public:
50   // Validates the Configuration that will subsequently be used to create an
51   // instance of this protocol.
52   // Returns INVALID_ARGUMENT if the configuration is invalid.
53   static absl::Status ValidateConfig(const Configuration& configuration);
54 
55   // Factory method to create an instance of the Simple Aggregation Protocol.
56   //
57   // Does not take ownership of the callback, which must refer to a valid object
58   // that outlives the SimpleAggregationProtocol instance.
59   static absl::StatusOr<std::unique_ptr<SimpleAggregationProtocol>> Create(
60       const Configuration& configuration,
61       AggregationProtocol::Callback* callback,
62       const CheckpointParserFactory* checkpoint_parser_factory,
63       const CheckpointBuilderFactory* checkpoint_builder_factory,
64       ResourceResolver* resource_resolver);
65 
66   // Implementation of the overridden Aggregation Protocol methods.
67   absl::Status Start(int64_t num_clients) override;
68   absl::Status AddClients(int64_t num_clients) override;
69   absl::Status ReceiveClientMessage(int64_t client_id,
70                                     const ClientMessage& message) override;
71   absl::Status CloseClient(int64_t client_id,
72                            absl::Status client_status) override;
73   absl::Status Complete() override;
74   absl::Status Abort() override;
75   StatusMessage GetStatus() override;
76 
77   ~SimpleAggregationProtocol() override = default;
78 
79   // SimpleAggregationProtocol is neither copyable nor movable.
80   SimpleAggregationProtocol(const SimpleAggregationProtocol&) = delete;
81   SimpleAggregationProtocol& operator=(const SimpleAggregationProtocol&) =
82       delete;
83 
84  private:
85   // The structure representing a single aggregation intrinsic.
86   // TODO(team): Implement mapping of multiple inputs and outputs to
87   // individual TensorAggregator instances.
88   struct Intrinsic {
89     TensorSpec input;
90     TensorSpec output;
91     std::unique_ptr<TensorAggregator> aggregator
92         ABSL_PT_GUARDED_BY(&SimpleAggregationProtocol::aggregation_mu_);
93   };
94 
95   // Private constructor.
96   SimpleAggregationProtocol(
97       std::vector<Intrinsic> intrinsics,
98       AggregationProtocol::Callback* callback,
99       const CheckpointParserFactory* checkpoint_parser_factory,
100       const CheckpointBuilderFactory* checkpoint_builder_factory,
101       ResourceResolver* resource_resolver);
102 
103   // Creates an aggregation intrinsic based on the intrinsic configuration.
104   static absl::StatusOr<Intrinsic> CreateIntrinsic(
105       const Configuration::ServerAggregationConfig& aggregation_config);
106 
107   // Describes the overall protocol state.
108   enum ProtocolState {
109     // The initial state indicating that the protocol was created.
110     PROTOCOL_CREATED,
111     // The protocol `Start` method has been called.
112     PROTOCOL_STARTED,
113     // The protocol `Complete` method has finished successfully.
114     PROTOCOL_COMPLETED,
115     // The protocol `Abort` method has been called.
116     PROTOCOL_ABORTED
117   };
118 
119   // Describes state of each client participating in the protocol.
120   enum ClientState : uint8_t {
121     // No input received from the client yet.
122     CLIENT_PENDING,
123     // Client input received but the aggregation still pending, which may
124     // be the case when there are multiple concurrent ReceiveClientMessage
125     // calls.
126     CLIENT_RECEIVED_INPUT_AND_PENDING,
127     // Client input has been successfully aggregated.
128     CLIENT_COMPLETED,
129     // Client failed either by being closed with an error or by submitting a
130     // malformed input.
131     CLIENT_FAILED,
132     // Client which has been aborted by the server before its input has been
133     // received.
134     CLIENT_ABORTED,
135     // Client input has been received but discarded, for example due to the
136     // protocol Abort method being called.
137     CLIENT_DISCARDED
138   };
139 
140   // Returns string representation of the protocol state.
141   static absl::string_view ProtocolStateDebugString(ProtocolState state);
142 
143   // Returns string representation of the client state.
144   static absl::string_view ClientStateDebugString(ClientState state);
145 
146   // Returns an error if the current protocol state isn't the expected one.
147   absl::Status CheckProtocolState(ProtocolState state) const
148       ABSL_SHARED_LOCKS_REQUIRED(state_mu_);
149 
150   // Changes the protocol state.
151   void SetProtocolState(ProtocolState state)
152       ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_);
153 
154   // Gets the client state for the given client ID.
155   absl::StatusOr<ClientState> GetClientState(int64_t client_id) const
156       ABSL_SHARED_LOCKS_REQUIRED(state_mu_);
157 
158   // Sets the client state for the given client ID.
159   void SetClientState(int64_t client_id, ClientState state)
160       ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_);
161 
162   // Parses and validates the client report.
163   // This function involves a potentially expensive I/O and parsing and should
164   // run concurrently as much as possible. The ABSL_LOCKS_EXCLUDED attribution
165   // below is used to emphasize that.
166   using TensorMap = absl::flat_hash_map<std::string, Tensor>;
167   absl::StatusOr<TensorMap> ParseCheckpoint(absl::Cord report) const
168       ABSL_LOCKS_EXCLUDED(state_mu_, aggregation_mu_);
169 
170   // Aggregates the input via the underlying aggregators.
171   absl::Status AggregateClientInput(TensorMap tensor_map)
172       ABSL_LOCKS_EXCLUDED(state_mu_, aggregation_mu_);
173 
174   // Produces the report via the underlying aggregators.
175   absl::StatusOr<absl::Cord> CreateReport()
176       ABSL_LOCKS_EXCLUDED(aggregation_mu_);
177 
178   // Protects the mutable state.
179   absl::Mutex state_mu_;
180   // Protects calls into the aggregators.
181   absl::Mutex aggregation_mu_;
182   // This indicates that the aggregation has finished either by completing
183   // the protocol or by aborting it. This can be triggered without locking on
184   // the aggregation_mu_ mutex first to allow aborting the protocol promptly and
185   // discarding all the pending aggregation calls.
186   std::atomic_bool aggregation_finished_ = false;
187 
188   // The overall state of the protocol.
189   ProtocolState protocol_state_ ABSL_GUARDED_BY(state_mu_);
190 
191   // Holds state of all clients. The length of the vector equals
192   // to the number of clients accepted into the protocol.
193   std::vector<ClientState> client_states_ ABSL_GUARDED_BY(state_mu_);
194 
195   // Counters for various client states other than pending.
196   // Note that the number of pending clients can be found by subtracting the
197   // sum of the below counters from `client_states_.size()`.
198   uint64_t num_clients_received_and_pending_ ABSL_GUARDED_BY(state_mu_) = 0;
199   uint64_t num_clients_aggregated_ ABSL_GUARDED_BY(state_mu_) = 0;
200   uint64_t num_clients_failed_ ABSL_GUARDED_BY(state_mu_) = 0;
201   uint64_t num_clients_aborted_ ABSL_GUARDED_BY(state_mu_) = 0;
202   uint64_t num_clients_discarded_ ABSL_GUARDED_BY(state_mu_) = 0;
203 
204   // Intrinsics are immutable and shouldn't be guarded by the either of mutexes.
205   // Please note that the access to the aggregators that intrinsics point to
206   // still needs to be strictly sequential. That is guarded separatedly by
207   // `aggregators_mu_`.
208   std::vector<Intrinsic> const intrinsics_;
209 
210   AggregationProtocol::Callback* const callback_;
211   const CheckpointParserFactory* const checkpoint_parser_factory_;
212   const CheckpointBuilderFactory* const checkpoint_builder_factory_;
213   ResourceResolver* const resource_resolver_;
214 };
215 }  // namespace fcp::aggregation
216 
217 #endif  // FCP_AGGREGATION_PROTOCOL_SIMPLE_AGGREGATION_SIMPLE_AGGREGATION_PROTOCOL_H_
218