xref: /aosp_15_r20/external/federated-compute/fcp/client/federated_protocol.h (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2020 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef FCP_CLIENT_FEDERATED_PROTOCOL_H_
17 #define FCP_CLIENT_FEDERATED_PROTOCOL_H_
18 
19 #include <cstdint>
20 #include <functional>
21 #include <memory>
22 #include <optional>
23 #include <string>
24 #include <utility>
25 #include <variant>
26 #include <vector>
27 
28 #include "absl/container/node_hash_map.h"
29 #include "absl/status/status.h"
30 #include "absl/status/statusor.h"
31 #include "fcp/client/engine/engine.pb.h"
32 #include "fcp/client/stats.h"
33 #include "fcp/protos/federated_api.pb.h"
34 #include "fcp/protos/federatedcompute/eligibility_eval_tasks.pb.h"
35 #include "fcp/protos/plan.pb.h"
36 
37 namespace fcp {
38 namespace client {
39 
40 // Data type used to encode results of a computation - a TensorFlow
41 // checkpoint, or SecAgg quantized tensors.
42 // For non-SecAgg use (simple federated aggregation, or local computation),
43 // this map should only contain one entry - a TFCheckpoint - and the string
44 // should be ignored by downstream code.
45 // For SecAgg use, there should be
46 // * at most one TFCheckpoint - again, the key should be ignored - and
47 // * N QuantizedTensors, whose string keys must map to the tensor names
48 //   provided in the server's CheckinResponse's SideChannelExecutionInfo.
49 using TFCheckpoint = std::string;
50 struct QuantizedTensor {
51   std::vector<uint64_t> values;
52   int32_t bitwidth = 0;
53   std::vector<int64_t> dimensions;
54 
55   QuantizedTensor() = default;
56   // Disallow copy and assign.
57   QuantizedTensor(const QuantizedTensor&) = delete;
58   QuantizedTensor& operator=(const QuantizedTensor&) = delete;
59   // Enable move semantics.
60   QuantizedTensor(QuantizedTensor&&) = default;
61   QuantizedTensor& operator=(QuantizedTensor&&) = default;
62 };
63 // This is equivalent to using ComputationResults =
64 //    std::map<std::string, std::variant<TFCheckpoint, QuantizedTensor>>;
65 // except copy construction and assignment are explicitly prohibited and move
66 // semantics is enforced.
67 class ComputationResults
68     : public absl::node_hash_map<std::string,
69                                  std::variant<TFCheckpoint, QuantizedTensor>> {
70  public:
71   using Base = absl::node_hash_map<std::string,
72                                    std::variant<TFCheckpoint, QuantizedTensor>>;
73   using Base::Base;
74   using Base::operator=;
75   ComputationResults(const ComputationResults&) = delete;
76   ComputationResults& operator=(const ComputationResults&) = delete;
77   ComputationResults(ComputationResults&&) = default;
78   ComputationResults& operator=(ComputationResults&&) = default;
79 };
80 
81 // An interface that represents a single Federated Compute protocol session.
82 //
83 // An instance of this class represents a single session of client-server
84 // interaction. Instances are generally stateful, and therefore cannot be
85 // reused (each session should use a dedicated instance).
86 //
87 // The protocol consists of 3 phases, which must occur in the following order:
88 // 1. A call to `EligibilityEvalCheckin()`.
89 // 2. A call to `Checkin(...)`, only if the client wasn't rejected by the server
90 //    in the previous phase.
91 // 3. A call to `ReportCompleted(...)` or `ReportNotCompleted(...)`, only if the
92 //    client wasn't rejected in the previous phase.
93 class FederatedProtocol {
94  public:
95   virtual ~FederatedProtocol() = default;
96 
97   // The unparsed plan and checkpoint payload which make up a computation. The
98   // data can be provided as either an std::string or an absl::Cord.
99   struct PlanAndCheckpointPayloads {
100     std::variant<std::string, absl::Cord> plan;
101     std::variant<std::string, absl::Cord> checkpoint;
102   };
103 
104   // An eligibility task, consisting of task payloads and an execution ID.
105   struct EligibilityEvalTask {
106     PlanAndCheckpointPayloads payloads;
107     std::string execution_id;
108     std::optional<
109         google::internal::federatedcompute::v1::PopulationEligibilitySpec>
110         population_eligibility_spec;
111   };
112   // A rejection of a client by the server.
113   struct Rejection {};
114   // Indicates that the server does not have an eligibility eval task configured
115   // for the population.
116   struct EligibilityEvalDisabled {};
117   // EligibilityEvalCheckin() returns either
118   // 1. an `EligibilityEvalTask` struct holding the payloads for an eligibility
119   //    eval task, if the population is configured with such a task. In this
120   //    case the caller should execute the task and pass the resulting
121   //    `TaskEligibilityInfo` value to the `Checkin(...)` method.
122   // 2. an `EligibilityEvalDisabled` struct if the population doesn't have an
123   //    eligibility eval task configured. In this case the caller should
124   //    continue the protocol by calling the `Checkin(...)` method without
125   //    providing a `TaskEligibilityInfo` value.
126   // 3. a `Rejection` if the server rejected this device. In this case the
127   // caller
128   //    should end its protocol interaction.
129   using EligibilityEvalCheckinResult =
130       std::variant<EligibilityEvalTask, EligibilityEvalDisabled, Rejection>;
131 
132   // Checks in with a federated server to receive the population's eligibility
133   // eval task. This method is optional and may be called 0 or 1 times. If it is
134   // called, then it must be called before any call to `Checkin(...)`.
135   //
136   // If an eligibility eval task is configured, then the
137   // `payload_uris_received_callback` function will be called with a partially
138   // populated `EligibilityEvalTask` containing all of the task's info except
139   // for the actual payloads (which are yet to be fetched at that point).
140   //
141   // Returns:
142   // - On success, an EligibilityEvalCheckinResult.
143   // - On error:
144   //   - ABORTED when one of the I/O operations got aborted by the server.
145   //   - CANCELLED when one of the I/O operations was interrupted by the client
146   //     (possibly due to a positive result from the should_abort callback).
147   //   - UNAVAILABLE when server cannot be reached or URI is invalid.
148   //   - NOT_FOUND if the server responds with NOT_FOUND, e.g. because the
149   //     specified population name is incorrect.
150   //   - UNIMPLEMENTED if an unexpected server response is received.
151   //   - INTERNAL if the server-provided ClientOnlyPlan cannot be parsed. (See
152   //     note in federated_protocol.cc for the reasoning for this.)
153   //   - INTERNAL for other unexpected client-side errors.
154   //   - any server-provided error code.
155   virtual absl::StatusOr<EligibilityEvalCheckinResult> EligibilityEvalCheckin(
156       std::function<void(const EligibilityEvalTask&)>
157           payload_uris_received_callback) = 0;
158 
159   // Report an eligibility eval task error to the federated server.
160   // Must only be called once and after a successful call to
161   // EligibilityEvalCheckin() which returns an eligibility eval task. This
162   // method is only used to report an error happened during the computation of
163   // the eligibility eval task. If the eligibility eval computation succeeds,
164   // the success will be reported during task assignment.
165   // @param status the outcome of the eligibility eval computation.
166   virtual void ReportEligibilityEvalError(absl::Status error_status) = 0;
167 
168   // SecAgg metadata, e.g. see SecureAggregationProtocolExecutionInfo in
169   // federated_api.proto.
170   struct SecAggInfo {
171     int32_t expected_number_of_clients;
172     int32_t minimum_clients_in_server_visible_aggregate;
173   };
174 
175   // A task assignment, consisting of task payloads, a URI template to download
176   // federated select task slices with (if the plan uses federated select), a
177   // session identifier, and SecAgg-related metadata.
178   struct TaskAssignment {
179     PlanAndCheckpointPayloads payloads;
180     std::string federated_select_uri_template;
181     std::string aggregation_session_id;
182     std::optional<SecAggInfo> sec_agg_info;
183   };
184   // Checkin() returns either
185   // 1. a `TaskAssignment` struct if the client was assigned a task to run, or
186   // 2. a `Rejection` struct if the server rejected this device.
187   using CheckinResult = std::variant<TaskAssignment, Rejection>;
188 
189   // Checks in with a federated server. Must only be called once. If the
190   // `EligibilityEvalCheckin()` method was previously called, then this method
191   // must only be called if the result of that call was not a `Rejection`.
192   //
193   // If the caller previously called `EligibilityEvalCheckin()` and:
194   // - received a payload, then the `TaskEligibilityInfo` value computed by that
195   //   payload must be provided via the `task_eligibility_info` parameter.
196   // - received an `EligibilityEvalDisabled` result, then the
197   //   `task_eligibility_info` parameter should be left empty.
198   //
199   // If the caller did not previously call `EligibilityEvalCheckin()`, then the
200   // `task_eligibility_info` parameter should be left empty.
201   //
202   // If the client is assigned a task by the server, then the
203   // `payload_uris_received_callback` function will be called with a partially
204   // populated `TaskAssignment` containing all of the task's info except for the
205   // actual payloads (which are yet to be fetched at that point)
206   //
207   // Returns:
208   // - On success, a `CheckinResult`.
209   // - On error:
210   //   - ABORTED when one of the I/O operations got aborted by the server.
211   //   - CANCELLED when one of the I/O operations was interrupted by the client
212   //     (possibly due to a positive result from the should_abort callback).
213   //   - UNAVAILABLE when server cannot be reached or URI is invalid.
214   //   - NOT_FOUND if the server responds with NOT_FOUND, e.g. because the
215   //     specified population name is incorrect.
216   //   - UNIMPLEMENTED if an unexpected server response is received.
217   //   - INTERNAL if the server-provided ClientOnlyPlan cannot be parsed. (See
218   //     note in federated_protocol.cc for the reasoning for this.)
219   //   - INTERNAL for other unexpected client-side errors.
220   //   - any server-provided error code.
221   // TODO(team): Replace this reference to protocol-specific
222   // TaskEligibilityInfo proto with a protocol-agnostic struct.
223   virtual absl::StatusOr<CheckinResult> Checkin(
224       const std::optional<
225           google::internal::federatedml::v2::TaskEligibilityInfo>&
226           task_eligibility_info,
227       std::function<void(const TaskAssignment&)>
228           payload_uris_received_callback) = 0;
229 
230   // A list of absl::StatusOr<TaskAssignment> returned by
231   // PerformMultipleTaskAssignments. Individual absl::StatusOr<TaskAssignment>
232   // may be an error status due to failed to fetch the plan resources.
233   struct MultipleTaskAssignments {
234     std::vector<absl::StatusOr<TaskAssignment>> task_assignments;
235   };
236 
237   // Checks in with a federated server to get multiple task assignments.
238   //
239   // Must only be called once after the following conditions are met:
240   //
241   // - the caller previously called `EligibilityEvalCheckin()` and,
242   // - received a payload, and the returned EligibilityEvalTask's
243   // `PopulationEligibilitySpec` contained at least one task with
244   // TASK_ASSIGNMENT_MODE_MULTIPLE, for which the device is eligible.
245   //
246   //
247   // Returns:
248   // - On success, a `MultipleTaskAssignments`.
249   // - On error:
250   //   - ABORTED when one of the I/O operations got aborted by the server.
251   //   - CANCELLED when one of the I/O operations was interrupted by the client
252   //     (possibly due to a positive result from the should_abort callback).
253   //   - UNAVAILABLE when server cannot be reached or URI is invalid.
254   //   - NOT_FOUND if the server responds with NOT_FOUND, e.g. because the
255   //     specified population name is incorrect.
256   //   - UNIMPLEMENTED if an unexpected server response is received.
257   //   - INTERNAL for other unexpected client-side errors.
258   //   - any server-provided error code.
259   virtual absl::StatusOr<MultipleTaskAssignments>
260   PerformMultipleTaskAssignments(
261       const std::vector<std::string>& task_names) = 0;
262 
263   // Reports the result of a federated computation to the server. Must only be
264   // called once and after a successful call to Checkin().
265   // @param checkpoint A checkpoint proto.
266   // @param stats all stats reported during the computation.
267   // @param plan_duration the duration for executing the plan in the plan
268   //        engine. Does not include time spent on downloading the plan.
269   // Returns:
270   // - On success, OK.
271   // - On error (e.g. an interruption, network error, or other unexpected
272   //   error):
273   //   - ABORTED when one of the I/O operations got aborted by the server.
274   //   - CANCELLED when one of the I/O operations was interrupted by the client
275   //     (possibly due to a positive result from the should_abort callback).
276   //   - UNIMPLEMENTED if the server responded with an unexpected response
277   //     message.
278   //   - INTERNAL for other unexpected client-side errors.
279   //   - any server-provided error code.
280   virtual absl::Status ReportCompleted(
281       ComputationResults results, absl::Duration plan_duration,
282       std::optional<std::string> aggregation_session_id) = 0;
283 
284   // Reports the unsuccessful result of a federated computation to the server.
285   // Must only be called once and after a successful call to Checkin().
286   // @param phase_outcome the outcome of the federated computation.
287   // @param plan_duration the duration for executing the plan in the plan
288   //        engine. Does not include time spent on downloading the plan.
289   // Returns:
290   // - On success, OK.
291   // - On error:
292   //   - ABORTED when one of the I/O operations got aborted by the server.
293   //   - CANCELLED when one of the I/O operations was interrupted by the client
294   //     (possibly due to a positive result from the should_abort callback).
295   //   - UNIMPLEMENTED if the server responded with an unexpected response
296   //     message, or if the results to report require SecAgg support.
297   //   - INTERNAL for other unexpected client-side errors.
298   //   - any server-provided error code.
299   virtual absl::Status ReportNotCompleted(
300       engine::PhaseOutcome phase_outcome, absl::Duration plan_duration,
301       std::optional<std::string> aggregation_session_id) = 0;
302 
303   // Returns the RetryWindow the caller should use when rescheduling, based on
304   // the current protocol phase. The value returned by this method may change
305   // after every interaction with the protocol, so callers should call this
306   // right before ending their interactions with the FederatedProtocol object to
307   // ensure they use the most recent value.
308   // TODO(team): Replace this reference to protocol-specific
309   // RetryWindow proto with a protocol-agnostic struct (or just a single
310   // absl::Duration).
311   virtual google::internal::federatedml::v2::RetryWindow
312   GetLatestRetryWindow() = 0;
313 
314   // Returns the best estimate of the total bytes downloaded and uploaded over
315   // the network, plus the best estimate of the duration of wall clock time
316   // spent waiting for network requests to finish (but, for example, excluding
317   // any idle time spent waiting between issuing polling requests).
318   //
319   // Note that this estimate may still include time spent simply waiting for a
320   // server response, even if no data was being sent or received during that
321   // time. E.g. in the case of the legacy gRPC protocol where the single checkin
322   // request blocks until a task is assigned to the client.
323   //
324   // If possible, this estimate should also include time spent
325   // compressing/decompressing payloads before writing them to or after reading
326   // them from the network.
327   virtual NetworkStats GetNetworkStats() = 0;
328 
329  protected:
330   // A list of states representing the sequence of calls we expect to receive
331   // via this interface, as well as their possible outcomes. Implementations of
332   // this class are likely to share these coarse-grained states, and use them to
333   // determine which values to return from `GetLatestRetryWindow()`.
334   enum class ObjectState {
335     // The initial object state.
336     kInitialized,
337     // EligibilityEvalCheckin() was called but it failed with a 'transient'
338     // error (e.g. an UNAVAILABLE network error, although the set of transient
339     // errors is flag-defined).
340     kEligibilityEvalCheckinFailed,
341     // EligibilityEvalCheckin() was called but it failed with a 'permanent'
342     // error (e.g. a NOT_FOUND network error, although the set of permanent
343     // errors is flag-defined).
344     kEligibilityEvalCheckinFailedPermanentError,
345     // EligibilityEvalCheckin() was called, and the server rejected the client.
346     kEligibilityEvalCheckinRejected,
347     // EligibilityEvalCheckin() was called, and the server did not return an
348     // eligibility eval payload.
349     kEligibilityEvalDisabled,
350     // EligibilityEvalCheckin() was called, and the server did return an
351     // eligibility eval payload, which must then be run to produce a
352     // TaskEligibilityInfo value.
353     kEligibilityEvalEnabled,
354     // Checkin(...) was called but it failed with a 'transient' error.
355     kCheckinFailed,
356     // Checkin(...) was called but it failed with a 'permanent' error.
357     kCheckinFailedPermanentError,
358     // Checkin(...) was called, and the server rejected the client.
359     kCheckinRejected,
360     // Checkin(...) was called, and the server accepted the client and returned
361     // a payload, which must then be run to produce a report.
362     kCheckinAccepted,
363     // PerformMultipleTaskAssignments(...) was called but it failed with a
364     // 'transient' error, without receiving a single task assignment. If some
365     // task assignments were successfully received, but some others failed (e.g.
366     // because their resources failed to be downloaded), then this state won't
367     // be used.
368     kMultipleTaskAssignmentsFailed,
369     // PerformMultipleTaskAssignments(...) was called but it failed with a
370     // 'permanent' error.
371     kMultipleTaskAssignmentsFailedPermanentError,
372     // PerformMultipleTaskAssignments(...) was called but an empty list of tasks
373     // is returned by the server.
374     kMultipleTaskAssignmentsNoAvailableTask,
375     // PerformMultipleTaskAssignments(...) was called, and the server accepted
376     // the client and returned one or more payload, which must then be run to
377     // produce a report.
378     kMultipleTaskAssignmentsAccepted,
379     // Report(...) was called.
380     kReportCalled,
381     // Report(...) was called and it resulted in a 'permanent' error.
382     //
383     // Note: there is no kReportFailed (corresponding to 'transient' errors,
384     // like the other phases have), because by the time the report phase is
385     // reached, a set of RetryWindows is guaranteed to have been received from
386     // the server.
387     kReportFailedPermanentError,
388     // Report(...) was called for multiple tasks, and only a subset of the tasks
389     // succeed.
390     kReportMultipleTaskPartialError,
391   };
392 };
393 
394 }  // namespace client
395 }  // namespace fcp
396 
397 #endif  // FCP_CLIENT_FEDERATED_PROTOCOL_H_
398