xref: /aosp_15_r20/external/federated-compute/fcp/client/http/http_federated_protocol.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2022 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "fcp/client/http/http_federated_protocol.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <functional>
21 #include <memory>
22 #include <optional>
23 #include <string>
24 #include <utility>
25 #include <variant>
26 #include <vector>
27 
28 // #include "google/longrunning/operations.pb.h"
29 #include "google/protobuf/any.pb.h"
30 // #include "google/rpc/code.pb.h"
31 #include "absl/container/flat_hash_set.h"
32 #include "absl/random/random.h"
33 #include "absl/status/status.h"
34 #include "absl/status/statusor.h"
35 #include "absl/strings/ascii.h"
36 #include "absl/strings/cord.h"
37 #include "absl/strings/str_cat.h"
38 #include "absl/strings/string_view.h"
39 #include "absl/strings/substitute.h"
40 #include "absl/time/time.h"
41 #include "fcp/base/clock.h"
42 #include "fcp/base/monitoring.h"
43 #include "fcp/base/time_util.h"
44 #include "fcp/base/wall_clock_stopwatch.h"
45 #include "fcp/client/diag_codes.pb.h"
46 #include "fcp/client/engine/engine.pb.h"
47 #include "fcp/client/federated_protocol.h"
48 #include "fcp/client/federated_protocol_util.h"
49 #include "fcp/client/fl_runner.pb.h"
50 #include "fcp/client/flags.h"
51 #include "fcp/client/http/http_client.h"
52 #include "fcp/client/http/http_client_util.h"
53 #include "fcp/client/http/http_secagg_send_to_server_impl.h"
54 #include "fcp/client/http/in_memory_request_response.h"
55 #include "fcp/client/interruptible_runner.h"
56 #include "fcp/client/log_manager.h"
57 #include "fcp/client/parsing_utils.h"
58 #include "fcp/client/stats.h"
59 #include "fcp/protos/federated_api.pb.h"
60 #include "fcp/protos/federatedcompute/aggregations.pb.h"
61 #include "fcp/protos/federatedcompute/common.pb.h"
62 #include "fcp/protos/federatedcompute/eligibility_eval_tasks.pb.h"
63 #include "fcp/protos/federatedcompute/secure_aggregations.pb.h"
64 #include "fcp/protos/federatedcompute/task_assignments.pb.h"
65 #include "fcp/protos/plan.pb.h"
66 
67 namespace fcp {
68 namespace client {
69 namespace http {
70 namespace {
71 
72 using ::fcp::client::GenerateRetryWindowFromRetryTime;
73 using ::fcp::client::GenerateRetryWindowFromTargetDelay;
74 using ::fcp::client::PickRetryTimeFromRange;
75 using ::google::internal::federatedcompute::v1::AbortAggregationRequest;
76 using ::google::internal::federatedcompute::v1::ClientStats;
77 using ::google::internal::federatedcompute::v1::EligibilityEvalTaskRequest;
78 using ::google::internal::federatedcompute::v1::EligibilityEvalTaskResponse;
79 using ::google::internal::federatedcompute::v1::PopulationEligibilitySpec;
80 using ::google::internal::federatedcompute::v1::
81     ReportEligibilityEvalTaskResultRequest;
82 using ::google::internal::federatedcompute::v1::ReportTaskResultRequest;
83 using ::google::internal::federatedcompute::v1::Resource;
84 using ::google::internal::federatedcompute::v1::ResourceCompressionFormat;
85 using ::google::internal::federatedcompute::v1::
86     SecureAggregationProtocolExecutionInfo;
87 using ::google::internal::federatedcompute::v1::
88     StartAggregationDataUploadRequest;
89 using ::google::internal::federatedcompute::v1::
90     StartAggregationDataUploadResponse;
91 using ::google::internal::federatedcompute::v1::StartSecureAggregationRequest;
92 using ::google::internal::federatedcompute::v1::StartSecureAggregationResponse;
93 using ::google::internal::federatedcompute::v1::StartTaskAssignmentRequest;
94 using ::google::internal::federatedcompute::v1::StartTaskAssignmentResponse;
95 using ::google::internal::federatedcompute::v1::SubmitAggregationResultRequest;
96 using ::google::internal::federatedml::v2::TaskEligibilityInfo;
97 // using ::google::longrunning::Operation;
98 
99 using CompressionFormat =
100     ::fcp::client::http::UriOrInlineData::InlineData::CompressionFormat;
101 
102 // Creates the URI suffix for a RequestEligibilityEvalTask protocol request.
CreateRequestEligibilityEvalTaskUriSuffix(absl::string_view population_name)103 absl::StatusOr<std::string> CreateRequestEligibilityEvalTaskUriSuffix(
104     absl::string_view population_name) {
105   constexpr absl::string_view kRequestEligibilityEvalTaskUriSuffix =
106       "/v1/eligibilityevaltasks/$0:request";
107   FCP_ASSIGN_OR_RETURN(std::string encoded_population_name,
108                        EncodeUriSinglePathSegment(population_name));
109   return absl::Substitute(kRequestEligibilityEvalTaskUriSuffix,
110                           encoded_population_name);
111 }
112 
113 // Creates the URI suffix for a ReportEligibilityEvalTaskResult protocol
114 // request.
CreateReportEligibilityEvalTaskResultUriSuffix(absl::string_view population_name,absl::string_view session_id)115 absl::StatusOr<std::string> CreateReportEligibilityEvalTaskResultUriSuffix(
116     absl::string_view population_name, absl::string_view session_id) {
117   constexpr absl::string_view kReportEligibilityEvalTaskResultUriSuffix =
118       "/v1/populations/$0/eligibilityevaltasks/$1:reportresult";
119   FCP_ASSIGN_OR_RETURN(std::string encoded_population_name,
120                        EncodeUriSinglePathSegment(population_name));
121   FCP_ASSIGN_OR_RETURN(std::string encoded_session_id,
122                        EncodeUriSinglePathSegment(session_id));
123   return absl::Substitute(kReportEligibilityEvalTaskResultUriSuffix,
124                           encoded_population_name, encoded_session_id);
125 }
126 
127 // Creates the URI suffix for a StartTaskAssignment protocol request.
CreateStartTaskAssignmentUriSuffix(absl::string_view population_name,absl::string_view session_id)128 absl::StatusOr<std::string> CreateStartTaskAssignmentUriSuffix(
129     absl::string_view population_name, absl::string_view session_id) {
130   constexpr absl::string_view kStartTaskAssignmentUriSuffix =
131       "/v1/populations/$0/taskassignments/$1:start";
132   FCP_ASSIGN_OR_RETURN(std::string encoded_population_name,
133                        EncodeUriSinglePathSegment(population_name));
134   FCP_ASSIGN_OR_RETURN(std::string encoded_session_id,
135                        EncodeUriSinglePathSegment(session_id));
136   return absl::Substitute(kStartTaskAssignmentUriSuffix,
137                           encoded_population_name, encoded_session_id);
138 }
139 
140 // Creates he URI suffix for a ReportTaskResult protocol request.
CreateReportTaskResultUriSuffix(absl::string_view population_name,absl::string_view session_id)141 absl::StatusOr<std::string> CreateReportTaskResultUriSuffix(
142     absl::string_view population_name, absl::string_view session_id) {
143   constexpr absl::string_view pattern =
144       "/v1/populations/$0/taskassignments/$1:reportresult";
145   FCP_ASSIGN_OR_RETURN(std::string encoded_population_name,
146                        EncodeUriSinglePathSegment(population_name));
147   FCP_ASSIGN_OR_RETURN(std::string encoded_session_id,
148                        EncodeUriSinglePathSegment(session_id));
149   // Construct the URI suffix.
150   return absl::Substitute(pattern, encoded_population_name, encoded_session_id);
151 }
152 
CreateStartAggregationDataUploadUriSuffix(absl::string_view aggregation_id,absl::string_view client_token)153 absl::StatusOr<std::string> CreateStartAggregationDataUploadUriSuffix(
154     absl::string_view aggregation_id, absl::string_view client_token) {
155   constexpr absl::string_view pattern =
156       "/v1/aggregations/$0/clients/$1:startdataupload";
157   FCP_ASSIGN_OR_RETURN(std::string encoded_aggregation_id,
158                        EncodeUriSinglePathSegment(aggregation_id));
159   FCP_ASSIGN_OR_RETURN(std::string encoded_client_token,
160                        EncodeUriSinglePathSegment(client_token));
161   // Construct the URI suffix.
162   return absl::Substitute(pattern, encoded_aggregation_id,
163                           encoded_client_token);
164 }
165 
CreateSubmitAggregationResultUriSuffix(absl::string_view aggregation_id,absl::string_view client_token)166 absl::StatusOr<std::string> CreateSubmitAggregationResultUriSuffix(
167     absl::string_view aggregation_id, absl::string_view client_token) {
168   constexpr absl::string_view pattern = "/v1/aggregations/$0/clients/$1:submit";
169   FCP_ASSIGN_OR_RETURN(std::string encoded_aggregation_id,
170                        EncodeUriSinglePathSegment(aggregation_id));
171   FCP_ASSIGN_OR_RETURN(std::string encoded_client_token,
172                        EncodeUriSinglePathSegment(client_token));
173   // Construct the URI suffix.
174   return absl::Substitute(pattern, encoded_aggregation_id,
175                           encoded_client_token);
176 }
177 
CreateAbortAggregationUriSuffix(absl::string_view aggregation_id,absl::string_view client_token)178 absl::StatusOr<std::string> CreateAbortAggregationUriSuffix(
179     absl::string_view aggregation_id, absl::string_view client_token) {
180   constexpr absl::string_view pattern = "/v1/aggregations/$0/clients/$1:abort";
181   FCP_ASSIGN_OR_RETURN(std::string encoded_aggregation_id,
182                        EncodeUriSinglePathSegment(aggregation_id));
183   FCP_ASSIGN_OR_RETURN(std::string encoded_client_token,
184                        EncodeUriSinglePathSegment(client_token));
185   // Construct the URI suffix.
186   return absl::Substitute(pattern, encoded_aggregation_id,
187                           encoded_client_token);
188 }
189 
CreateStartSecureAggregationUriSuffix(absl::string_view aggregation_id,absl::string_view client_token)190 absl::StatusOr<std::string> CreateStartSecureAggregationUriSuffix(
191     absl::string_view aggregation_id, absl::string_view client_token) {
192   constexpr absl::string_view pattern =
193       "/v1/secureaggregations/$0/clients/$1:start";
194   FCP_ASSIGN_OR_RETURN(std::string encoded_aggregation_id,
195                        EncodeUriSinglePathSegment(aggregation_id));
196   FCP_ASSIGN_OR_RETURN(std::string encoded_client_token,
197                        EncodeUriSinglePathSegment(client_token));
198   // Construct the URI suffix.
199   return absl::Substitute(pattern, encoded_aggregation_id,
200                           encoded_client_token);
201 }
202 
203 // Convert a Resource proto into a UriOrInlineData object. Returns an
204 // `INVALID_ARGUMENT` error if the given `Resource` has the `uri` field set to
205 // an empty value, or an `UNIMPLEMENTED` error if the `Resource` has an unknown
206 // field set.
ConvertResourceToUriOrInlineData(const Resource & resource)207 absl::StatusOr<UriOrInlineData> ConvertResourceToUriOrInlineData(
208     const Resource& resource) {
209   switch (resource.resource_case()) {
210     case Resource::ResourceCase::kUri:
211       if (resource.uri().empty()) {
212         return absl::InvalidArgumentError(
213             "Resource.uri must be non-empty when set");
214       }
215       return UriOrInlineData::CreateUri(
216           resource.uri(), resource.client_cache_id(),
217           TimeUtil::ConvertProtoToAbslDuration(resource.max_age()));
218     case Resource::ResourceCase::kInlineResource: {
219       CompressionFormat compression_format = CompressionFormat::kUncompressed;
220       if (resource.inline_resource().has_compression_format()) {
221         switch (resource.inline_resource().compression_format()) {
222           case ResourceCompressionFormat::RESOURCE_COMPRESSION_FORMAT_GZIP:
223             compression_format = CompressionFormat::kGzip;
224             break;
225           default:
226             return absl::UnimplementedError(
227                 "Unknown ResourceCompressionFormat");
228         }
229       }
230       return UriOrInlineData::CreateInlineData(
231           absl::Cord(resource.inline_resource().data()), compression_format);
232     }
233     case Resource::ResourceCase::RESOURCE_NOT_SET:
234       // If neither field is set at all, we'll just act as if we got an empty
235       // inline data field.
236       return UriOrInlineData::CreateInlineData(
237           absl::Cord(), CompressionFormat::kUncompressed);
238     default:
239       return absl::UnimplementedError("Unknown Resource type");
240   }
241 }
242 
ConvertPhaseOutcomeToRpcCode(engine::PhaseOutcome phase_outcome)243 ::google::internal::federatedcompute::v1::Code ConvertPhaseOutcomeToRpcCode(
244     engine::PhaseOutcome phase_outcome) {
245   switch (phase_outcome) {
246     case engine::PhaseOutcome::COMPLETED:
247       return ::google::internal::federatedcompute::v1::Code::OK;
248     case engine::PhaseOutcome::ERROR:
249       return ::google::internal::federatedcompute::v1::Code::INTERNAL;
250     case engine::PhaseOutcome::INTERRUPTED:
251       return ::google::internal::federatedcompute::v1::Code::CANCELLED;
252     default:
253       return ::google::internal::federatedcompute::v1::Code::UNKNOWN;
254   }
255 }
256 
CreateReportTaskResultRequest(engine::PhaseOutcome phase_outcome,absl::Duration plan_duration,absl::string_view aggregation_id,absl::string_view task_name)257 absl::StatusOr<ReportTaskResultRequest> CreateReportTaskResultRequest(
258     engine::PhaseOutcome phase_outcome, absl::Duration plan_duration,
259     absl::string_view aggregation_id, absl::string_view task_name) {
260   ReportTaskResultRequest request;
261   request.set_aggregation_id(std::string(aggregation_id));
262   request.set_task_name(std::string(task_name));
263   request.set_computation_status_code(
264       ConvertPhaseOutcomeToRpcCode(phase_outcome));
265   ClientStats* client_stats = request.mutable_client_stats();
266   *client_stats->mutable_computation_execution_duration() =
267       TimeUtil::ConvertAbslToProtoDuration(plan_duration);
268   return request;
269 }
270 
271 // Creates a special InterruptibleRunner which won't check the should_abort
272 // function until the timeout duration is passed.  This special
273 // InterruptibleRunner is used to issue Cancellation requests or Abort requests.
CreateDelayedInterruptibleRunner(LogManager * log_manager,std::function<bool ()> should_abort,const InterruptibleRunner::TimingConfig & timing_config,absl::Time deadline)274 std::unique_ptr<InterruptibleRunner> CreateDelayedInterruptibleRunner(
275     LogManager* log_manager, std::function<bool()> should_abort,
276     const InterruptibleRunner::TimingConfig& timing_config,
277     absl::Time deadline) {
278   return std::make_unique<InterruptibleRunner>(
279       log_manager,
280       [deadline, should_abort]() {
281         return absl::Now() > deadline && should_abort();
282       },
283       timing_config,
284       InterruptibleRunner::DiagnosticsConfig{
285           .interrupted = ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP,
286           .interrupt_timeout =
287               ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP_TIMED_OUT,
288           .interrupted_extended = ProdDiagCode::
289               BACKGROUND_TRAINING_INTERRUPT_HTTP_EXTENDED_COMPLETED,
290           .interrupt_timeout_extended = ProdDiagCode::
291               BACKGROUND_TRAINING_INTERRUPT_HTTP_EXTENDED_TIMED_OUT});
292 }
293 }  // namespace
294 
HttpFederatedProtocol(Clock * clock,LogManager * log_manager,const Flags * flags,HttpClient * http_client,std::unique_ptr<SecAggRunnerFactory> secagg_runner_factory,SecAggEventPublisher * secagg_event_publisher,absl::string_view entry_point_uri,absl::string_view api_key,absl::string_view population_name,absl::string_view retry_token,absl::string_view client_version,absl::string_view attestation_measurement,std::function<bool ()> should_abort,absl::BitGen bit_gen,const InterruptibleRunner::TimingConfig & timing_config,cache::ResourceCache * resource_cache)295 HttpFederatedProtocol::HttpFederatedProtocol(
296     Clock* clock, LogManager* log_manager, const Flags* flags,
297     HttpClient* http_client,
298     std::unique_ptr<SecAggRunnerFactory> secagg_runner_factory,
299     SecAggEventPublisher* secagg_event_publisher,
300     absl::string_view entry_point_uri, absl::string_view api_key,
301     absl::string_view population_name, absl::string_view retry_token,
302     absl::string_view client_version, absl::string_view attestation_measurement,
303     std::function<bool()> should_abort, absl::BitGen bit_gen,
304     const InterruptibleRunner::TimingConfig& timing_config,
305     cache::ResourceCache* resource_cache)
306     : object_state_(ObjectState::kInitialized),
307       clock_(*clock),
308       log_manager_(log_manager),
309       flags_(flags),
310       http_client_(http_client),
311       secagg_runner_factory_(std::move(secagg_runner_factory)),
312       secagg_event_publisher_(secagg_event_publisher),
313       interruptible_runner_(std::make_unique<InterruptibleRunner>(
314           log_manager, should_abort, timing_config,
315           InterruptibleRunner::DiagnosticsConfig{
316               .interrupted = ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP,
317               .interrupt_timeout =
318                   ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP_TIMED_OUT,
319               .interrupted_extended = ProdDiagCode::
320                   BACKGROUND_TRAINING_INTERRUPT_HTTP_EXTENDED_COMPLETED,
321               .interrupt_timeout_extended = ProdDiagCode::
322                   BACKGROUND_TRAINING_INTERRUPT_HTTP_EXTENDED_TIMED_OUT})),
323       eligibility_eval_request_creator_(
324           std::make_unique<ProtocolRequestCreator>(
325               entry_point_uri, api_key, HeaderList{},
326               !flags->disable_http_request_body_compression())),
327       protocol_request_helper_(http_client, &bytes_downloaded_,
328                                &bytes_uploaded_, network_stopwatch_.get(),
329                                clock),
330       api_key_(api_key),
331       population_name_(population_name),
332       retry_token_(retry_token),
333       client_version_(client_version),
334       attestation_measurement_(attestation_measurement),
335       should_abort_(std::move(should_abort)),
336       bit_gen_(std::move(bit_gen)),
337       timing_config_(timing_config),
338       waiting_period_for_cancellation_(
339           absl::Seconds(flags->waiting_period_sec_for_cancellation())),
340       resource_cache_(resource_cache) {
341   // Note that we could cast the provided error codes to absl::StatusCode
342   // values here. However, that means we'd have to handle the case when
343   // invalid integers that don't map to a StatusCode enum are provided in the
344   // flag here. Instead, we cast absl::StatusCodes to int32_t each time we
345   // compare them with the flag-provided list of codes, which means we never
346   // have to worry about invalid flag values (besides the fact that invalid
347   // values will be silently ignored, which could make it harder to realize when
348   // a flag is misconfigured).
349   const std::vector<int32_t>& error_codes =
350       flags->federated_training_permanent_error_codes();
351   federated_training_permanent_error_codes_ =
352       absl::flat_hash_set<int32_t>(error_codes.begin(), error_codes.end());
353 }
354 
355 absl::StatusOr<FederatedProtocol::EligibilityEvalCheckinResult>
EligibilityEvalCheckin(std::function<void (const EligibilityEvalTask &)> payload_uris_received_callback)356 HttpFederatedProtocol::EligibilityEvalCheckin(
357     std::function<void(const EligibilityEvalTask&)>
358         payload_uris_received_callback) {
359   FCP_CHECK(object_state_ == ObjectState::kInitialized)
360       << "Invalid call sequence";
361   object_state_ = ObjectState::kEligibilityEvalCheckinFailed;
362 
363   // Send the request and parse the response.
364   auto response = HandleEligibilityEvalTaskResponse(
365       PerformEligibilityEvalTaskRequest(), payload_uris_received_callback);
366   // Update the object state to ensure we return the correct retry delay.
367   UpdateObjectStateIfPermanentError(
368       response.status(),
369       ObjectState::kEligibilityEvalCheckinFailedPermanentError);
370   if (response.ok() && std::holds_alternative<EligibilityEvalTask>(*response)) {
371     eligibility_eval_enabled_ = true;
372   }
373   return response;
374 }
375 
376 absl::StatusOr<InMemoryHttpResponse>
PerformEligibilityEvalTaskRequest()377 HttpFederatedProtocol::PerformEligibilityEvalTaskRequest() {
378   // Create and serialize the request body. Note that the `population_name`
379   // field is set in the URI instead of in this request proto message.
380   EligibilityEvalTaskRequest request;
381   request.mutable_client_version()->set_version_code(client_version_);
382   request.mutable_attestation_measurement()->set_value(
383       attestation_measurement_);
384 
385   request.mutable_resource_capabilities()->add_supported_compression_formats(
386       ResourceCompressionFormat::RESOURCE_COMPRESSION_FORMAT_GZIP);
387   request.mutable_eligibility_eval_task_capabilities()
388       ->set_supports_multiple_task_assignment(
389           flags_->http_protocol_supports_multiple_task_assignments());
390 
391   FCP_ASSIGN_OR_RETURN(
392       std::string uri_suffix,
393       CreateRequestEligibilityEvalTaskUriSuffix(population_name_));
394   FCP_ASSIGN_OR_RETURN(
395       std::unique_ptr<HttpRequest> http_request,
396       eligibility_eval_request_creator_->CreateProtocolRequest(
397           uri_suffix, {}, HttpRequest::Method::kPost,
398           request.SerializeAsString(), /*is_protobuf_encoded=*/true));
399 
400   // Issue the request.
401   return protocol_request_helper_.PerformProtocolRequest(
402       std::move(http_request), *interruptible_runner_);
403 }
404 
405 absl::StatusOr<FederatedProtocol::EligibilityEvalCheckinResult>
HandleEligibilityEvalTaskResponse(absl::StatusOr<InMemoryHttpResponse> http_response,std::function<void (const EligibilityEvalTask &)> payload_uris_received_callback)406 HttpFederatedProtocol::HandleEligibilityEvalTaskResponse(
407     absl::StatusOr<InMemoryHttpResponse> http_response,
408     std::function<void(const EligibilityEvalTask&)>
409         payload_uris_received_callback) {
410   if (!http_response.ok()) {
411     // If the protocol request failed then forward the error, but add a prefix
412     // to the error message to ensure we can easily distinguish an HTTP error
413     // occurring in response to the protocol request from HTTP errors occurring
414     // during checkpoint/plan resource fetch requests later on.
415     return absl::Status(http_response.status().code(),
416                         absl::StrCat("protocol request failed: ",
417                                      http_response.status().ToString()));
418   }
419 
420   EligibilityEvalTaskResponse response_proto;
421   if (!response_proto.ParseFromString(std::string(http_response->body))) {
422     return absl::InvalidArgumentError("Could not parse response_proto");
423   }
424 
425   // Upon receiving the server's RetryWindows we immediately choose a concrete
426   // target timestamp to retry at. This ensures that a) clients of this class
427   // don't have to implement the logic to select a timestamp from a min/max
428   // range themselves, b) we tell clients of this class to come back at exactly
429   // a point in time the server intended us to come at (i.e. "now +
430   // server_specified_retry_period", and not a point in time that is partly
431   // determined by how long the remaining protocol interactions (e.g. training
432   // and results upload) will take (i.e. "now +
433   // duration_of_remaining_protocol_interactions +
434   // server_specified_retry_period").
435   retry_times_ = RetryTimes{
436       .retry_time_if_rejected = PickRetryTimeFromRange(
437           response_proto.retry_window_if_rejected().delay_min(),
438           response_proto.retry_window_if_rejected().delay_max(), bit_gen_),
439       .retry_time_if_accepted = PickRetryTimeFromRange(
440           response_proto.retry_window_if_accepted().delay_min(),
441           response_proto.retry_window_if_accepted().delay_max(), bit_gen_)};
442 
443   // If the request was rejected then the protocol session has ended and there's
444   // no more work for us to do.
445   if (response_proto.has_rejection_info()) {
446     object_state_ = ObjectState::kEligibilityEvalCheckinRejected;
447     return Rejection{};
448   }
449 
450   pre_task_assignment_session_id_ = response_proto.session_id();
451 
452   FCP_ASSIGN_OR_RETURN(
453       task_assignment_request_creator_,
454       ProtocolRequestCreator::Create(
455           api_key_, response_proto.task_assignment_forwarding_info(),
456           !flags_->disable_http_request_body_compression()));
457 
458   switch (response_proto.result_case()) {
459     case EligibilityEvalTaskResponse::kEligibilityEvalTask: {
460       const auto& task = response_proto.eligibility_eval_task();
461 
462       EligibilityEvalTask result{.execution_id = task.execution_id()};
463       payload_uris_received_callback(result);
464 
465       // Fetch the task resources, returning any errors that may be encountered
466       // in the process.
467       FCP_ASSIGN_OR_RETURN(
468           result.payloads,
469           FetchTaskResources(
470               {.plan = task.plan(), .checkpoint = task.init_checkpoint()}));
471       if (task.has_population_eligibility_spec() &&
472           flags_->http_protocol_supports_multiple_task_assignments()) {
473         FCP_ASSIGN_OR_RETURN(
474             result.population_eligibility_spec,
475             FetchPopulationEligibilitySpec(task.population_eligibility_spec()));
476       }
477 
478       object_state_ = ObjectState::kEligibilityEvalEnabled;
479       return std::move(result);
480     }
481     case EligibilityEvalTaskResponse::kNoEligibilityEvalConfigured: {
482       // Nothing to do...
483       object_state_ = ObjectState::kEligibilityEvalDisabled;
484       return EligibilityEvalDisabled{};
485     }
486     default:
487       return absl::UnimplementedError(
488           "Unrecognized EligibilityEvalCheckinResponse");
489   }
490 }
491 
492 absl::StatusOr<std::unique_ptr<HttpRequest>>
CreateReportEligibilityEvalTaskResultRequest(absl::Status status)493 HttpFederatedProtocol::CreateReportEligibilityEvalTaskResultRequest(
494     absl::Status status) {
495   ReportEligibilityEvalTaskResultRequest request;
496   request.set_status_code(
497       static_cast<::google::internal::federatedcompute::v1::Code>(
498           status.code()));
499   FCP_ASSIGN_OR_RETURN(std::string uri_suffix,
500                        CreateReportEligibilityEvalTaskResultUriSuffix(
501                            population_name_, pre_task_assignment_session_id_));
502   return eligibility_eval_request_creator_->CreateProtocolRequest(
503       uri_suffix, QueryParams(), HttpRequest::Method::kPost,
504       request.SerializeAsString(),
505       /*is_protobuf_encoded=*/true);
506 }
507 
ReportEligibilityEvalError(absl::Status error_status)508 void HttpFederatedProtocol::ReportEligibilityEvalError(
509     absl::Status error_status) {
510   if (!ReportEligibilityEvalErrorInternal(error_status).ok()) {
511     log_manager_->LogDiag(
512         ProdDiagCode::HTTP_REPORT_ELIGIBILITY_EVAL_RESULT_REQUEST_FAILED);
513   }
514 }
515 
ReportEligibilityEvalErrorInternal(absl::Status error_status)516 absl::Status HttpFederatedProtocol::ReportEligibilityEvalErrorInternal(
517     absl::Status error_status) {
518   FCP_ASSIGN_OR_RETURN(
519       std::unique_ptr<HttpRequest> request,
520       CreateReportEligibilityEvalTaskResultRequest(error_status));
521   return protocol_request_helper_
522       .PerformProtocolRequest(std::move(request), *interruptible_runner_)
523       .status();
524 }
525 
Checkin(const std::optional<TaskEligibilityInfo> & task_eligibility_info,std::function<void (const TaskAssignment &)> payload_uris_received_callback)526 absl::StatusOr<FederatedProtocol::CheckinResult> HttpFederatedProtocol::Checkin(
527     const std::optional<TaskEligibilityInfo>& task_eligibility_info,
528     std::function<void(const TaskAssignment&)> payload_uris_received_callback) {
529   // Checkin(...) must follow an earlier call to EligibilityEvalCheckin() that
530   // resulted in a CheckinResultPayload or an EligibilityEvalDisabled result. Or
531   // it must follow a PerformMultipleTaskAssignments(...) regardless of the
532   // outcome for the call.
533   FCP_CHECK(object_state_ == ObjectState::kEligibilityEvalDisabled ||
534             object_state_ == ObjectState::kEligibilityEvalEnabled ||
535             object_state_ == ObjectState::kMultipleTaskAssignmentsAccepted ||
536             object_state_ == ObjectState::kMultipleTaskAssignmentsFailed ||
537             object_state_ ==
538                 ObjectState::kMultipleTaskAssignmentsFailedPermanentError ||
539             object_state_ ==
540                 ObjectState::kMultipleTaskAssignmentsNoAvailableTask)
541       << "Checkin(...) called despite failed/rejected earlier "
542          "EligibilityEvalCheckin";
543   if (object_state_ == ObjectState::kEligibilityEvalEnabled) {
544     FCP_CHECK(task_eligibility_info.has_value())
545         << "Missing TaskEligibilityInfo despite receiving prior "
546            "EligibilityEvalCheckin payload";
547   } else {
548     FCP_CHECK(!task_eligibility_info.has_value())
549         << "Received TaskEligibilityInfo despite not receiving a prior "
550            "EligibilityEvalCheckin payload";
551   }
552   object_state_ = ObjectState::kCheckinFailed;
553 
554   // Send the request and parse the response.
555   auto response = HandleTaskAssignmentOperationResponse(
556       PerformTaskAssignmentAndReportEligibilityEvalResultRequests(
557           task_eligibility_info),
558       payload_uris_received_callback);
559 
560   // Update the object state to ensure we return the correct retry delay.
561   UpdateObjectStateIfPermanentError(response.status(),
562                                     ObjectState::kCheckinFailedPermanentError);
563   return response;
564 }
565 
566 absl::StatusOr<InMemoryHttpResponse> HttpFederatedProtocol::
PerformTaskAssignmentAndReportEligibilityEvalResultRequests(const std::optional<TaskEligibilityInfo> & task_eligibility_info)567     PerformTaskAssignmentAndReportEligibilityEvalResultRequests(
568         const std::optional<TaskEligibilityInfo>& task_eligibility_info) {
569   // Create and serialize the request body. Note that the `population_name`
570   // and `session_id` fields are set in the URI instead of in this request
571   // proto message.
572   StartTaskAssignmentRequest request;
573   request.mutable_client_version()->set_version_code(client_version_);
574 
575   if (task_eligibility_info.has_value()) {
576     *request.mutable_task_eligibility_info() = *task_eligibility_info;
577   }
578 
579   request.mutable_resource_capabilities()->add_supported_compression_formats(
580       ResourceCompressionFormat::RESOURCE_COMPRESSION_FORMAT_GZIP);
581 
582   std::vector<std::unique_ptr<HttpRequest>> requests;
583 
584   // Construct the URI suffix.
585   FCP_ASSIGN_OR_RETURN(std::string task_assignment_uri_suffix,
586                        CreateStartTaskAssignmentUriSuffix(
587                            population_name_, pre_task_assignment_session_id_));
588   FCP_ASSIGN_OR_RETURN(
589       std::unique_ptr<HttpRequest> task_assignment_http_request,
590       task_assignment_request_creator_->CreateProtocolRequest(
591           task_assignment_uri_suffix, {}, HttpRequest::Method::kPost,
592           request.SerializeAsString(), /*is_protobuf_encoded=*/true));
593   requests.push_back(std::move(task_assignment_http_request));
594 
595   if (eligibility_eval_enabled_) {
596     FCP_ASSIGN_OR_RETURN(
597         std::unique_ptr<HttpRequest>
598             report_eligibility_eval_result_http_request,
599         CreateReportEligibilityEvalTaskResultRequest(absl::OkStatus()));
600     requests.push_back(std::move(report_eligibility_eval_result_http_request));
601   }
602 
603   // Issue the request.
604   FCP_ASSIGN_OR_RETURN(
605       std::vector<absl::StatusOr<InMemoryHttpResponse>> responses,
606       protocol_request_helper_.PerformMultipleProtocolRequests(
607           std::move(requests), *interruptible_runner_));
608   // The responses are returned in order. The first one is for the task
609   // assignment request. The second one (optional) is for the report eligibility
610   // eval task result request.  We only care about the first one.
611   if (eligibility_eval_enabled_ && !responses[1].ok()) {
612     log_manager_->LogDiag(
613         ProdDiagCode::HTTP_REPORT_ELIGIBILITY_EVAL_RESULT_REQUEST_FAILED);
614   }
615   return responses[0];
616 }
617 
618 absl::StatusOr<FederatedProtocol::CheckinResult>
HandleTaskAssignmentOperationResponse(absl::StatusOr<InMemoryHttpResponse> http_response,std::function<void (const TaskAssignment &)> payload_uris_received_callback)619 HttpFederatedProtocol::HandleTaskAssignmentOperationResponse(
620     absl::StatusOr<InMemoryHttpResponse> http_response,
621     std::function<void(const TaskAssignment&)> payload_uris_received_callback) {
622   // If the initial response was not successful, then return immediately, even
623   // if the result was CANCELLED, since we won't have received an operation name
624   // to issue a CancelOperationRequest with anyway.
625   FCP_RETURN_IF_ERROR(http_response);
626   StartTaskAssignmentResponse response_proto;
627   if (!response_proto.ParseFromString(std::string(http_response->body))) {
628     return absl::InvalidArgumentError(
629         "could not parse StartTaskAssignmentResponse proto");
630   }
631 
632   // absl::StatusOr<Operation> initial_operation =
633   //     ParseOperationProtoFromHttpResponse(http_response);
634   // if (!initial_operation.ok()) {
635   //   return absl::Status(initial_operation.status().code(),
636   //                       absl::StrCat("protocol request failed: ",
637   //                                    initial_operation.status().ToString()));
638   //   }
639   //   absl::StatusOr<Operation> response_operation_proto =
640   //       protocol_request_helper_.PollOperationResponseUntilDone(
641   //           *initial_operation, *task_assignment_request_creator_,
642   //           *interruptible_runner_);
643   //   if (!response_operation_proto.ok()) {
644   //     // If the protocol request failed then issue a cancellation request to
645   //     let
646   //     // the server know the operation will be abandoned, and forward the
647   //     error,
648   //     // but add a prefix to the error message to ensure we can easily
649   //     // distinguish an HTTP error occurring in response to the protocol
650   //     request
651   //     // from HTTP errors occurring during checkpoint/plan resource fetch
652   //     // requests later on.
653   //     FCP_ASSIGN_OR_RETURN(std::string operation_name,
654   //                          ExtractOperationName(*initial_operation));
655   //     // Client interruption
656   //     std::unique_ptr<InterruptibleRunner> cancellation_runner =
657   //         CreateDelayedInterruptibleRunner(
658   //             log_manager_, should_abort_, timing_config_,
659   //             absl::Now() + waiting_period_for_cancellation_);
660   //     if (!protocol_request_helper_
661   //              .CancelOperation(operation_name,
662   //                               *task_assignment_request_creator_,
663   //                               *cancellation_runner)
664   //              .ok()) {
665   //       log_manager_->LogDiag(
666   //           ProdDiagCode::HTTP_CANCELLATION_OR_ABORT_REQUEST_FAILED);
667   //     }
668   //     return absl::Status(
669   //         response_operation_proto.status().code(),
670   //         absl::StrCat("protocol request failed: ",
671   //                      response_operation_proto.status().ToString()));
672   //   }
673 
674   //   // The Operation has finished. Check if it resulted in an error, and if
675   //   so
676   //   // forward it after converting it to an absl::Status error.
677   //   if (response_operation_proto->has_error()) {
678   //     auto rpc_error =
679   //         ConvertRpcStatusToAbslStatus(response_operation_proto->error());
680   //     return absl::Status(
681   //         rpc_error.code(),
682   //         absl::StrCat("Operation ", response_operation_proto->name(),
683   //                      " contained error: ", rpc_error.ToString()));
684   //   }
685 
686   // Otherwise, handle the StartTaskAssignmentResponse that should have been
687   // returned by the Operation response proto.
688   return HandleTaskAssignmentInnerResponse(response_proto,
689                                            payload_uris_received_callback);
690 }
691 
692 absl::StatusOr<FederatedProtocol::CheckinResult>
HandleTaskAssignmentInnerResponse(const StartTaskAssignmentResponse & response_proto,std::function<void (const TaskAssignment &)> payload_uris_received_callback)693 HttpFederatedProtocol::HandleTaskAssignmentInnerResponse(
694     const StartTaskAssignmentResponse& response_proto,
695     std::function<void(const TaskAssignment&)> payload_uris_received_callback) {
696   // StartTaskAssignmentResponse response_proto;
697   // if (!operation_response.UnpackTo(&response_proto)) {
698   //   return absl::InvalidArgumentError(
699   //       "could not parse StartTaskAssignmentResponse proto");
700   // }
701   if (response_proto.has_rejection_info()) {
702     object_state_ = ObjectState::kCheckinRejected;
703     return Rejection{};
704   }
705   if (!response_proto.has_task_assignment()) {
706     return absl::UnimplementedError("Unrecognized StartTaskAssignmentResponse");
707   }
708   const auto& task_assignment = response_proto.task_assignment();
709 
710   FCP_ASSIGN_OR_RETURN(
711       default_task_info_.aggregation_request_creator,
712       ProtocolRequestCreator::Create(
713           api_key_, task_assignment.aggregation_data_forwarding_info(),
714           !flags_->disable_http_request_body_compression()));
715 
716   TaskAssignment result = {
717       .federated_select_uri_template =
718           task_assignment.federated_select_uri_info().uri_template(),
719       .aggregation_session_id = task_assignment.aggregation_id(),
720       .sec_agg_info = std::nullopt};
721   if (task_assignment.has_secure_aggregation_info()) {
722     result.sec_agg_info =
723         SecAggInfo{.minimum_clients_in_server_visible_aggregate =
724                        task_assignment.secure_aggregation_info()
725                            .minimum_clients_in_server_visible_aggregate()};
726   }
727 
728   payload_uris_received_callback(result);
729 
730   // Fetch the task resources, returning any errors that may be encountered in
731   // the process.
732   FCP_ASSIGN_OR_RETURN(
733       result.payloads,
734       FetchTaskResources({.plan = task_assignment.plan(),
735                           .checkpoint = task_assignment.init_checkpoint()}));
736 
737   object_state_ = ObjectState::kCheckinAccepted;
738   default_task_info_.state = ObjectState::kCheckinAccepted;
739   default_task_info_.session_id = task_assignment.session_id();
740   default_task_info_.aggregation_session_id = task_assignment.aggregation_id();
741   default_task_info_.aggregation_authorization_token =
742       task_assignment.authorization_token();
743   default_task_info_.task_name = task_assignment.task_name();
744 
745   return std::move(result);
746 }
747 
748 absl::StatusOr<FederatedProtocol::MultipleTaskAssignments>
PerformMultipleTaskAssignments(const std::vector<std::string> & task_names)749 HttpFederatedProtocol::PerformMultipleTaskAssignments(
750     const std::vector<std::string>& task_names) {
751   // PerformMultipleTaskAssignments(...) must follow an earlier call to
752   // EligibilityEvalCheckin() that resulted in a EligibilityEvalTask with
753   // PopulationEligibilitySpec.
754   FCP_CHECK(object_state_ == ObjectState::kEligibilityEvalDisabled ||
755             object_state_ == ObjectState::kEligibilityEvalEnabled)
756       << "PerformMultipleTaskAssignments(...) called despite failed/rejected "
757          "earlier "
758          "EligibilityEvalCheckin";
759   object_state_ = ObjectState::kMultipleTaskAssignmentsFailed;
760   return absl::UnimplementedError(
761       "PerformMultipleTaskAssignments is not yet implemented.");
762 }
763 
ReportCompleted(ComputationResults results,absl::Duration plan_duration,std::optional<std::string> aggregation_session_id)764 absl::Status HttpFederatedProtocol::ReportCompleted(
765     ComputationResults results, absl::Duration plan_duration,
766     std::optional<std::string> aggregation_session_id) {
767   FCP_LOG(INFO) << "Reporting outcome: " << static_cast<int>(engine::COMPLETED);
768   PerTaskInfo* task_info;
769   if (aggregation_session_id.has_value()) {
770     if (!task_info_map_.contains(aggregation_session_id.value())) {
771       return absl::InvalidArgumentError("Unexpected aggregation_session_id.");
772     }
773     task_info = &task_info_map_[aggregation_session_id.value()];
774   } else {
775     task_info = &default_task_info_;
776   }
777   FCP_CHECK(task_info->state == ObjectState::kCheckinAccepted ||
778             task_info->state == ObjectState::kMultipleTaskAssignmentsAccepted)
779       << "Invalid call sequence";
780   task_info->state = ObjectState::kReportCalled;
781   auto find_secagg_tensor_lambda = [](const auto& item) {
782     return std::holds_alternative<QuantizedTensor>(item.second);
783   };
784   if (std::find_if(results.begin(), results.end(), find_secagg_tensor_lambda) ==
785       results.end()) {
786     return ReportViaSimpleAggregation(std::move(results), plan_duration,
787                                       *task_info);
788   } else {
789     return ReportViaSecureAggregation(std::move(results), plan_duration,
790                                       *task_info);
791   }
792 }
793 
ReportViaSimpleAggregation(ComputationResults results,absl::Duration plan_duration,PerTaskInfo & task_info)794 absl::Status HttpFederatedProtocol::ReportViaSimpleAggregation(
795     ComputationResults results, absl::Duration plan_duration,
796     PerTaskInfo& task_info) {
797   if (results.size() != 1 ||
798       !std::holds_alternative<TFCheckpoint>(results.begin()->second)) {
799     return absl::InternalError(
800         "Simple Aggregation aggregands have unexpected format.");
801   }
802   auto start_upload_status = HandleStartDataAggregationUploadOperationResponse(
803       PerformStartDataUploadRequestAndReportTaskResult(plan_duration,
804                                                        task_info),
805       task_info);
806   if (!start_upload_status.ok()) {
807     task_info.state = ObjectState::kReportFailedPermanentError;
808     return start_upload_status;
809   }
810   auto upload_status = UploadDataViaSimpleAgg(
811       std::get<TFCheckpoint>(std::move(results.begin()->second)), task_info);
812   if (!upload_status.ok()) {
813     task_info.state = ObjectState::kReportFailedPermanentError;
814     if (upload_status.code() != absl::StatusCode::kAborted &&
815         !AbortAggregation(upload_status,
816                           "Upload data via simple aggregation failed.",
817                           task_info)
818              .ok()) {
819       log_manager_->LogDiag(
820           ProdDiagCode::HTTP_CANCELLATION_OR_ABORT_REQUEST_FAILED);
821     }
822     return upload_status;
823   }
824   return SubmitAggregationResult(task_info);
825 }
826 
827 absl::StatusOr<InMemoryHttpResponse>
PerformStartDataUploadRequestAndReportTaskResult(absl::Duration plan_duration,PerTaskInfo & task_info)828 HttpFederatedProtocol::PerformStartDataUploadRequestAndReportTaskResult(
829     absl::Duration plan_duration, PerTaskInfo& task_info) {
830   FCP_ASSIGN_OR_RETURN(
831       ReportTaskResultRequest report_task_result_request,
832       CreateReportTaskResultRequest(
833           engine::PhaseOutcome::COMPLETED, plan_duration,
834           task_info.aggregation_session_id, task_info.task_name));
835   FCP_ASSIGN_OR_RETURN(
836       std::string report_task_result_uri_suffix,
837       CreateReportTaskResultUriSuffix(population_name_, task_info.session_id));
838   FCP_ASSIGN_OR_RETURN(
839       std::unique_ptr<HttpRequest> http_report_task_result_request,
840       task_assignment_request_creator_->CreateProtocolRequest(
841           report_task_result_uri_suffix, {}, HttpRequest::Method::kPost,
842           report_task_result_request.SerializeAsString(),
843           /*is_protobuf_encoded=*/true));
844 
845   StartAggregationDataUploadRequest start_upload_request;
846   FCP_ASSIGN_OR_RETURN(std::string start_aggregation_data_upload_uri_suffix,
847                        CreateStartAggregationDataUploadUriSuffix(
848                            task_info.aggregation_session_id,
849                            task_info.aggregation_authorization_token));
850   FCP_ASSIGN_OR_RETURN(
851       std::unique_ptr<HttpRequest> http_start_aggregation_data_upload_request,
852       task_info.aggregation_request_creator->CreateProtocolRequest(
853           start_aggregation_data_upload_uri_suffix, {},
854           HttpRequest::Method::kPost, start_upload_request.SerializeAsString(),
855           /*is_protobuf_encoded=*/true));
856   FCP_LOG(INFO) << "StartAggregationDataUpload request uri is: "
857                 << http_start_aggregation_data_upload_request->uri();
858   FCP_LOG(INFO) << "ReportTaskResult request uri is: "
859                 << http_report_task_result_request->uri();
860   std::vector<std::unique_ptr<HttpRequest>> requests;
861   requests.push_back(std::move(http_start_aggregation_data_upload_request));
862   requests.push_back(std::move(http_report_task_result_request));
863   FCP_ASSIGN_OR_RETURN(
864       std::vector<absl::StatusOr<InMemoryHttpResponse>> responses,
865       protocol_request_helper_.PerformMultipleProtocolRequests(
866           std::move(requests), *interruptible_runner_));
867   // We should have two responses, otherwise we have made a developer error.
868   FCP_CHECK(responses.size() == 2);
869   // The responses are returned in order so the first response will be the one
870   // for StartAggregationDataUpload request.  We only care about this response,
871   // the ReportTaskResult request is just a best effort to report client metrics
872   // to the server, and we don't want to abort the aggregation even if it
873   // failed.
874   if (!responses[1].ok()) {
875     log_manager_->LogDiag(ProdDiagCode::HTTP_REPORT_TASK_RESULT_REQUEST_FAILED);
876   }
877   return responses[0];
878 }
879 
880 absl::Status
HandleStartDataAggregationUploadOperationResponse(absl::StatusOr<InMemoryHttpResponse> http_response,PerTaskInfo & task_info)881 HttpFederatedProtocol::HandleStartDataAggregationUploadOperationResponse(
882     absl::StatusOr<InMemoryHttpResponse> http_response,
883     PerTaskInfo& task_info) {
884   // absl::StatusOr<Operation> operation =
885   //     ParseOperationProtoFromHttpResponse(http_response);
886   // if (!operation.ok()) {
887   //   // If the protocol request failed then forward the error, but add a
888   //   prefix
889   //   // to the error message to ensure we can easily distinguish an HTTP error
890   //   // occurring in response to the protocol request from HTTP errors
891   //   // occurring during upload requests later on.
892   //   return absl::Status(
893   //       operation.status().code(),
894   //       absl::StrCat(
895   //           "StartAggregationDataUpload request failed during polling: ",
896   //           operation.status().ToString()));
897   // }
898   // absl::StatusOr<Operation> response_operation_proto =
899   //     protocol_request_helper_.PollOperationResponseUntilDone(
900   //         *operation, *task_info.aggregation_request_creator,
901   //         *interruptible_runner_);
902   // if (!response_operation_proto.ok()) {
903   //   return absl::Status(
904   //       response_operation_proto.status().code(),
905   //       absl::StrCat("StartAggregationDataUpload request failed: ",
906   //                    response_operation_proto.status().ToString()));
907   // }
908 
909   // // The Operation has finished. Check if it resulted in an error, and if so
910   // // forward it after converting it to an absl::Status error.
911   // if (response_operation_proto->has_error()) {
912   //   auto rpc_error =
913   //       ConvertRpcStatusToAbslStatus(response_operation_proto->error());
914   //   return absl::Status(
915   //       rpc_error.code(),
916   //       absl::StrCat("Operation ", response_operation_proto->name(),
917   //                    " contained error: ", rpc_error.ToString()));
918   // }
919 
920   // Otherwise, handle the StartDataAggregationUploadResponse that should have
921   // been returned by the Operation response proto.
922   FCP_RETURN_IF_ERROR(http_response);
923   StartAggregationDataUploadResponse response_proto;
924   if (!response_proto.ParseFromString(std::string(http_response->body))) {
925     return absl::InvalidArgumentError(
926         "could not parse StartTaskAssignmentResponse proto");
927   }
928 
929   // Note that we reassign `aggregation_request_creator_` because from this
930   // point onwards, subsequent aggregation protocol requests should go to the
931   // endpoint identified in the aggregation_protocol_forwarding_info.
932   FCP_ASSIGN_OR_RETURN(
933       task_info.aggregation_request_creator,
934       ProtocolRequestCreator::Create(
935           api_key_, response_proto.aggregation_protocol_forwarding_info(),
936           !flags_->disable_http_request_body_compression()));
937   auto upload_resource = response_proto.resource();
938   task_info.aggregation_resource_name = upload_resource.resource_name();
939   FCP_ASSIGN_OR_RETURN(
940       task_info.data_upload_request_creator,
941       ProtocolRequestCreator::Create(
942           api_key_, upload_resource.data_upload_forwarding_info(),
943           !flags_->disable_http_request_body_compression()));
944   // TODO(team): Remove the authorization token fallback once
945   // client_token is always populated.
946   task_info.aggregation_client_token =
947       !response_proto.client_token().empty()
948           ? response_proto.client_token()
949           : task_info.aggregation_authorization_token;
950   return absl::OkStatus();
951 }
952 
UploadDataViaSimpleAgg(std::string tf_checkpoint,PerTaskInfo & task_info)953 absl::Status HttpFederatedProtocol::UploadDataViaSimpleAgg(
954     std::string tf_checkpoint, PerTaskInfo& task_info) {
955   FCP_LOG(INFO) << "Uploading checkpoint with simple aggregation.";
956   FCP_ASSIGN_OR_RETURN(
957       std::string uri_suffix,
958       CreateByteStreamUploadUriSuffix(task_info.aggregation_resource_name));
959   FCP_ASSIGN_OR_RETURN(
960       std::unique_ptr<HttpRequest> http_request,
961       task_info.data_upload_request_creator->CreateProtocolRequest(
962           uri_suffix, {{"upload_protocol", "raw"}}, HttpRequest::Method::kPost,
963           std::move(tf_checkpoint), /*is_protobuf_encoded=*/false));
964   FCP_LOG(INFO) << "ByteStream.Write request URI is: " << http_request->uri();
965   auto http_response = protocol_request_helper_.PerformProtocolRequest(
966       std::move(http_request), *interruptible_runner_);
967   if (!http_response.ok()) {
968     // If the request failed, we'll forward the error status.
969     return absl::Status(http_response.status().code(),
970                         absl::StrCat("Data upload failed: ",
971                                      http_response.status().ToString()));
972   }
973   return absl::OkStatus();
974 }
975 
SubmitAggregationResult(PerTaskInfo & task_info)976 absl::Status HttpFederatedProtocol::SubmitAggregationResult(
977     PerTaskInfo& task_info) {
978   FCP_LOG(INFO) << "Notifying the server that data upload is complete.";
979   FCP_ASSIGN_OR_RETURN(std::string uri_suffix,
980                        CreateSubmitAggregationResultUriSuffix(
981                            task_info.aggregation_session_id,
982                            task_info.aggregation_client_token));
983   SubmitAggregationResultRequest request;
984   request.set_resource_name(task_info.aggregation_resource_name);
985   FCP_ASSIGN_OR_RETURN(
986       std::unique_ptr<HttpRequest> http_request,
987       task_info.aggregation_request_creator->CreateProtocolRequest(
988           uri_suffix, {}, HttpRequest::Method::kPost,
989           request.SerializeAsString(), /*is_protobuf_encoded=*/true));
990   FCP_LOG(INFO) << "SubmitAggregationResult request URI is: "
991                 << http_request->uri();
992   auto http_response = protocol_request_helper_.PerformProtocolRequest(
993       std::move(http_request), *interruptible_runner_);
994   if (!http_response.ok()) {
995     // If the request failed, we'll forward the error status.
996     return absl::Status(http_response.status().code(),
997                         absl::StrCat("SubmitAggregationResult failed: ",
998                                      http_response.status().ToString()));
999   }
1000   return absl::OkStatus();
1001 }
1002 
AbortAggregation(absl::Status original_error_status,absl::string_view error_message_for_server,PerTaskInfo & task_info)1003 absl::Status HttpFederatedProtocol::AbortAggregation(
1004     absl::Status original_error_status,
1005     absl::string_view error_message_for_server, PerTaskInfo& task_info) {
1006   FCP_LOG(INFO) << "Aborting aggregation: " << original_error_status;
1007   FCP_CHECK(task_info.state == ObjectState::kReportFailedPermanentError)
1008       << "Invalid call sequence";
1009   FCP_ASSIGN_OR_RETURN(
1010       std::string uri_suffix,
1011       CreateAbortAggregationUriSuffix(task_info.aggregation_session_id,
1012                                       task_info.aggregation_client_token));
1013   // We only provide the server with a simplified error message.
1014   absl::Status error_status(original_error_status.code(),
1015                             error_message_for_server);
1016   AbortAggregationRequest request;
1017   *request.mutable_status() = ConvertAbslStatusToRpcStatus(error_status);
1018   FCP_ASSIGN_OR_RETURN(
1019       std::unique_ptr<HttpRequest> http_request,
1020       task_info.aggregation_request_creator->CreateProtocolRequest(
1021           uri_suffix, {}, HttpRequest::Method::kPost,
1022           request.SerializeAsString(), /*is_protobuf_encoded=*/true));
1023   std::unique_ptr<InterruptibleRunner> cancellation_runner =
1024       CreateDelayedInterruptibleRunner(
1025           log_manager_, should_abort_, timing_config_,
1026           absl::Now() + waiting_period_for_cancellation_);
1027   return protocol_request_helper_
1028       .PerformProtocolRequest(std::move(http_request), *cancellation_runner)
1029       .status();
1030 }
1031 
ReportViaSecureAggregation(ComputationResults results,absl::Duration plan_duration,PerTaskInfo & task_info)1032 absl::Status HttpFederatedProtocol::ReportViaSecureAggregation(
1033     ComputationResults results, absl::Duration plan_duration,
1034     PerTaskInfo& task_info) {
1035   FCP_ASSIGN_OR_RETURN(
1036       StartSecureAggregationResponse response_proto,
1037       StartSecureAggregationAndReportTaskResult(plan_duration, task_info));
1038   SecureAggregationProtocolExecutionInfo protocol_execution_info =
1039       response_proto.protocol_execution_info();
1040   // TODO(team): Remove the authorization token fallback once
1041   // client_token is always populated.
1042   task_info.aggregation_client_token =
1043       !response_proto.client_token().empty()
1044           ? response_proto.client_token()
1045           : task_info.aggregation_authorization_token;
1046 
1047   // Move checkpoint out of ComputationResults, and put it into a std::optional.
1048   std::optional<TFCheckpoint> tf_checkpoint;
1049   for (auto& [k, v] : results) {
1050     if (std::holds_alternative<TFCheckpoint>(v)) {
1051       tf_checkpoint = std::get<TFCheckpoint>(std::move(v));
1052       results.erase(k);
1053       break;
1054     }
1055   }
1056   absl::StatusOr<secagg::ServerToClientWrapperMessage> server_response_holder;
1057   FCP_ASSIGN_OR_RETURN(
1058       std::unique_ptr<SecAggSendToServerBase> send_to_server_impl,
1059       HttpSecAggSendToServerImpl::Create(
1060           api_key_, &clock_, &protocol_request_helper_,
1061           interruptible_runner_.get(),
1062           [this](absl::Time deadline) {
1063             return CreateDelayedInterruptibleRunner(
1064                 this->log_manager_, this->should_abort_, this->timing_config_,
1065                 deadline);
1066           },
1067           &server_response_holder, task_info.aggregation_session_id,
1068           task_info.aggregation_client_token,
1069           response_proto.secagg_protocol_forwarding_info(),
1070           response_proto.masked_result_resource(),
1071           response_proto.nonmasked_result_resource(), std::move(tf_checkpoint),
1072           flags_->disable_http_request_body_compression(),
1073           waiting_period_for_cancellation_));
1074   auto protocol_delegate = std::make_unique<HttpSecAggProtocolDelegate>(
1075       response_proto.secure_aggregands(), &server_response_holder);
1076   auto secagg_interruptible_runner = std::make_unique<InterruptibleRunner>(
1077       log_manager_, should_abort_, timing_config_,
1078       InterruptibleRunner::DiagnosticsConfig{
1079           .interrupted = ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP,
1080           .interrupt_timeout =
1081               ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP_TIMED_OUT,
1082           .interrupted_extended = ProdDiagCode::
1083               BACKGROUND_TRAINING_INTERRUPT_HTTP_EXTENDED_COMPLETED,
1084           .interrupt_timeout_extended = ProdDiagCode::
1085               BACKGROUND_TRAINING_INTERRUPT_HTTP_EXTENDED_TIMED_OUT});
1086   std::unique_ptr<SecAggRunner> secagg_runner =
1087       secagg_runner_factory_->CreateSecAggRunner(
1088           std::move(send_to_server_impl), std::move(protocol_delegate),
1089           secagg_event_publisher_, log_manager_,
1090           secagg_interruptible_runner.get(),
1091           protocol_execution_info.expected_number_of_clients(),
1092           protocol_execution_info
1093               .minimum_surviving_clients_for_reconstruction());
1094   FCP_RETURN_IF_ERROR(secagg_runner->Run(std::move(results)));
1095   return absl::OkStatus();
1096 }
1097 
1098 absl::StatusOr<StartSecureAggregationResponse>
StartSecureAggregationAndReportTaskResult(absl::Duration plan_duration,PerTaskInfo & task_info)1099 HttpFederatedProtocol::StartSecureAggregationAndReportTaskResult(
1100     absl::Duration plan_duration, PerTaskInfo& task_info) {
1101   FCP_ASSIGN_OR_RETURN(std::string start_secure_aggregation_uri_suffix,
1102                        CreateStartSecureAggregationUriSuffix(
1103                            task_info.aggregation_session_id,
1104                            task_info.aggregation_authorization_token));
1105   FCP_ASSIGN_OR_RETURN(
1106       std::unique_ptr<HttpRequest> start_secure_aggregation_http_request,
1107       task_info.aggregation_request_creator->CreateProtocolRequest(
1108           start_secure_aggregation_uri_suffix, QueryParams(),
1109           HttpRequest::Method::kPost,
1110           StartSecureAggregationRequest::default_instance().SerializeAsString(),
1111           /*is_protobuf_encoded=*/true));
1112 
1113   FCP_ASSIGN_OR_RETURN(
1114       std::string report_task_result_uri_suffix,
1115       CreateReportTaskResultUriSuffix(population_name_, task_info.session_id));
1116   FCP_ASSIGN_OR_RETURN(
1117       ReportTaskResultRequest report_task_result_request,
1118       CreateReportTaskResultRequest(
1119           engine::PhaseOutcome::COMPLETED, plan_duration,
1120           task_info.aggregation_session_id, task_info.task_name));
1121   FCP_ASSIGN_OR_RETURN(
1122       std::unique_ptr<HttpRequest> report_task_result_http_request,
1123       task_assignment_request_creator_->CreateProtocolRequest(
1124           report_task_result_uri_suffix, QueryParams(),
1125           HttpRequest::Method::kPost,
1126           report_task_result_request.SerializeAsString(),
1127           /*is_protobuf_encoded=*/true));
1128 
1129   std::vector<std::unique_ptr<HttpRequest>> requests;
1130   requests.push_back(std::move(start_secure_aggregation_http_request));
1131   requests.push_back(std::move(report_task_result_http_request));
1132 
1133   FCP_ASSIGN_OR_RETURN(
1134       std::vector<absl::StatusOr<InMemoryHttpResponse>> responses,
1135       protocol_request_helper_.PerformMultipleProtocolRequests(
1136           std::move(requests), *interruptible_runner_));
1137   // We will handle the response for StartSecureAggregation RPC.
1138   // The ReportTaskResult RPC is for best efforts only, we will ignore the
1139   // response, only log a diagcode if it fails.
1140   FCP_CHECK(responses.size() == 2);
1141   if (!responses[1].ok()) {
1142     log_manager_->LogDiag(ProdDiagCode::HTTP_REPORT_TASK_RESULT_REQUEST_FAILED);
1143   }
1144   // FCP_ASSIGN_OR_RETURN(Operation initial_operation,
1145   //                      ParseOperationProtoFromHttpResponse(responses[0]));
1146   // FCP_ASSIGN_OR_RETURN(
1147   //     Operation completed_operation,
1148   //     protocol_request_helper_.PollOperationResponseUntilDone(
1149   //         initial_operation, *task_info.aggregation_request_creator,
1150   //         *interruptible_runner_));
1151   // // The Operation has finished. Check if it resulted in an error, and if so
1152   // // forward it after converting it to an absl::Status error.
1153   // if (completed_operation.has_error()) {
1154   //   auto rpc_error =
1155   //   ConvertRpcStatusToAbslStatus(completed_operation.error()); return
1156   //   absl::Status(
1157   //       rpc_error.code(),
1158   //       absl::StrCat("Operation ", completed_operation.name(),
1159   //                    " contained error: ", rpc_error.ToString()));
1160   // }
1161   StartSecureAggregationResponse response_proto;
1162   if (!response_proto.ParseFromString(std::string(responses[0]->body))) {
1163     return absl::InvalidArgumentError(
1164         "could not parse StartSecureAggregationResponse proto");
1165   }
1166   return response_proto;
1167 }
1168 
ReportNotCompleted(engine::PhaseOutcome phase_outcome,absl::Duration plan_duration,std::optional<std::string> aggregation_session_id)1169 absl::Status HttpFederatedProtocol::ReportNotCompleted(
1170     engine::PhaseOutcome phase_outcome, absl::Duration plan_duration,
1171     std::optional<std::string> aggregation_session_id) {
1172   FCP_LOG(WARNING) << "Reporting outcome: " << static_cast<int>(phase_outcome);
1173   PerTaskInfo* task_info;
1174   if (aggregation_session_id.has_value()) {
1175     if (!task_info_map_.contains(aggregation_session_id.value())) {
1176       return absl::InvalidArgumentError("Unexpected aggregation_session_id.");
1177     }
1178     task_info = &task_info_map_[aggregation_session_id.value()];
1179   } else {
1180     task_info = &default_task_info_;
1181   }
1182   FCP_CHECK(task_info->state == ObjectState::kCheckinAccepted ||
1183             task_info->state == ObjectState::kMultipleTaskAssignmentsAccepted)
1184       << "Invalid call sequence";
1185   task_info->state = ObjectState::kReportCalled;
1186   FCP_ASSIGN_OR_RETURN(
1187       ReportTaskResultRequest request,
1188       CreateReportTaskResultRequest(phase_outcome, plan_duration,
1189                                     task_info->aggregation_session_id,
1190                                     task_info->task_name));
1191   // Construct the URI suffix.
1192   FCP_ASSIGN_OR_RETURN(
1193       std::string uri_suffix,
1194       CreateReportTaskResultUriSuffix(population_name_, task_info->session_id));
1195   FCP_ASSIGN_OR_RETURN(
1196       std::unique_ptr<HttpRequest> http_request,
1197       task_assignment_request_creator_->CreateProtocolRequest(
1198           uri_suffix, {}, HttpRequest::Method::kPost,
1199           request.SerializeAsString(), /*is_protobuf_encoded=*/true));
1200 
1201   // Issue the request.
1202   absl::StatusOr<InMemoryHttpResponse> http_response =
1203       protocol_request_helper_.PerformProtocolRequest(std::move(http_request),
1204                                                       *interruptible_runner_);
1205   if (!http_response.ok()) {
1206     // If the request failed, we'll forward the error status.
1207     return absl::Status(http_response.status().code(),
1208                         absl::StrCat("ReportTaskResult request failed: ",
1209                                      http_response.status().ToString()));
1210   }
1211   return absl::OkStatus();
1212 }
1213 
1214 ::google::internal::federatedml::v2::RetryWindow
GetLatestRetryWindow()1215 HttpFederatedProtocol::GetLatestRetryWindow() {
1216   ObjectState state = GetTheLatestStateFromAllTasks();
1217   // We explicitly enumerate all possible states here rather than using
1218   // "default", to ensure that when new states are added later on, the author
1219   // is forced to update this method and consider which is the correct
1220   // RetryWindow to return.
1221   switch (state) {
1222     case ObjectState::kCheckinAccepted:
1223     case ObjectState::kMultipleTaskAssignmentsAccepted:
1224     case ObjectState::kReportCalled:
1225       // If a client makes it past the 'checkin acceptance' stage, we use the
1226       // 'accepted' RetryWindow unconditionally (unless a permanent error is
1227       // encountered). This includes cases where the checkin is accepted, but
1228       // the report request results in a (transient) error.
1229       FCP_CHECK(retry_times_.has_value());
1230       return GenerateRetryWindowFromRetryTime(
1231           retry_times_->retry_time_if_accepted);
1232     case ObjectState::kEligibilityEvalCheckinRejected:
1233     case ObjectState::kEligibilityEvalDisabled:
1234     case ObjectState::kEligibilityEvalEnabled:
1235     case ObjectState::kCheckinRejected:
1236     case ObjectState::kMultipleTaskAssignmentsNoAvailableTask:
1237     case ObjectState::kReportMultipleTaskPartialError:
1238       FCP_CHECK(retry_times_.has_value());
1239       return GenerateRetryWindowFromRetryTime(
1240           retry_times_->retry_time_if_rejected);
1241     case ObjectState::kInitialized:
1242     case ObjectState::kEligibilityEvalCheckinFailed:
1243     case ObjectState::kCheckinFailed:
1244     case ObjectState::kMultipleTaskAssignmentsFailed:
1245       if (retry_times_.has_value()) {
1246         // If we already received a server-provided retry window, then use it.
1247         return GenerateRetryWindowFromRetryTime(
1248             retry_times_->retry_time_if_rejected);
1249       }
1250       // Otherwise, we generate a retry window using the flag-provided transient
1251       // error retry period.
1252       return GenerateRetryWindowFromTargetDelay(
1253           absl::Seconds(
1254               flags_->federated_training_transient_errors_retry_delay_secs()),
1255           // NOLINTBEGIN(whitespace/line_length)
1256           flags_
1257               ->federated_training_transient_errors_retry_delay_jitter_percent(),
1258           // NOLINTEND(whitespace/line_length)
1259           bit_gen_);
1260     case ObjectState::kEligibilityEvalCheckinFailedPermanentError:
1261     case ObjectState::kCheckinFailedPermanentError:
1262     case ObjectState::kMultipleTaskAssignmentsFailedPermanentError:
1263     case ObjectState::kReportFailedPermanentError:
1264       // If we encountered a permanent error during the eligibility eval or
1265       // regular checkins, then we use the Flags-configured 'permanent error'
1266       // retry period. Note that we do so regardless of whether the server had,
1267       // by the time the permanent error was received, already returned a
1268       // CheckinRequestAck containing a set of retry windows. See note on error
1269       // handling at the top of this file.
1270       return GenerateRetryWindowFromTargetDelay(
1271           absl::Seconds(
1272               flags_->federated_training_permanent_errors_retry_delay_secs()),
1273           // NOLINTBEGIN(whitespace/line_length)
1274           flags_
1275               ->federated_training_permanent_errors_retry_delay_jitter_percent(),
1276           // NOLINTEND(whitespace/line_length)
1277           bit_gen_);
1278   }
1279 }
1280 
1281 absl::StatusOr<FederatedProtocol::PlanAndCheckpointPayloads>
FetchTaskResources(HttpFederatedProtocol::TaskResources task_resources)1282 HttpFederatedProtocol::FetchTaskResources(
1283     HttpFederatedProtocol::TaskResources task_resources) {
1284   FCP_ASSIGN_OR_RETURN(UriOrInlineData plan_uri_or_data,
1285                        ConvertResourceToUriOrInlineData(task_resources.plan));
1286   FCP_ASSIGN_OR_RETURN(
1287       UriOrInlineData checkpoint_uri_or_data,
1288       ConvertResourceToUriOrInlineData(task_resources.checkpoint));
1289 
1290   // Fetch the plan and init checkpoint resources if they need to be fetched
1291   // (using the inline data instead if available).
1292   absl::StatusOr<std::vector<absl::StatusOr<InMemoryHttpResponse>>>
1293       resource_responses;
1294   {
1295     auto started_stopwatch = network_stopwatch_->Start();
1296     resource_responses = FetchResourcesInMemory(
1297         *http_client_, *interruptible_runner_,
1298         {plan_uri_or_data, checkpoint_uri_or_data}, &bytes_downloaded_,
1299         &bytes_uploaded_, resource_cache_);
1300   }
1301   FCP_RETURN_IF_ERROR(resource_responses);
1302   auto& plan_data_response = (*resource_responses)[0];
1303   auto& checkpoint_data_response = (*resource_responses)[1];
1304 
1305   // Note: we forward any error during the fetching of the plan/checkpoint
1306   // resources resources to the caller, which means that these error codes
1307   // will be checked against the set of 'permanent' error codes, just like the
1308   // errors in response to the protocol request are.
1309   if (!plan_data_response.ok()) {
1310     return absl::Status(plan_data_response.status().code(),
1311                         absl::StrCat("plan fetch failed: ",
1312                                      plan_data_response.status().ToString()));
1313   }
1314   if (!checkpoint_data_response.ok()) {
1315     return absl::Status(
1316         checkpoint_data_response.status().code(),
1317         absl::StrCat("checkpoint fetch failed: ",
1318                      checkpoint_data_response.status().ToString()));
1319   }
1320 
1321   return PlanAndCheckpointPayloads{plan_data_response->body,
1322                                    checkpoint_data_response->body};
1323 }
1324 
1325 absl::StatusOr<PopulationEligibilitySpec>
FetchPopulationEligibilitySpec(const Resource & population_eligibility_spec_resource)1326 HttpFederatedProtocol::FetchPopulationEligibilitySpec(
1327     const Resource& population_eligibility_spec_resource) {
1328   FCP_ASSIGN_OR_RETURN(
1329       UriOrInlineData population_eligibility_spec_uri_or_data,
1330       ConvertResourceToUriOrInlineData(population_eligibility_spec_resource));
1331 
1332   // Fetch the plan and init checkpoint resources if they need to be fetched
1333   // (using the inline data instead if available).
1334   absl::StatusOr<std::vector<absl::StatusOr<InMemoryHttpResponse>>>
1335       resource_responses;
1336   {
1337     auto started_stopwatch = network_stopwatch_->Start();
1338     resource_responses = FetchResourcesInMemory(
1339         *http_client_, *interruptible_runner_,
1340         {population_eligibility_spec_uri_or_data}, &bytes_downloaded_,
1341         &bytes_uploaded_, resource_cache_);
1342   }
1343   FCP_RETURN_IF_ERROR(resource_responses);
1344   auto& response = (*resource_responses)[0];
1345 
1346   // Note: we forward any error during the fetching of the plan/checkpoint
1347   // resources resources to the caller, which means that these error codes
1348   // will be checked against the set of 'permanent' error codes, just like the
1349   // errors in response to the protocol request are.
1350   if (!response.ok()) {
1351     return absl::Status(
1352         response.status().code(),
1353         absl::StrCat("population eligibility spec fetch failed: ",
1354                      response.status().ToString()));
1355   }
1356   PopulationEligibilitySpec population_eligibility_spec;
1357   if (!ParseFromStringOrCord(population_eligibility_spec, response->body)) {
1358     return absl::InvalidArgumentError(
1359         "Unable to parse PopulationEligibilitySpec.");
1360   }
1361   return population_eligibility_spec;
1362 }
1363 
UpdateObjectStateIfPermanentError(absl::Status status,HttpFederatedProtocol::ObjectState permanent_error_object_state)1364 void HttpFederatedProtocol::UpdateObjectStateIfPermanentError(
1365     absl::Status status,
1366     HttpFederatedProtocol::ObjectState permanent_error_object_state) {
1367   if (federated_training_permanent_error_codes_.contains(
1368           static_cast<int32_t>(status.code()))) {
1369     object_state_ = permanent_error_object_state;
1370   }
1371 }
1372 
1373 FederatedProtocol::ObjectState
GetTheLatestStateFromAllTasks()1374 HttpFederatedProtocol::GetTheLatestStateFromAllTasks() {
1375   // If we didn't have successful check-in or multiple task assignments, we
1376   // don't have to check the per task states.
1377   if (object_state_ != ObjectState::kCheckinAccepted &&
1378       object_state_ != ObjectState::kMultipleTaskAssignmentsAccepted) {
1379     return object_state_;
1380   }
1381   if (!flags_->http_protocol_supports_multiple_task_assignments()) {
1382     return default_task_info_.state;
1383   }
1384 
1385   int32_t success_cnt = 0;
1386   int32_t permanent_failure_cnt = 0;
1387   int32_t task_cnt = 0;
1388   auto count_func = [&success_cnt, &permanent_failure_cnt](ObjectState state) {
1389     if (state == ObjectState::kReportCalled) {
1390       success_cnt++;
1391     }
1392     if (state == ObjectState::kReportFailedPermanentError) {
1393       permanent_failure_cnt++;
1394     }
1395   };
1396 
1397   if (default_task_info_.state != ObjectState::kInitialized) {
1398     task_cnt++;
1399     count_func(default_task_info_.state);
1400   }
1401 
1402   for (const auto& item : task_info_map_) {
1403     task_cnt++;
1404     count_func(item.second.state);
1405   }
1406 
1407   // If none of the tasks succeeds, assume all of them failed with permanent
1408   // error and return kReportFailedPermanentError. If all of them succeeds,
1409   // return kReportCalled. If only some of the tasks succeed, return
1410   // kReportMultipleTaskPartialError.
1411   if (permanent_failure_cnt == task_cnt) {
1412     return ObjectState::kReportFailedPermanentError;
1413   } else if (success_cnt == task_cnt) {
1414     return ObjectState::kReportCalled;
1415   } else {
1416     return ObjectState::kReportMultipleTaskPartialError;
1417   }
1418 }
1419 
GetNetworkStats()1420 NetworkStats HttpFederatedProtocol::GetNetworkStats() {
1421   return {.bytes_downloaded = bytes_downloaded_,
1422           .bytes_uploaded = bytes_uploaded_,
1423           .network_duration = network_stopwatch_->GetTotalDuration()};
1424 }
1425 
1426 }  // namespace http
1427 }  // namespace client
1428 }  // namespace fcp
1429