xref: /aosp_15_r20/external/federated-compute/fcp/client/http/http_federated_protocol.h (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2022 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef FCP_CLIENT_HTTP_HTTP_FEDERATED_PROTOCOL_H_
17 #define FCP_CLIENT_HTTP_HTTP_FEDERATED_PROTOCOL_H_
18 
19 #include <cstdint>
20 #include <functional>
21 #include <memory>
22 #include <optional>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 
27 #include "absl/container/flat_hash_set.h"
28 #include "absl/random/random.h"
29 #include "absl/status/status.h"
30 #include "absl/status/statusor.h"
31 #include "absl/strings/string_view.h"
32 #include "absl/time/time.h"
33 #include "fcp/base/clock.h"
34 #include "fcp/base/monitoring.h"
35 #include "fcp/base/wall_clock_stopwatch.h"
36 #include "fcp/client/cache/resource_cache.h"
37 #include "fcp/client/engine/engine.pb.h"
38 #include "fcp/client/federated_protocol.h"
39 #include "fcp/client/fl_runner.pb.h"
40 #include "fcp/client/flags.h"
41 #include "fcp/client/http/http_client.h"
42 #include "fcp/client/http/in_memory_request_response.h"
43 #include "fcp/client/http/protocol_request_helper.h"
44 #include "fcp/client/interruptible_runner.h"
45 #include "fcp/client/log_manager.h"
46 #include "fcp/client/secagg_runner.h"
47 #include "fcp/client/selector_context.pb.h"
48 #include "fcp/client/stats.h"
49 #include "fcp/protos/federated_api.pb.h"
50 #include "fcp/protos/federatedcompute/common.pb.h"
51 #include "fcp/protos/federatedcompute/eligibility_eval_tasks.pb.h"
52 #include "fcp/protos/federatedcompute/secure_aggregations.pb.h"
53 #include "fcp/protos/federatedcompute/task_assignments.pb.h"
54 #include "fcp/protos/plan.pb.h"
55 #include "fcp/secagg/client/secagg_client.h"
56 
57 namespace fcp {
58 namespace client {
59 namespace http {
60 
61 // Implements a single session of the HTTP-based Federated Compute protocol.
62 class HttpFederatedProtocol : public fcp::client::FederatedProtocol {
63  public:
64   HttpFederatedProtocol(
65       Clock* clock, LogManager* log_manager, const Flags* flags,
66       HttpClient* http_client,
67       std::unique_ptr<SecAggRunnerFactory> secagg_runner_factory,
68       SecAggEventPublisher* secagg_event_publisher,
69       absl::string_view entry_point_uri, absl::string_view api_key,
70       absl::string_view population_name, absl::string_view retry_token,
71       absl::string_view client_version,
72       absl::string_view attestation_measurement,
73       std::function<bool()> should_abort, absl::BitGen bit_gen,
74       const InterruptibleRunner::TimingConfig& timing_config,
75       cache::ResourceCache* resource_cache);
76 
77   ~HttpFederatedProtocol() override = default;
78 
79   absl::StatusOr<fcp::client::FederatedProtocol::EligibilityEvalCheckinResult>
80   EligibilityEvalCheckin(std::function<void(const EligibilityEvalTask&)>
81                              payload_uris_received_callback) override;
82 
83   void ReportEligibilityEvalError(absl::Status error_status) override;
84 
85   absl::StatusOr<fcp::client::FederatedProtocol::CheckinResult> Checkin(
86       const std::optional<
87           google::internal::federatedml::v2::TaskEligibilityInfo>&
88           task_eligibility_info,
89       std::function<void(const TaskAssignment&)> payload_uris_received_callback)
90       override;
91 
92   absl::StatusOr<MultipleTaskAssignments> PerformMultipleTaskAssignments(
93       const std::vector<std::string>& task_names) override;
94 
95   absl::Status ReportCompleted(
96       ComputationResults results, absl::Duration plan_duration,
97       std::optional<std::string> aggregation_session_id) override;
98 
99   absl::Status ReportNotCompleted(
100       engine::PhaseOutcome phase_outcome, absl::Duration plan_duration,
101       std::optional<std::string> aggregation_session_id) override;
102 
103   google::internal::federatedml::v2::RetryWindow GetLatestRetryWindow()
104       override;
105 
106   NetworkStats GetNetworkStats() override;
107 
108  private:
109   // Information for a given task.
110   struct PerTaskInfo {
111     std::unique_ptr<ProtocolRequestCreator> aggregation_request_creator;
112     std::unique_ptr<ProtocolRequestCreator> data_upload_request_creator;
113     std::string session_id;
114     // The identifier of the aggregation session we are participating in.
115     std::string aggregation_session_id;
116     // The token authorizing the client to participate in an aggregation
117     // session.
118     std::string aggregation_authorization_token;
119     // The name identifying the task that was assigned.
120     std::string task_name;
121     // Unique identifier for the client's participation in an aggregation
122     // session.
123     std::string aggregation_client_token;
124     // Resource name for the checkpoint in simple aggregation.
125     std::string aggregation_resource_name;
126     // Each task's state is tracked individually starting from the end of
127     // check-in or multiple task assignments. The states from all of the tasks
128     // will be used collectively to determine which retry window to use.
129     ObjectState state = ObjectState::kInitialized;
130   };
131 
132   // Helper function to perform an eligibility eval task request and get its
133   // response.
134   absl::StatusOr<InMemoryHttpResponse> PerformEligibilityEvalTaskRequest();
135 
136   // Helper function for handling an eligibility eval task response (incl.
137   // fetching any resources, if necessary).
138   absl::StatusOr<fcp::client::FederatedProtocol::EligibilityEvalCheckinResult>
139   HandleEligibilityEvalTaskResponse(
140       absl::StatusOr<InMemoryHttpResponse> http_response,
141       std::function<void(const EligibilityEvalTask&)>
142           payload_uris_received_callback);
143 
144   absl::StatusOr<std::unique_ptr<HttpRequest>>
145   CreateReportEligibilityEvalTaskResultRequest(absl::Status status);
146 
147   // Helper function to perform an ReportEligibilityEvalResult request.
148   absl::Status ReportEligibilityEvalErrorInternal(absl::Status error_status);
149 
150   // Helper function to perform a task assignment request and get its response.
151   absl::StatusOr<InMemoryHttpResponse>
152   PerformTaskAssignmentAndReportEligibilityEvalResultRequests(
153       const std::optional<
154           ::google::internal::federatedml::v2::TaskEligibilityInfo>&
155           task_eligibility_info);
156 
157   // Helper function for handling the 'outer' task assignment response, which
158   // consists of an `Operation` which may or may not need to be polled before a
159   // final 'inner' response is available.
160   absl::StatusOr<::fcp::client::FederatedProtocol::CheckinResult>
161   HandleTaskAssignmentOperationResponse(
162       absl::StatusOr<InMemoryHttpResponse> http_response,
163       std::function<void(const TaskAssignment&)>
164           payload_uris_received_callback);
165 
166   // Helper function for handling an 'inner' task assignment response (i.e.
167   // after the outer `Operation` has concluded). This includes fetching any
168   // resources, if necessary.
169   absl::StatusOr<::fcp::client::FederatedProtocol::CheckinResult>
170   HandleTaskAssignmentInnerResponse(
171       const google::internal::federatedcompute::v1::StartTaskAssignmentResponse&
172           response_proto,
173       std::function<void(const TaskAssignment&)>
174           payload_uris_received_callback);
175 
176   // Helper function for reporting result via simple aggregation.
177   absl::Status ReportViaSimpleAggregation(ComputationResults results,
178                                           absl::Duration plan_duration,
179                                           PerTaskInfo& task_info);
180   // Helper function to perform a StartDataUploadRequest and a ReportTaskResult
181   // request concurrently.
182   // This method will only return the response from the StartDataUploadRequest.
183   absl::StatusOr<InMemoryHttpResponse>
184   PerformStartDataUploadRequestAndReportTaskResult(absl::Duration plan_duration,
185                                                    PerTaskInfo& task_info);
186 
187   // Helper function for handling a longrunning operation returned by a
188   // StartDataAggregationUpload request.
189   absl::Status HandleStartDataAggregationUploadOperationResponse(
190       absl::StatusOr<InMemoryHttpResponse> http_response,
191       PerTaskInfo& task_info);
192 
193   // Helper function to perform data upload via simple aggregation.
194   absl::Status UploadDataViaSimpleAgg(std::string tf_checkpoint,
195                                       PerTaskInfo& task_info);
196 
197   // Helper function to perform a SubmitAggregationResult request.
198   absl::Status SubmitAggregationResult(PerTaskInfo& task_info);
199 
200   // Helper function to perform an AbortAggregation request.
201   // We only provide the server with a simplified error message.
202   absl::Status AbortAggregation(absl::Status original_error_status,
203                                 absl::string_view error_message_for_server,
204                                 PerTaskInfo& task_info);
205 
206   // Helper function for reporting via secure aggregation.
207   absl::Status ReportViaSecureAggregation(ComputationResults results,
208                                           absl::Duration plan_duration,
209                                           PerTaskInfo& task_info);
210 
211   // Helper function to perform a StartSecureAggregationRequest and a
212   // ReportTaskResultRequest.
213   absl::StatusOr<
214       google::internal::federatedcompute::v1::StartSecureAggregationResponse>
215   StartSecureAggregationAndReportTaskResult(absl::Duration plan_duration,
216                                             PerTaskInfo& task_info);
217 
218   struct TaskResources {
219     const ::google::internal::federatedcompute::v1::Resource& plan;
220     const ::google::internal::federatedcompute::v1::Resource& checkpoint;
221   };
222 
223   // Helper function for fetching the checkpoint/plan resources for an
224   // eligibility eval task or regular task.
225   absl::StatusOr<PlanAndCheckpointPayloads> FetchTaskResources(
226       TaskResources task_resources);
227 
228   // Helper function for fetching the PopulationEligibilitySpec.
229   absl::StatusOr<
230       google::internal::federatedcompute::v1::PopulationEligibilitySpec>
231   FetchPopulationEligibilitySpec(
232       const ::google::internal::federatedcompute::v1::Resource&
233           population_eligibility_spec_resource);
234 
235   // Helper that moves to the given object state if the given status represents
236   // a permanent error.
237   void UpdateObjectStateIfPermanentError(
238       absl::Status status, ObjectState permanent_error_object_state);
239 
240   ObjectState GetTheLatestStateFromAllTasks();
241 
242   // This ObjectState tracks states until the end of check-in or multiple task
243   // assignments.  Once a task is assigned, the state is tracked inside the
244   // task_info_map_ for multiple task assignments or default_task_info_ for
245   // single task check-in.
246   ObjectState object_state_;
247   Clock& clock_;
248   LogManager* log_manager_;
249   const Flags* const flags_;
250   HttpClient* const http_client_;
251   std::unique_ptr<SecAggRunnerFactory> secagg_runner_factory_;
252   SecAggEventPublisher* secagg_event_publisher_;
253   std::unique_ptr<InterruptibleRunner> interruptible_runner_;
254   std::unique_ptr<ProtocolRequestCreator> eligibility_eval_request_creator_;
255   std::unique_ptr<ProtocolRequestCreator> task_assignment_request_creator_;
256   std::unique_ptr<WallClockStopwatch> network_stopwatch_ =
257       WallClockStopwatch::Create();
258   ProtocolRequestHelper protocol_request_helper_;
259   const std::string api_key_;
260   const std::string population_name_;
261   const std::string retry_token_;
262   const std::string client_version_;
263   const std::string attestation_measurement_;
264   std::function<bool()> should_abort_;
265   absl::BitGen bit_gen_;
266   const InterruptibleRunner::TimingConfig timing_config_;
267   // The graceful waiting period for cancellation requests before checking
268   // whether the client should be interrupted.
269   const absl::Duration waiting_period_for_cancellation_;
270   // The set of canonical error codes that should be treated as 'permanent'
271   // errors.
272   absl::flat_hash_set<int32_t> federated_training_permanent_error_codes_;
273   int64_t bytes_downloaded_ = 0;
274   int64_t bytes_uploaded_ = 0;
275   // Represents 2 absolute retry timestamps to use when the device is rejected
276   // or accepted. The retry timestamps will have been generated based on the
277   // retry windows specified in the server's EligibilityEvalTaskResponse message
278   // and the time at which that message was received.
279   struct RetryTimes {
280     absl::Time retry_time_if_rejected;
281     absl::Time retry_time_if_accepted;
282   };
283   // Represents the information received via the EligibilityEvalTaskResponse
284   // message. This field will have an absent value until that message has been
285   // received.
286   std::optional<RetryTimes> retry_times_;
287   std::string pre_task_assignment_session_id_;
288 
289   // A map of aggregation_session_id to per-task information.
290   // Only tasks from the multiple task assignments will be tracked in this map.
291   absl::flat_hash_map<std::string, PerTaskInfo> task_info_map_;
292   // The task received from the regular check-in will be tracked here.
293   PerTaskInfo default_task_info_;
294 
295   // Set this field to true if an eligibility eval task was received from the
296   // server in the EligibilityEvalTaskResponse.
297   bool eligibility_eval_enabled_ = false;
298   // `nullptr` if the feature is disabled.
299   cache::ResourceCache* resource_cache_;
300 };
301 
302 }  // namespace http
303 }  // namespace client
304 }  // namespace fcp
305 
306 #endif  // FCP_CLIENT_HTTP_HTTP_FEDERATED_PROTOCOL_H_
307