1 /* 2 * Copyright 2022 Google LLC 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef FCP_CLIENT_HTTP_HTTP_FEDERATED_PROTOCOL_H_ 17 #define FCP_CLIENT_HTTP_HTTP_FEDERATED_PROTOCOL_H_ 18 19 #include <cstdint> 20 #include <functional> 21 #include <memory> 22 #include <optional> 23 #include <string> 24 #include <utility> 25 #include <vector> 26 27 #include "absl/container/flat_hash_set.h" 28 #include "absl/random/random.h" 29 #include "absl/status/status.h" 30 #include "absl/status/statusor.h" 31 #include "absl/strings/string_view.h" 32 #include "absl/time/time.h" 33 #include "fcp/base/clock.h" 34 #include "fcp/base/monitoring.h" 35 #include "fcp/base/wall_clock_stopwatch.h" 36 #include "fcp/client/cache/resource_cache.h" 37 #include "fcp/client/engine/engine.pb.h" 38 #include "fcp/client/federated_protocol.h" 39 #include "fcp/client/fl_runner.pb.h" 40 #include "fcp/client/flags.h" 41 #include "fcp/client/http/http_client.h" 42 #include "fcp/client/http/in_memory_request_response.h" 43 #include "fcp/client/http/protocol_request_helper.h" 44 #include "fcp/client/interruptible_runner.h" 45 #include "fcp/client/log_manager.h" 46 #include "fcp/client/secagg_runner.h" 47 #include "fcp/client/selector_context.pb.h" 48 #include "fcp/client/stats.h" 49 #include "fcp/protos/federated_api.pb.h" 50 #include "fcp/protos/federatedcompute/common.pb.h" 51 #include "fcp/protos/federatedcompute/eligibility_eval_tasks.pb.h" 52 #include "fcp/protos/federatedcompute/secure_aggregations.pb.h" 53 #include "fcp/protos/federatedcompute/task_assignments.pb.h" 54 #include "fcp/protos/plan.pb.h" 55 #include "fcp/secagg/client/secagg_client.h" 56 57 namespace fcp { 58 namespace client { 59 namespace http { 60 61 // Implements a single session of the HTTP-based Federated Compute protocol. 62 class HttpFederatedProtocol : public fcp::client::FederatedProtocol { 63 public: 64 HttpFederatedProtocol( 65 Clock* clock, LogManager* log_manager, const Flags* flags, 66 HttpClient* http_client, 67 std::unique_ptr<SecAggRunnerFactory> secagg_runner_factory, 68 SecAggEventPublisher* secagg_event_publisher, 69 absl::string_view entry_point_uri, absl::string_view api_key, 70 absl::string_view population_name, absl::string_view retry_token, 71 absl::string_view client_version, 72 absl::string_view attestation_measurement, 73 std::function<bool()> should_abort, absl::BitGen bit_gen, 74 const InterruptibleRunner::TimingConfig& timing_config, 75 cache::ResourceCache* resource_cache); 76 77 ~HttpFederatedProtocol() override = default; 78 79 absl::StatusOr<fcp::client::FederatedProtocol::EligibilityEvalCheckinResult> 80 EligibilityEvalCheckin(std::function<void(const EligibilityEvalTask&)> 81 payload_uris_received_callback) override; 82 83 void ReportEligibilityEvalError(absl::Status error_status) override; 84 85 absl::StatusOr<fcp::client::FederatedProtocol::CheckinResult> Checkin( 86 const std::optional< 87 google::internal::federatedml::v2::TaskEligibilityInfo>& 88 task_eligibility_info, 89 std::function<void(const TaskAssignment&)> payload_uris_received_callback) 90 override; 91 92 absl::StatusOr<MultipleTaskAssignments> PerformMultipleTaskAssignments( 93 const std::vector<std::string>& task_names) override; 94 95 absl::Status ReportCompleted( 96 ComputationResults results, absl::Duration plan_duration, 97 std::optional<std::string> aggregation_session_id) override; 98 99 absl::Status ReportNotCompleted( 100 engine::PhaseOutcome phase_outcome, absl::Duration plan_duration, 101 std::optional<std::string> aggregation_session_id) override; 102 103 google::internal::federatedml::v2::RetryWindow GetLatestRetryWindow() 104 override; 105 106 NetworkStats GetNetworkStats() override; 107 108 private: 109 // Information for a given task. 110 struct PerTaskInfo { 111 std::unique_ptr<ProtocolRequestCreator> aggregation_request_creator; 112 std::unique_ptr<ProtocolRequestCreator> data_upload_request_creator; 113 std::string session_id; 114 // The identifier of the aggregation session we are participating in. 115 std::string aggregation_session_id; 116 // The token authorizing the client to participate in an aggregation 117 // session. 118 std::string aggregation_authorization_token; 119 // The name identifying the task that was assigned. 120 std::string task_name; 121 // Unique identifier for the client's participation in an aggregation 122 // session. 123 std::string aggregation_client_token; 124 // Resource name for the checkpoint in simple aggregation. 125 std::string aggregation_resource_name; 126 // Each task's state is tracked individually starting from the end of 127 // check-in or multiple task assignments. The states from all of the tasks 128 // will be used collectively to determine which retry window to use. 129 ObjectState state = ObjectState::kInitialized; 130 }; 131 132 // Helper function to perform an eligibility eval task request and get its 133 // response. 134 absl::StatusOr<InMemoryHttpResponse> PerformEligibilityEvalTaskRequest(); 135 136 // Helper function for handling an eligibility eval task response (incl. 137 // fetching any resources, if necessary). 138 absl::StatusOr<fcp::client::FederatedProtocol::EligibilityEvalCheckinResult> 139 HandleEligibilityEvalTaskResponse( 140 absl::StatusOr<InMemoryHttpResponse> http_response, 141 std::function<void(const EligibilityEvalTask&)> 142 payload_uris_received_callback); 143 144 absl::StatusOr<std::unique_ptr<HttpRequest>> 145 CreateReportEligibilityEvalTaskResultRequest(absl::Status status); 146 147 // Helper function to perform an ReportEligibilityEvalResult request. 148 absl::Status ReportEligibilityEvalErrorInternal(absl::Status error_status); 149 150 // Helper function to perform a task assignment request and get its response. 151 absl::StatusOr<InMemoryHttpResponse> 152 PerformTaskAssignmentAndReportEligibilityEvalResultRequests( 153 const std::optional< 154 ::google::internal::federatedml::v2::TaskEligibilityInfo>& 155 task_eligibility_info); 156 157 // Helper function for handling the 'outer' task assignment response, which 158 // consists of an `Operation` which may or may not need to be polled before a 159 // final 'inner' response is available. 160 absl::StatusOr<::fcp::client::FederatedProtocol::CheckinResult> 161 HandleTaskAssignmentOperationResponse( 162 absl::StatusOr<InMemoryHttpResponse> http_response, 163 std::function<void(const TaskAssignment&)> 164 payload_uris_received_callback); 165 166 // Helper function for handling an 'inner' task assignment response (i.e. 167 // after the outer `Operation` has concluded). This includes fetching any 168 // resources, if necessary. 169 absl::StatusOr<::fcp::client::FederatedProtocol::CheckinResult> 170 HandleTaskAssignmentInnerResponse( 171 const google::internal::federatedcompute::v1::StartTaskAssignmentResponse& 172 response_proto, 173 std::function<void(const TaskAssignment&)> 174 payload_uris_received_callback); 175 176 // Helper function for reporting result via simple aggregation. 177 absl::Status ReportViaSimpleAggregation(ComputationResults results, 178 absl::Duration plan_duration, 179 PerTaskInfo& task_info); 180 // Helper function to perform a StartDataUploadRequest and a ReportTaskResult 181 // request concurrently. 182 // This method will only return the response from the StartDataUploadRequest. 183 absl::StatusOr<InMemoryHttpResponse> 184 PerformStartDataUploadRequestAndReportTaskResult(absl::Duration plan_duration, 185 PerTaskInfo& task_info); 186 187 // Helper function for handling a longrunning operation returned by a 188 // StartDataAggregationUpload request. 189 absl::Status HandleStartDataAggregationUploadOperationResponse( 190 absl::StatusOr<InMemoryHttpResponse> http_response, 191 PerTaskInfo& task_info); 192 193 // Helper function to perform data upload via simple aggregation. 194 absl::Status UploadDataViaSimpleAgg(std::string tf_checkpoint, 195 PerTaskInfo& task_info); 196 197 // Helper function to perform a SubmitAggregationResult request. 198 absl::Status SubmitAggregationResult(PerTaskInfo& task_info); 199 200 // Helper function to perform an AbortAggregation request. 201 // We only provide the server with a simplified error message. 202 absl::Status AbortAggregation(absl::Status original_error_status, 203 absl::string_view error_message_for_server, 204 PerTaskInfo& task_info); 205 206 // Helper function for reporting via secure aggregation. 207 absl::Status ReportViaSecureAggregation(ComputationResults results, 208 absl::Duration plan_duration, 209 PerTaskInfo& task_info); 210 211 // Helper function to perform a StartSecureAggregationRequest and a 212 // ReportTaskResultRequest. 213 absl::StatusOr< 214 google::internal::federatedcompute::v1::StartSecureAggregationResponse> 215 StartSecureAggregationAndReportTaskResult(absl::Duration plan_duration, 216 PerTaskInfo& task_info); 217 218 struct TaskResources { 219 const ::google::internal::federatedcompute::v1::Resource& plan; 220 const ::google::internal::federatedcompute::v1::Resource& checkpoint; 221 }; 222 223 // Helper function for fetching the checkpoint/plan resources for an 224 // eligibility eval task or regular task. 225 absl::StatusOr<PlanAndCheckpointPayloads> FetchTaskResources( 226 TaskResources task_resources); 227 228 // Helper function for fetching the PopulationEligibilitySpec. 229 absl::StatusOr< 230 google::internal::federatedcompute::v1::PopulationEligibilitySpec> 231 FetchPopulationEligibilitySpec( 232 const ::google::internal::federatedcompute::v1::Resource& 233 population_eligibility_spec_resource); 234 235 // Helper that moves to the given object state if the given status represents 236 // a permanent error. 237 void UpdateObjectStateIfPermanentError( 238 absl::Status status, ObjectState permanent_error_object_state); 239 240 ObjectState GetTheLatestStateFromAllTasks(); 241 242 // This ObjectState tracks states until the end of check-in or multiple task 243 // assignments. Once a task is assigned, the state is tracked inside the 244 // task_info_map_ for multiple task assignments or default_task_info_ for 245 // single task check-in. 246 ObjectState object_state_; 247 Clock& clock_; 248 LogManager* log_manager_; 249 const Flags* const flags_; 250 HttpClient* const http_client_; 251 std::unique_ptr<SecAggRunnerFactory> secagg_runner_factory_; 252 SecAggEventPublisher* secagg_event_publisher_; 253 std::unique_ptr<InterruptibleRunner> interruptible_runner_; 254 std::unique_ptr<ProtocolRequestCreator> eligibility_eval_request_creator_; 255 std::unique_ptr<ProtocolRequestCreator> task_assignment_request_creator_; 256 std::unique_ptr<WallClockStopwatch> network_stopwatch_ = 257 WallClockStopwatch::Create(); 258 ProtocolRequestHelper protocol_request_helper_; 259 const std::string api_key_; 260 const std::string population_name_; 261 const std::string retry_token_; 262 const std::string client_version_; 263 const std::string attestation_measurement_; 264 std::function<bool()> should_abort_; 265 absl::BitGen bit_gen_; 266 const InterruptibleRunner::TimingConfig timing_config_; 267 // The graceful waiting period for cancellation requests before checking 268 // whether the client should be interrupted. 269 const absl::Duration waiting_period_for_cancellation_; 270 // The set of canonical error codes that should be treated as 'permanent' 271 // errors. 272 absl::flat_hash_set<int32_t> federated_training_permanent_error_codes_; 273 int64_t bytes_downloaded_ = 0; 274 int64_t bytes_uploaded_ = 0; 275 // Represents 2 absolute retry timestamps to use when the device is rejected 276 // or accepted. The retry timestamps will have been generated based on the 277 // retry windows specified in the server's EligibilityEvalTaskResponse message 278 // and the time at which that message was received. 279 struct RetryTimes { 280 absl::Time retry_time_if_rejected; 281 absl::Time retry_time_if_accepted; 282 }; 283 // Represents the information received via the EligibilityEvalTaskResponse 284 // message. This field will have an absent value until that message has been 285 // received. 286 std::optional<RetryTimes> retry_times_; 287 std::string pre_task_assignment_session_id_; 288 289 // A map of aggregation_session_id to per-task information. 290 // Only tasks from the multiple task assignments will be tracked in this map. 291 absl::flat_hash_map<std::string, PerTaskInfo> task_info_map_; 292 // The task received from the regular check-in will be tracked here. 293 PerTaskInfo default_task_info_; 294 295 // Set this field to true if an eligibility eval task was received from the 296 // server in the EligibilityEvalTaskResponse. 297 bool eligibility_eval_enabled_ = false; 298 // `nullptr` if the feature is disabled. 299 cache::ResourceCache* resource_cache_; 300 }; 301 302 } // namespace http 303 } // namespace client 304 } // namespace fcp 305 306 #endif // FCP_CLIENT_HTTP_HTTP_FEDERATED_PROTOCOL_H_ 307