xref: /aosp_15_r20/external/federated-compute/fcp/client/grpc_federated_protocol.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2020 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "fcp/client/grpc_federated_protocol.h"
17 
18 #include <algorithm>
19 #include <functional>
20 #include <memory>
21 #include <optional>
22 #include <string>
23 #include <utility>
24 #include <variant>
25 
26 #include "google/protobuf/duration.pb.h"
27 #include "absl/status/status.h"
28 #include "absl/status/statusor.h"
29 #include "absl/time/time.h"
30 #include "absl/types/span.h"
31 #include "fcp/base/monitoring.h"
32 #include "fcp/base/time_util.h"
33 #include "fcp/client/diag_codes.pb.h"
34 #include "fcp/client/engine/engine.pb.h"
35 #include "fcp/client/event_publisher.h"
36 #include "fcp/client/federated_protocol.h"
37 #include "fcp/client/federated_protocol_util.h"
38 #include "fcp/client/fl_runner.pb.h"
39 #include "fcp/client/flags.h"
40 #include "fcp/client/grpc_bidi_stream.h"
41 #include "fcp/client/http/http_client.h"
42 #include "fcp/client/http/in_memory_request_response.h"
43 #include "fcp/client/interruptible_runner.h"
44 #include "fcp/client/log_manager.h"
45 #include "fcp/client/opstats/opstats_logger.h"
46 #include "fcp/client/secagg_event_publisher.h"
47 #include "fcp/client/secagg_runner.h"
48 #include "fcp/client/stats.h"
49 #include "fcp/protos/federated_api.pb.h"
50 #include "fcp/protos/plan.pb.h"
51 #include "fcp/secagg/client/secagg_client.h"
52 #include "fcp/secagg/client/send_to_server_interface.h"
53 #include "fcp/secagg/client/state_transition_listener_interface.h"
54 #include "fcp/secagg/shared/aes_ctr_prng_factory.h"
55 #include "fcp/secagg/shared/crypto_rand_prng.h"
56 #include "fcp/secagg/shared/input_vector_specification.h"
57 #include "fcp/secagg/shared/math.h"
58 #include "fcp/secagg/shared/secagg_messages.pb.h"
59 #include "fcp/secagg/shared/secagg_vector.h"
60 
61 namespace fcp {
62 namespace client {
63 
64 using ::fcp::client::http::UriOrInlineData;
65 using ::fcp::secagg::ClientToServerWrapperMessage;
66 using ::google::internal::federatedml::v2::CheckinRequest;
67 using ::google::internal::federatedml::v2::CheckinRequestAck;
68 using ::google::internal::federatedml::v2::CheckinResponse;
69 using ::google::internal::federatedml::v2::ClientExecutionStats;
70 using ::google::internal::federatedml::v2::ClientStreamMessage;
71 using ::google::internal::federatedml::v2::EligibilityEvalCheckinRequest;
72 using ::google::internal::federatedml::v2::EligibilityEvalCheckinResponse;
73 using ::google::internal::federatedml::v2::EligibilityEvalPayload;
74 using ::google::internal::federatedml::v2::HttpCompressionFormat;
75 using ::google::internal::federatedml::v2::ProtocolOptionsRequest;
76 using ::google::internal::federatedml::v2::RetryWindow;
77 using ::google::internal::federatedml::v2::ServerStreamMessage;
78 using ::google::internal::federatedml::v2::SideChannelExecutionInfo;
79 using ::google::internal::federatedml::v2::TaskEligibilityInfo;
80 
81 // A note on error handling:
82 //
83 // The implementation here makes a distinction between what we call 'transient'
84 // and 'permanent' errors. While the exact categorization of transient vs.
85 // permanent errors is defined by a flag, the intent is that transient errors
86 // are those types of errors that may occur in the regular course of business,
87 // e.g. due to an interrupted network connection, a load balancer temporarily
88 // rejecting our request etc. Generally, these are expected to be resolvable by
89 // merely retrying the request at a slightly later time. Permanent errors are
90 // intended to be those that are not expected to be resolvable as quickly or by
91 // merely retrying the request. E.g. if a client checks in to the server with a
92 // population name that doesn't exist, then the server may return NOT_FOUND, and
93 // until the server-side configuration is changed, it will continue returning
94 // such an error. Hence, such errors can warrant a longer retry period (to waste
95 // less of both the client's and server's resources).
96 //
97 // The errors also differ in how they interact with the server-specified retry
98 // windows that are returned via the CheckinRequestAck message.
99 // - If a permanent error occurs, then we will always return a retry window
100 //   based on the target 'permanent errors retry period' flag, regardless of
101 //   whether we received a CheckinRequestAck from the server at an earlier time.
102 // - If a transient error occurs, then we will only return a retry window
103 //   based on the target 'transient errors retry period' flag if the server
104 //   didn't already return a CheckinRequestAck. If it did return such an ack,
105 //   then one of the retry windows in that message will be used instead.
106 //
107 // Finally, note that for simplicity's sake we generally check whether a
108 // permanent error was received at the level of this class's public method,
109 // rather than deeper down in each of our helper methods that actually call
110 // directly into the gRPC stack. This keeps our state-managing code simpler, but
111 // does mean that if any of our helper methods like SendCheckinRequest produce a
112 // permanent error code locally (i.e. without it being sent by the server), it
113 // will be treated as if the server sent it and the permanent error retry period
114 // will be used. We consider this a reasonable tradeoff.
115 
GrpcFederatedProtocol(EventPublisher * event_publisher,LogManager * log_manager,std::unique_ptr<SecAggRunnerFactory> secagg_runner_factory,const Flags * flags,::fcp::client::http::HttpClient * http_client,const std::string & federated_service_uri,const std::string & api_key,const std::string & test_cert_path,absl::string_view population_name,absl::string_view retry_token,absl::string_view client_version,absl::string_view attestation_measurement,std::function<bool ()> should_abort,const InterruptibleRunner::TimingConfig & timing_config,const int64_t grpc_channel_deadline_seconds,cache::ResourceCache * resource_cache)116 GrpcFederatedProtocol::GrpcFederatedProtocol(
117     EventPublisher* event_publisher, LogManager* log_manager,
118     std::unique_ptr<SecAggRunnerFactory> secagg_runner_factory,
119     const Flags* flags, ::fcp::client::http::HttpClient* http_client,
120     const std::string& federated_service_uri, const std::string& api_key,
121     const std::string& test_cert_path, absl::string_view population_name,
122     absl::string_view retry_token, absl::string_view client_version,
123     absl::string_view attestation_measurement,
124     std::function<bool()> should_abort,
125     const InterruptibleRunner::TimingConfig& timing_config,
126     const int64_t grpc_channel_deadline_seconds,
127     cache::ResourceCache* resource_cache)
128     : GrpcFederatedProtocol(
129           event_publisher, log_manager, std::move(secagg_runner_factory), flags,
130           http_client,
131           std::make_unique<GrpcBidiStream>(
132               federated_service_uri, api_key, std::string(population_name),
133               grpc_channel_deadline_seconds, test_cert_path),
134           population_name, retry_token, client_version, attestation_measurement,
135           should_abort, absl::BitGen(), timing_config, resource_cache) {}
136 
GrpcFederatedProtocol(EventPublisher * event_publisher,LogManager * log_manager,std::unique_ptr<SecAggRunnerFactory> secagg_runner_factory,const Flags * flags,::fcp::client::http::HttpClient * http_client,std::unique_ptr<GrpcBidiStreamInterface> grpc_bidi_stream,absl::string_view population_name,absl::string_view retry_token,absl::string_view client_version,absl::string_view attestation_measurement,std::function<bool ()> should_abort,absl::BitGen bit_gen,const InterruptibleRunner::TimingConfig & timing_config,cache::ResourceCache * resource_cache)137 GrpcFederatedProtocol::GrpcFederatedProtocol(
138     EventPublisher* event_publisher, LogManager* log_manager,
139     std::unique_ptr<SecAggRunnerFactory> secagg_runner_factory,
140     const Flags* flags, ::fcp::client::http::HttpClient* http_client,
141     std::unique_ptr<GrpcBidiStreamInterface> grpc_bidi_stream,
142     absl::string_view population_name, absl::string_view retry_token,
143     absl::string_view client_version, absl::string_view attestation_measurement,
144     std::function<bool()> should_abort, absl::BitGen bit_gen,
145     const InterruptibleRunner::TimingConfig& timing_config,
146     cache::ResourceCache* resource_cache)
147     : object_state_(ObjectState::kInitialized),
148       event_publisher_(event_publisher),
149       log_manager_(log_manager),
150       secagg_runner_factory_(std::move(secagg_runner_factory)),
151       flags_(flags),
152       http_client_(http_client),
153       grpc_bidi_stream_(std::move(grpc_bidi_stream)),
154       population_name_(population_name),
155       retry_token_(retry_token),
156       client_version_(client_version),
157       attestation_measurement_(attestation_measurement),
158       bit_gen_(std::move(bit_gen)),
159       resource_cache_(resource_cache) {
160   interruptible_runner_ = std::make_unique<InterruptibleRunner>(
161       log_manager, should_abort, timing_config,
162       InterruptibleRunner::DiagnosticsConfig{
163           .interrupted = ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_GRPC,
164           .interrupt_timeout =
165               ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_GRPC_TIMED_OUT,
166           .interrupted_extended = ProdDiagCode::
167               BACKGROUND_TRAINING_INTERRUPT_GRPC_EXTENDED_COMPLETED,
168           .interrupt_timeout_extended = ProdDiagCode::
169               BACKGROUND_TRAINING_INTERRUPT_GRPC_EXTENDED_TIMED_OUT});
170   // Note that we could cast the provided error codes to absl::StatusCode
171   // values here. However, that means we'd have to handle the case when
172   // invalid integers that don't map to a StatusCode enum are provided in the
173   // flag here. Instead, we cast absl::StatusCodes to int32_t each time we
174   // compare them with the flag-provided list of codes, which means we never
175   // have to worry about invalid flag values (besides the fact that invalid
176   // values will be silently ignored, which could make it harder to realize when
177   // flag is misconfigured).
178   const std::vector<int32_t>& error_codes =
179       flags->federated_training_permanent_error_codes();
180   federated_training_permanent_error_codes_ =
181       absl::flat_hash_set<int32_t>(error_codes.begin(), error_codes.end());
182 }
183 
~GrpcFederatedProtocol()184 GrpcFederatedProtocol::~GrpcFederatedProtocol() { grpc_bidi_stream_->Close(); }
185 
Send(google::internal::federatedml::v2::ClientStreamMessage * client_stream_message)186 absl::Status GrpcFederatedProtocol::Send(
187     google::internal::federatedml::v2::ClientStreamMessage*
188         client_stream_message) {
189   // Note that this stopwatch measurement may not fully measure the time it
190   // takes to send all of the data, as it may return before all data was written
191   // to the network socket. It's the best estimate we can provide though.
192   auto started_stopwatch = network_stopwatch_->Start();
193   FCP_RETURN_IF_ERROR(interruptible_runner_->Run(
194       [this, &client_stream_message]() {
195         return this->grpc_bidi_stream_->Send(client_stream_message);
196       },
197       [this]() { this->grpc_bidi_stream_->Close(); }));
198   return absl::OkStatus();
199 }
200 
Receive(google::internal::federatedml::v2::ServerStreamMessage * server_stream_message)201 absl::Status GrpcFederatedProtocol::Receive(
202     google::internal::federatedml::v2::ServerStreamMessage*
203         server_stream_message) {
204   // Note that this stopwatch measurement will generally include time spent
205   // waiting for the server to return a response (i.e. idle time rather than the
206   // true time it takes to send/receive data on the network). It's the best
207   // estimate we can provide though.
208   auto started_stopwatch = network_stopwatch_->Start();
209   FCP_RETURN_IF_ERROR(interruptible_runner_->Run(
210       [this, &server_stream_message]() {
211         return grpc_bidi_stream_->Receive(server_stream_message);
212       },
213       [this]() { this->grpc_bidi_stream_->Close(); }));
214   return absl::OkStatus();
215 }
216 
CreateProtocolOptionsRequest(bool should_ack_checkin) const217 ProtocolOptionsRequest GrpcFederatedProtocol::CreateProtocolOptionsRequest(
218     bool should_ack_checkin) const {
219   ProtocolOptionsRequest request;
220   request.set_should_ack_checkin(should_ack_checkin);
221   request.set_supports_http_download(http_client_ != nullptr);
222   request.set_supports_eligibility_eval_http_download(
223       http_client_ != nullptr &&
224       flags_->enable_grpc_with_eligibility_eval_http_resource_support());
225 
226   // Note that we set this field for both eligibility eval checkin requests
227   // and regular checkin requests. Even though eligibility eval tasks do not
228   // have any aggregation phase, we still advertise the client's support for
229   // Secure Aggregation during the eligibility eval checkin phase. We do
230   // this because it doesn't hurt anything, and because letting the server
231   // know whether client supports SecAgg sooner rather than later in the
232   // protocol seems to provide maximum flexibility if the server ever were
233   // to use that information at this stage of the protocol in the future.
234   request.mutable_side_channels()
235       ->mutable_secure_aggregation()
236       ->add_client_variant(secagg::SECAGG_CLIENT_VARIANT_NATIVE_V1);
237     request.mutable_supported_http_compression_formats()->Add(
238         HttpCompressionFormat::HTTP_COMPRESSION_FORMAT_GZIP);
239   return request;
240 }
241 
SendEligibilityEvalCheckinRequest()242 absl::Status GrpcFederatedProtocol::SendEligibilityEvalCheckinRequest() {
243   ClientStreamMessage client_stream_message;
244   EligibilityEvalCheckinRequest* eligibility_checkin_request =
245       client_stream_message.mutable_eligibility_eval_checkin_request();
246   eligibility_checkin_request->set_population_name(population_name_);
247   eligibility_checkin_request->set_retry_token(retry_token_);
248   eligibility_checkin_request->set_client_version(client_version_);
249   eligibility_checkin_request->set_attestation_measurement(
250       attestation_measurement_);
251   *eligibility_checkin_request->mutable_protocol_options_request() =
252       CreateProtocolOptionsRequest(
253           /* should_ack_checkin=*/true);
254 
255   return Send(&client_stream_message);
256 }
257 
SendCheckinRequest(const std::optional<TaskEligibilityInfo> & task_eligibility_info)258 absl::Status GrpcFederatedProtocol::SendCheckinRequest(
259     const std::optional<TaskEligibilityInfo>& task_eligibility_info) {
260   ClientStreamMessage client_stream_message;
261   CheckinRequest* checkin_request =
262       client_stream_message.mutable_checkin_request();
263   checkin_request->set_population_name(population_name_);
264   checkin_request->set_retry_token(retry_token_);
265   checkin_request->set_client_version(client_version_);
266   checkin_request->set_attestation_measurement(attestation_measurement_);
267   *checkin_request->mutable_protocol_options_request() =
268       CreateProtocolOptionsRequest(/* should_ack_checkin=*/false);
269 
270   if (task_eligibility_info.has_value()) {
271     *checkin_request->mutable_task_eligibility_info() = *task_eligibility_info;
272   }
273 
274   return Send(&client_stream_message);
275 }
276 
ReceiveCheckinRequestAck()277 absl::Status GrpcFederatedProtocol::ReceiveCheckinRequestAck() {
278   // Wait for a CheckinRequestAck.
279   ServerStreamMessage server_stream_message;
280   absl::Status receive_status = Receive(&server_stream_message);
281   if (receive_status.code() == absl::StatusCode::kNotFound) {
282     FCP_LOG(INFO) << "Server responded NOT_FOUND to checkin request, "
283                      "population name '"
284                   << population_name_ << "' is likely incorrect.";
285   }
286   FCP_RETURN_IF_ERROR(receive_status);
287 
288   if (!server_stream_message.has_checkin_request_ack()) {
289     log_manager_->LogDiag(
290         ProdDiagCode::
291             BACKGROUND_TRAINING_CHECKIN_REQUEST_ACK_EXPECTED_BUT_NOT_RECVD);
292     return absl::UnimplementedError(
293         "Requested but did not receive CheckinRequestAck");
294   }
295   log_manager_->LogDiag(
296       ProdDiagCode::BACKGROUND_TRAINING_CHECKIN_REQUEST_ACK_RECEIVED);
297   // Process the received CheckinRequestAck message.
298   const CheckinRequestAck& checkin_request_ack =
299       server_stream_message.checkin_request_ack();
300   if (!checkin_request_ack.has_retry_window_if_accepted() ||
301       !checkin_request_ack.has_retry_window_if_rejected()) {
302     return absl::UnimplementedError(
303         "Received CheckinRequestAck message with missing retry windows");
304   }
305   // Upon receiving the server's RetryWindows we immediately choose a concrete
306   // target timestamp to retry at. This ensures that a) clients of this class
307   // don't have to implement the logic to select a timestamp from a min/max
308   // range themselves, b) we tell clients of this class to come back at exactly
309   // a point in time the server intended us to come at (i.e. "now +
310   // server_specified_retry_period", and not a point in time that is partly
311   // determined by how long the remaining protocol interactions (e.g. training
312   // and results upload) will take (i.e. "now +
313   // duration_of_remaining_protocol_interactions +
314   // server_specified_retry_period").
315   checkin_request_ack_info_ = CheckinRequestAckInfo{
316       .retry_info_if_rejected =
317           RetryTimeAndToken{
318               PickRetryTimeFromRange(
319                   checkin_request_ack.retry_window_if_rejected().delay_min(),
320                   checkin_request_ack.retry_window_if_rejected().delay_max(),
321                   bit_gen_),
322               checkin_request_ack.retry_window_if_rejected().retry_token()},
323       .retry_info_if_accepted = RetryTimeAndToken{
324           PickRetryTimeFromRange(
325               checkin_request_ack.retry_window_if_accepted().delay_min(),
326               checkin_request_ack.retry_window_if_accepted().delay_max(),
327               bit_gen_),
328           checkin_request_ack.retry_window_if_accepted().retry_token()}};
329   return absl::OkStatus();
330 }
331 
332 absl::StatusOr<FederatedProtocol::EligibilityEvalCheckinResult>
ReceiveEligibilityEvalCheckinResponse(absl::Time start_time,std::function<void (const EligibilityEvalTask &)> payload_uris_received_callback)333 GrpcFederatedProtocol::ReceiveEligibilityEvalCheckinResponse(
334     absl::Time start_time, std::function<void(const EligibilityEvalTask&)>
335                                payload_uris_received_callback) {
336   ServerStreamMessage server_stream_message;
337   FCP_RETURN_IF_ERROR(Receive(&server_stream_message));
338 
339   if (!server_stream_message.has_eligibility_eval_checkin_response()) {
340     return absl::UnimplementedError(
341         absl::StrCat("Bad response to EligibilityEvalCheckinRequest; Expected "
342                      "EligibilityEvalCheckinResponse but got ",
343                      server_stream_message.kind_case(), "."));
344   }
345 
346   const EligibilityEvalCheckinResponse& eligibility_checkin_response =
347       server_stream_message.eligibility_eval_checkin_response();
348   switch (eligibility_checkin_response.checkin_result_case()) {
349     case EligibilityEvalCheckinResponse::kEligibilityEvalPayload: {
350       const EligibilityEvalPayload& eligibility_eval_payload =
351           eligibility_checkin_response.eligibility_eval_payload();
352       object_state_ = ObjectState::kEligibilityEvalEnabled;
353       EligibilityEvalTask result{.execution_id =
354                                      eligibility_eval_payload.execution_id()};
355 
356       payload_uris_received_callback(result);
357 
358       PlanAndCheckpointPayloads payloads;
359       if (http_client_ == nullptr ||
360           !flags_->enable_grpc_with_eligibility_eval_http_resource_support()) {
361         result.payloads = {
362             .plan = eligibility_eval_payload.plan(),
363             .checkpoint = eligibility_eval_payload.init_checkpoint()};
364       } else {
365         // Fetch the task resources, returning any errors that may be
366         // encountered in the process.
367         FCP_ASSIGN_OR_RETURN(
368             result.payloads,
369             FetchTaskResources(
370                 {.plan =
371                      {
372                          .has_uri =
373                              eligibility_eval_payload.has_plan_resource(),
374                          .uri = eligibility_eval_payload.plan_resource().uri(),
375                          .data = eligibility_eval_payload.plan(),
376                          .client_cache_id =
377                              eligibility_eval_payload.plan_resource()
378                                  .client_cache_id(),
379                          .max_age = TimeUtil::ConvertProtoToAbslDuration(
380                              eligibility_eval_payload.plan_resource()
381                                  .max_age()),
382                      },
383                  .checkpoint = {
384                      .has_uri = eligibility_eval_payload
385                                     .has_init_checkpoint_resource(),
386                      .uri = eligibility_eval_payload.init_checkpoint_resource()
387                                 .uri(),
388                      .data = eligibility_eval_payload.init_checkpoint(),
389                      .client_cache_id =
390                          eligibility_eval_payload.init_checkpoint_resource()
391                              .client_cache_id(),
392                      .max_age = TimeUtil::ConvertProtoToAbslDuration(
393                          eligibility_eval_payload.init_checkpoint_resource()
394                              .max_age()),
395                  }}));
396       }
397       return std::move(result);
398     }
399     case EligibilityEvalCheckinResponse::kNoEligibilityEvalConfigured: {
400       // Nothing to do...
401       object_state_ = ObjectState::kEligibilityEvalDisabled;
402       return EligibilityEvalDisabled{};
403     }
404     case EligibilityEvalCheckinResponse::kRejectionInfo: {
405       object_state_ = ObjectState::kEligibilityEvalCheckinRejected;
406       return Rejection{};
407     }
408     default:
409       return absl::UnimplementedError(
410           "Unrecognized EligibilityEvalCheckinResponse");
411   }
412 }
413 
414 absl::StatusOr<FederatedProtocol::CheckinResult>
ReceiveCheckinResponse(absl::Time start_time,std::function<void (const TaskAssignment &)> payload_uris_received_callback)415 GrpcFederatedProtocol::ReceiveCheckinResponse(
416     absl::Time start_time,
417     std::function<void(const TaskAssignment&)> payload_uris_received_callback) {
418   ServerStreamMessage server_stream_message;
419   absl::Status receive_status = Receive(&server_stream_message);
420   FCP_RETURN_IF_ERROR(receive_status);
421 
422   if (!server_stream_message.has_checkin_response()) {
423     return absl::UnimplementedError(absl::StrCat(
424         "Bad response to CheckinRequest; Expected CheckinResponse but got ",
425         server_stream_message.kind_case(), "."));
426   }
427 
428   const CheckinResponse& checkin_response =
429       server_stream_message.checkin_response();
430 
431   execution_phase_id_ =
432       checkin_response.has_acceptance_info()
433           ? checkin_response.acceptance_info().execution_phase_id()
434           : "";
435   switch (checkin_response.checkin_result_case()) {
436     case CheckinResponse::kAcceptanceInfo: {
437       const auto& acceptance_info = checkin_response.acceptance_info();
438 
439       for (const auto& [k, v] : acceptance_info.side_channels())
440         side_channels_[k] = v;
441       side_channel_protocol_execution_info_ =
442           acceptance_info.side_channel_protocol_execution_info();
443       side_channel_protocol_options_response_ =
444           checkin_response.protocol_options_response().side_channels();
445 
446       std::optional<SecAggInfo> sec_agg_info = std::nullopt;
447       if (side_channel_protocol_execution_info_.has_secure_aggregation()) {
448         sec_agg_info = SecAggInfo{
449             .expected_number_of_clients =
450                 side_channel_protocol_execution_info_.secure_aggregation()
451                     .expected_number_of_clients(),
452             .minimum_clients_in_server_visible_aggregate =
453                 side_channel_protocol_execution_info_.secure_aggregation()
454                     .minimum_clients_in_server_visible_aggregate()};
455       }
456 
457       TaskAssignment result{
458           .federated_select_uri_template =
459               acceptance_info.federated_select_uri_info().uri_template(),
460           .aggregation_session_id = acceptance_info.execution_phase_id(),
461           .sec_agg_info = sec_agg_info};
462 
463       payload_uris_received_callback(result);
464 
465       PlanAndCheckpointPayloads payloads;
466       if (http_client_ == nullptr) {
467         result.payloads = {.plan = acceptance_info.plan(),
468                            .checkpoint = acceptance_info.init_checkpoint()};
469       } else {
470         // Fetch the task resources, returning any errors that may be
471         // encountered in the process.
472         FCP_ASSIGN_OR_RETURN(
473             result.payloads,
474             FetchTaskResources(
475                 {.plan =
476                      {
477                          .has_uri = acceptance_info.has_plan_resource(),
478                          .uri = acceptance_info.plan_resource().uri(),
479                          .data = acceptance_info.plan(),
480                          .client_cache_id =
481                              acceptance_info.plan_resource().client_cache_id(),
482                          .max_age = TimeUtil::ConvertProtoToAbslDuration(
483                              acceptance_info.plan_resource().max_age()),
484                      },
485                  .checkpoint = {
486                      .has_uri = acceptance_info.has_init_checkpoint_resource(),
487                      .uri = acceptance_info.init_checkpoint_resource().uri(),
488                      .data = acceptance_info.init_checkpoint(),
489                      .client_cache_id =
490                          acceptance_info.init_checkpoint_resource()
491                              .client_cache_id(),
492                      .max_age = TimeUtil::ConvertProtoToAbslDuration(
493                          acceptance_info.init_checkpoint_resource().max_age()),
494                  }}));
495       }
496 
497       object_state_ = ObjectState::kCheckinAccepted;
498       return result;
499     }
500     case CheckinResponse::kRejectionInfo: {
501       object_state_ = ObjectState::kCheckinRejected;
502       return Rejection{};
503     }
504     default:
505       return absl::UnimplementedError("Unrecognized CheckinResponse");
506   }
507 }
508 
509 absl::StatusOr<FederatedProtocol::EligibilityEvalCheckinResult>
EligibilityEvalCheckin(std::function<void (const EligibilityEvalTask &)> payload_uris_received_callback)510 GrpcFederatedProtocol::EligibilityEvalCheckin(
511     std::function<void(const EligibilityEvalTask&)>
512         payload_uris_received_callback) {
513   FCP_CHECK(object_state_ == ObjectState::kInitialized)
514       << "Invalid call sequence";
515   object_state_ = ObjectState::kEligibilityEvalCheckinFailed;
516 
517   absl::Time start_time = absl::Now();
518 
519   // Send an EligibilityEvalCheckinRequest.
520   absl::Status request_status = SendEligibilityEvalCheckinRequest();
521   // See note about how we handle 'permanent' errors at the top of this file.
522   UpdateObjectStateIfPermanentError(
523       request_status, ObjectState::kEligibilityEvalCheckinFailedPermanentError);
524   FCP_RETURN_IF_ERROR(request_status);
525 
526   // Receive a CheckinRequestAck.
527   absl::Status ack_status = ReceiveCheckinRequestAck();
528   UpdateObjectStateIfPermanentError(
529       ack_status, ObjectState::kEligibilityEvalCheckinFailedPermanentError);
530   FCP_RETURN_IF_ERROR(ack_status);
531 
532   // Receive + handle an EligibilityEvalCheckinResponse message, and update the
533   // object state based on the received response.
534   auto response = ReceiveEligibilityEvalCheckinResponse(
535       start_time, payload_uris_received_callback);
536   UpdateObjectStateIfPermanentError(
537       response.status(),
538       ObjectState::kEligibilityEvalCheckinFailedPermanentError);
539   return response;
540 }
541 
542 // This is not supported in gRPC federated protocol, we'll do nothing.
ReportEligibilityEvalError(absl::Status error_status)543 void GrpcFederatedProtocol::ReportEligibilityEvalError(
544     absl::Status error_status) {}
545 
Checkin(const std::optional<TaskEligibilityInfo> & task_eligibility_info,std::function<void (const TaskAssignment &)> payload_uris_received_callback)546 absl::StatusOr<FederatedProtocol::CheckinResult> GrpcFederatedProtocol::Checkin(
547     const std::optional<TaskEligibilityInfo>& task_eligibility_info,
548     std::function<void(const TaskAssignment&)> payload_uris_received_callback) {
549   // Checkin(...) must follow an earlier call to EligibilityEvalCheckin() that
550   // resulted in a CheckinResultPayload or an EligibilityEvalDisabled result.
551   FCP_CHECK(object_state_ == ObjectState::kEligibilityEvalDisabled ||
552             object_state_ == ObjectState::kEligibilityEvalEnabled)
553       << "Checkin(...) called despite failed/rejected earlier "
554          "EligibilityEvalCheckin";
555   if (object_state_ == ObjectState::kEligibilityEvalEnabled) {
556     FCP_CHECK(task_eligibility_info.has_value())
557         << "Missing TaskEligibilityInfo despite receiving prior "
558            "EligibilityEvalCheckin payload";
559   } else {
560     FCP_CHECK(!task_eligibility_info.has_value())
561         << "Received TaskEligibilityInfo despite not receiving a prior "
562            "EligibilityEvalCheckin payload";
563   }
564 
565   object_state_ = ObjectState::kCheckinFailed;
566 
567   absl::Time start_time = absl::Now();
568   // Send a CheckinRequest.
569   absl::Status request_status = SendCheckinRequest(task_eligibility_info);
570   // See note about how we handle 'permanent' errors at the top of this file.
571   UpdateObjectStateIfPermanentError(request_status,
572                                     ObjectState::kCheckinFailedPermanentError);
573   FCP_RETURN_IF_ERROR(request_status);
574 
575   // Receive + handle a CheckinResponse message, and update the object state
576   // based on the received response.
577   auto response =
578       ReceiveCheckinResponse(start_time, payload_uris_received_callback);
579   UpdateObjectStateIfPermanentError(response.status(),
580                                     ObjectState::kCheckinFailedPermanentError);
581   return response;
582 }
583 
584 absl::StatusOr<FederatedProtocol::MultipleTaskAssignments>
PerformMultipleTaskAssignments(const std::vector<std::string> & task_names)585 GrpcFederatedProtocol::PerformMultipleTaskAssignments(
586     const std::vector<std::string>& task_names) {
587   return absl::UnimplementedError(
588       "PerformMultipleTaskAssignments is not supported by "
589       "GrpcFederatedProtocol.");
590 }
591 
ReportCompleted(ComputationResults results,absl::Duration plan_duration,std::optional<std::string> aggregation_session_id)592 absl::Status GrpcFederatedProtocol::ReportCompleted(
593     ComputationResults results, absl::Duration plan_duration,
594     std::optional<std::string> aggregation_session_id) {
595   FCP_LOG(INFO) << "Reporting outcome: " << static_cast<int>(engine::COMPLETED);
596   FCP_CHECK(object_state_ == ObjectState::kCheckinAccepted)
597       << "Invalid call sequence";
598   object_state_ = ObjectState::kReportCalled;
599   auto response = Report(std::move(results), engine::COMPLETED, plan_duration);
600   // See note about how we handle 'permanent' errors at the top of this file.
601   UpdateObjectStateIfPermanentError(response,
602                                     ObjectState::kReportFailedPermanentError);
603   return response;
604 }
605 
ReportNotCompleted(engine::PhaseOutcome phase_outcome,absl::Duration plan_duration,std::optional<std::string> aggregation_session_Id)606 absl::Status GrpcFederatedProtocol::ReportNotCompleted(
607     engine::PhaseOutcome phase_outcome, absl::Duration plan_duration,
608     std::optional<std::string> aggregation_session_Id) {
609   FCP_LOG(WARNING) << "Reporting outcome: " << static_cast<int>(phase_outcome);
610   FCP_CHECK(object_state_ == ObjectState::kCheckinAccepted)
611       << "Invalid call sequence";
612   object_state_ = ObjectState::kReportCalled;
613   ComputationResults results;
614   results.emplace("tensorflow_checkpoint", "");
615   auto response = Report(std::move(results), phase_outcome, plan_duration);
616   // See note about how we handle 'permanent' errors at the top of this file.
617   UpdateObjectStateIfPermanentError(response,
618                                     ObjectState::kReportFailedPermanentError);
619   return response;
620 }
621 
622 class GrpcSecAggSendToServerImpl : public SecAggSendToServerBase {
623  public:
GrpcSecAggSendToServerImpl(GrpcBidiStreamInterface * grpc_bidi_stream,const std::function<absl::Status (ClientToServerWrapperMessage *)> & report_func)624   GrpcSecAggSendToServerImpl(
625       GrpcBidiStreamInterface* grpc_bidi_stream,
626       const std::function<absl::Status(ClientToServerWrapperMessage*)>&
627           report_func)
628       : grpc_bidi_stream_(grpc_bidi_stream), report_func_(report_func) {}
629   ~GrpcSecAggSendToServerImpl() override = default;
630 
Send(ClientToServerWrapperMessage * message)631   void Send(ClientToServerWrapperMessage* message) override {
632     // The commit message (MaskedInputRequest) must be piggy-backed onto the
633     // ReportRequest message, the logic for which is encapsulated in
634     // report_func_ so that it may be held in common between both accumulation
635     // methods.
636     if (message->message_content_case() ==
637         ClientToServerWrapperMessage::MessageContentCase::
638             kMaskedInputResponse) {
639       auto status = report_func_(message);
640       if (!status.ok())
641         FCP_LOG(ERROR) << "Could not send ReportRequest: " << status;
642       return;
643     }
644     ClientStreamMessage client_stream_message;
645     client_stream_message.mutable_secure_aggregation_client_message()->Swap(
646         message);
647     auto bytes_to_upload = client_stream_message.ByteSizeLong();
648     auto status = grpc_bidi_stream_->Send(&client_stream_message);
649     if (status.ok()) {
650       last_sent_message_size_ = bytes_to_upload;
651     }
652   }
653 
654  private:
655   GrpcBidiStreamInterface* grpc_bidi_stream_;
656   // SecAgg's output must be wrapped in a ReportRequest; because the report
657   // logic is mostly generic, this lambda allows it to be shared between
658   // aggregation types.
659   const std::function<absl::Status(ClientToServerWrapperMessage*)>&
660       report_func_;
661 };
662 
663 class GrpcSecAggProtocolDelegate : public SecAggProtocolDelegate {
664  public:
GrpcSecAggProtocolDelegate(absl::flat_hash_map<std::string,SideChannelExecutionInfo> side_channels,GrpcBidiStreamInterface * grpc_bidi_stream)665   GrpcSecAggProtocolDelegate(
666       absl::flat_hash_map<std::string, SideChannelExecutionInfo> side_channels,
667       GrpcBidiStreamInterface* grpc_bidi_stream)
668       : side_channels_(std::move(side_channels)),
669         grpc_bidi_stream_(grpc_bidi_stream) {}
670 
GetModulus(const std::string & key)671   absl::StatusOr<uint64_t> GetModulus(const std::string& key) override {
672     auto execution_info = side_channels_.find(key);
673     if (execution_info == side_channels_.end())
674       return absl::InternalError(
675           absl::StrCat("Execution not found for aggregand: ", key));
676     uint64_t modulus;
677     auto secure_aggregand = execution_info->second.secure_aggregand();
678     // TODO(team): Delete output_bitwidth support once
679     // modulus is fully rolled out.
680     if (secure_aggregand.modulus() > 0) {
681       modulus = secure_aggregand.modulus();
682     } else {
683       // Note: we ignore vector.get_bitwidth() here, because (1)
684       // it is only an upper bound on the *input* bitwidth,
685       // based on the Tensorflow dtype, but (2) we have exact
686       // *output* bitwidth information from the execution_info,
687       // and that is what SecAgg needs.
688       modulus = 1ULL << secure_aggregand.output_bitwidth();
689     }
690     return modulus;
691   }
692 
ReceiveServerMessage()693   absl::StatusOr<secagg::ServerToClientWrapperMessage> ReceiveServerMessage()
694       override {
695     ServerStreamMessage server_stream_message;
696     absl::Status receive_status =
697         grpc_bidi_stream_->Receive(&server_stream_message);
698     if (!receive_status.ok()) {
699       return absl::Status(receive_status.code(),
700                           absl::StrCat("Error during SecAgg receive: ",
701                                        receive_status.message()));
702     }
703     last_received_message_size_ = server_stream_message.ByteSizeLong();
704     if (!server_stream_message.has_secure_aggregation_server_message()) {
705       return absl::InternalError(
706           absl::StrCat("Bad response to SecAgg protocol; Expected "
707                        "ServerToClientWrapperMessage but got ",
708                        server_stream_message.kind_case(), "."));
709     }
710     return server_stream_message.secure_aggregation_server_message();
711   }
712 
Abort()713   void Abort() override { grpc_bidi_stream_->Close(); }
last_received_message_size()714   size_t last_received_message_size() override {
715     return last_received_message_size_;
716   };
717 
718  private:
719   absl::flat_hash_map<std::string, SideChannelExecutionInfo> side_channels_;
720   GrpcBidiStreamInterface* grpc_bidi_stream_;
721   size_t last_received_message_size_;
722 };
723 
ReportInternal(std::string tf_checkpoint,engine::PhaseOutcome phase_outcome,absl::Duration plan_duration,ClientToServerWrapperMessage * secagg_commit_message)724 absl::Status GrpcFederatedProtocol::ReportInternal(
725     std::string tf_checkpoint, engine::PhaseOutcome phase_outcome,
726     absl::Duration plan_duration,
727     ClientToServerWrapperMessage* secagg_commit_message) {
728   ClientStreamMessage client_stream_message;
729   auto report_request = client_stream_message.mutable_report_request();
730   report_request->set_population_name(population_name_);
731   report_request->set_execution_phase_id(execution_phase_id_);
732   auto report = report_request->mutable_report();
733 
734   // 1. Include TF checkpoint and/or SecAgg commit message.
735   report->set_update_checkpoint(std::move(tf_checkpoint));
736   if (secagg_commit_message) {
737     client_stream_message.mutable_secure_aggregation_client_message()->Swap(
738         secagg_commit_message);
739   }
740 
741   // 2. Include outcome of computation.
742   report->set_status_code(phase_outcome == engine::COMPLETED
743                               ? google::rpc::OK
744                               : google::rpc::INTERNAL);
745 
746   // 3. Include client execution statistics, if any.
747   ClientExecutionStats client_execution_stats;
748   client_execution_stats.mutable_duration()->set_seconds(
749       absl::IDivDuration(plan_duration, absl::Seconds(1), &plan_duration));
750   client_execution_stats.mutable_duration()->set_nanos(static_cast<int32_t>(
751       absl::IDivDuration(plan_duration, absl::Nanoseconds(1), &plan_duration)));
752   report->add_serialized_train_event()->PackFrom(client_execution_stats);
753 
754   // 4. Send ReportRequest.
755 
756   // Note that we do not use the GrpcFederatedProtocol::Send(...) helper method
757   // here, since we are already running within a call to
758   // InterruptibleRunner::Run.
759   const auto status = this->grpc_bidi_stream_->Send(&client_stream_message);
760   if (!status.ok()) {
761     return absl::Status(
762         status.code(),
763         absl::StrCat("Error sending ReportRequest: ", status.message()));
764   }
765 
766   return absl::OkStatus();
767 }
768 
Report(ComputationResults results,engine::PhaseOutcome phase_outcome,absl::Duration plan_duration)769 absl::Status GrpcFederatedProtocol::Report(ComputationResults results,
770                                            engine::PhaseOutcome phase_outcome,
771                                            absl::Duration plan_duration) {
772   std::string tf_checkpoint;
773   bool has_checkpoint;
774   for (auto& [k, v] : results) {
775     if (std::holds_alternative<TFCheckpoint>(v)) {
776       tf_checkpoint = std::get<TFCheckpoint>(std::move(v));
777       has_checkpoint = true;
778       break;
779     }
780   }
781 
782   // This lambda allows for convenient reporting from within SecAgg's
783   // SendToServerInterface::Send().
784   std::function<absl::Status(ClientToServerWrapperMessage*)> report_lambda =
785       [&](ClientToServerWrapperMessage* secagg_commit_message) -> absl::Status {
786     return ReportInternal(std::move(tf_checkpoint), phase_outcome,
787                           plan_duration, secagg_commit_message);
788   };
789 
790   // Run the Secure Aggregation protocol, if necessary.
791   if (side_channel_protocol_execution_info_.has_secure_aggregation()) {
792     auto secure_aggregation_protocol_execution_info =
793         side_channel_protocol_execution_info_.secure_aggregation();
794     auto expected_number_of_clients =
795         secure_aggregation_protocol_execution_info.expected_number_of_clients();
796 
797     FCP_LOG(INFO) << "Reporting via Secure Aggregation";
798     if (phase_outcome != engine::COMPLETED)
799       return absl::InternalError(
800           "Aborting the SecAgg protocol (no update was produced).");
801 
802     if (side_channel_protocol_options_response_.secure_aggregation()
803             .client_variant() != secagg::SECAGG_CLIENT_VARIANT_NATIVE_V1) {
804       log_manager_->LogDiag(
805           ProdDiagCode::SECAGG_CLIENT_ERROR_UNSUPPORTED_VERSION);
806       return absl::InternalError(absl::StrCat(
807           "Unsupported SecAgg client variant: ",
808           side_channel_protocol_options_response_.secure_aggregation()
809               .client_variant()));
810     }
811 
812     auto send_to_server_impl = std::make_unique<GrpcSecAggSendToServerImpl>(
813         grpc_bidi_stream_.get(), report_lambda);
814     auto secagg_event_publisher = event_publisher_->secagg_event_publisher();
815     FCP_CHECK(secagg_event_publisher)
816         << "An implementation of "
817         << "SecAggEventPublisher must be provided.";
818     auto delegate = std::make_unique<GrpcSecAggProtocolDelegate>(
819         side_channels_, grpc_bidi_stream_.get());
820     std::unique_ptr<SecAggRunner> secagg_runner =
821         secagg_runner_factory_->CreateSecAggRunner(
822             std::move(send_to_server_impl), std::move(delegate),
823             secagg_event_publisher, log_manager_, interruptible_runner_.get(),
824             expected_number_of_clients,
825             secure_aggregation_protocol_execution_info
826                 .minimum_surviving_clients_for_reconstruction());
827 
828     FCP_RETURN_IF_ERROR(secagg_runner->Run(std::move(results)));
829   } else {
830     // Report without secure aggregation.
831     FCP_LOG(INFO) << "Reporting via Simple Aggregation";
832     if (results.size() != 1 || !has_checkpoint) {
833       return absl::InternalError(
834           "Simple Aggregation aggregands have unexpected format.");
835     }
836     FCP_RETURN_IF_ERROR(interruptible_runner_->Run(
837         [&report_lambda]() { return report_lambda(nullptr); },
838         [this]() {
839           // What about event_publisher_ and log_manager_?
840           this->grpc_bidi_stream_->Close();
841         }));
842   }
843 
844   FCP_LOG(INFO) << "Finished reporting.";
845 
846   // Receive ReportResponse.
847   ServerStreamMessage server_stream_message;
848   absl::Status receive_status = Receive(&server_stream_message);
849   if (receive_status.code() == absl::StatusCode::kAborted) {
850     FCP_LOG(INFO) << "Server responded ABORTED.";
851   } else if (receive_status.code() == absl::StatusCode::kCancelled) {
852     FCP_LOG(INFO) << "Upload was cancelled by the client.";
853   }
854   if (!receive_status.ok()) {
855     return absl::Status(
856         receive_status.code(),
857         absl::StrCat("Error after ReportRequest: ", receive_status.message()));
858   }
859   if (!server_stream_message.has_report_response()) {
860     return absl::UnimplementedError(absl::StrCat(
861         "Bad response to ReportRequest; Expected REPORT_RESPONSE but got ",
862         server_stream_message.kind_case(), "."));
863   }
864   return absl::OkStatus();
865 }
866 
GetLatestRetryWindow()867 RetryWindow GrpcFederatedProtocol::GetLatestRetryWindow() {
868   // We explicitly enumerate all possible states here rather than using
869   // "default", to ensure that when new states are added later on, the author
870   // is forced to update this method and consider which is the correct
871   // RetryWindow to return.
872   switch (object_state_) {
873     case ObjectState::kCheckinAccepted:
874     case ObjectState::kReportCalled:
875       // If a client makes it past the 'checkin acceptance' stage, we use the
876       // 'accepted' RetryWindow unconditionally (unless a permanent error is
877       // encountered). This includes cases where the checkin is accepted, but
878       // the report request results in a (transient) error.
879       FCP_CHECK(checkin_request_ack_info_.has_value());
880       return GenerateRetryWindowFromRetryTimeAndToken(
881           checkin_request_ack_info_->retry_info_if_accepted);
882     case ObjectState::kEligibilityEvalCheckinRejected:
883     case ObjectState::kEligibilityEvalDisabled:
884     case ObjectState::kEligibilityEvalEnabled:
885     case ObjectState::kCheckinRejected:
886       FCP_CHECK(checkin_request_ack_info_.has_value());
887       return GenerateRetryWindowFromRetryTimeAndToken(
888           checkin_request_ack_info_->retry_info_if_rejected);
889     case ObjectState::kInitialized:
890     case ObjectState::kEligibilityEvalCheckinFailed:
891     case ObjectState::kCheckinFailed:
892       // If the flag is true, then we use the previously chosen absolute retry
893       // time instead (if available).
894       if (checkin_request_ack_info_.has_value()) {
895         // If we already received a server-provided retry window, then use it.
896         return GenerateRetryWindowFromRetryTimeAndToken(
897             checkin_request_ack_info_->retry_info_if_rejected);
898       }
899       // Otherwise, we generate a retry window using the flag-provided transient
900       // error retry period.
901       return GenerateRetryWindowFromTargetDelay(
902           absl::Seconds(
903               flags_->federated_training_transient_errors_retry_delay_secs()),
904           // NOLINTBEGIN(whitespace/line_length)
905           flags_
906               ->federated_training_transient_errors_retry_delay_jitter_percent(),
907           // NOLINTEND
908           bit_gen_);
909     case ObjectState::kEligibilityEvalCheckinFailedPermanentError:
910     case ObjectState::kCheckinFailedPermanentError:
911     case ObjectState::kReportFailedPermanentError:
912       // If we encountered a permanent error during the eligibility eval or
913       // regular checkins, then we use the Flags-configured 'permanent error'
914       // retry period. Note that we do so regardless of whether the server had,
915       // by the time the permanent error was received, already returned a
916       // CheckinRequestAck containing a set of retry windows. See note on error
917       // handling at the top of this file.
918       return GenerateRetryWindowFromTargetDelay(
919           absl::Seconds(
920               flags_->federated_training_permanent_errors_retry_delay_secs()),
921           // NOLINTBEGIN(whitespace/line_length)
922           flags_
923               ->federated_training_permanent_errors_retry_delay_jitter_percent(),
924           // NOLINTEND
925           bit_gen_);
926     case ObjectState::kMultipleTaskAssignmentsAccepted:
927     case ObjectState::kMultipleTaskAssignmentsFailed:
928     case ObjectState::kMultipleTaskAssignmentsFailedPermanentError:
929     case ObjectState::kMultipleTaskAssignmentsNoAvailableTask:
930     case ObjectState::kReportMultipleTaskPartialError:
931       FCP_LOG(FATAL) << "Multi-task assignments is not supported by gRPC.";
932       RetryWindow retry_window;
933       return retry_window;
934   }
935 }
936 
937 // Converts the given RetryTimeAndToken to a zero-width RetryWindow (where
938 // delay_min and delay_max are set to the same value), by converting the target
939 // retry time to a delay relative to the current timestamp.
GenerateRetryWindowFromRetryTimeAndToken(const GrpcFederatedProtocol::RetryTimeAndToken & retry_info)940 RetryWindow GrpcFederatedProtocol::GenerateRetryWindowFromRetryTimeAndToken(
941     const GrpcFederatedProtocol::RetryTimeAndToken& retry_info) {
942   // Generate a RetryWindow with delay_min and delay_max both set to the same
943   // value.
944   RetryWindow retry_window =
945       GenerateRetryWindowFromRetryTime(retry_info.retry_time);
946   retry_window.set_retry_token(retry_info.retry_token);
947   return retry_window;
948 }
949 
UpdateObjectStateIfPermanentError(absl::Status status,GrpcFederatedProtocol::ObjectState permanent_error_object_state)950 void GrpcFederatedProtocol::UpdateObjectStateIfPermanentError(
951     absl::Status status,
952     GrpcFederatedProtocol::ObjectState permanent_error_object_state) {
953   if (federated_training_permanent_error_codes_.contains(
954           static_cast<int32_t>(status.code()))) {
955     object_state_ = permanent_error_object_state;
956   }
957 }
958 
959 absl::StatusOr<FederatedProtocol::PlanAndCheckpointPayloads>
FetchTaskResources(GrpcFederatedProtocol::TaskResources task_resources)960 GrpcFederatedProtocol::FetchTaskResources(
961     GrpcFederatedProtocol::TaskResources task_resources) {
962   FCP_ASSIGN_OR_RETURN(UriOrInlineData plan_uri_or_data,
963                        ConvertResourceToUriOrInlineData(task_resources.plan));
964   FCP_ASSIGN_OR_RETURN(
965       UriOrInlineData checkpoint_uri_or_data,
966       ConvertResourceToUriOrInlineData(task_resources.checkpoint));
967 
968   // Log a diag code if either resource is about to be downloaded via HTTP.
969   if (!plan_uri_or_data.uri().uri.empty() ||
970       !checkpoint_uri_or_data.uri().uri.empty()) {
971     log_manager_->LogDiag(
972         ProdDiagCode::HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_USES_HTTP);
973   }
974 
975   // Fetch the plan and init checkpoint resources if they need to be fetched
976   // (using the inline data instead if available).
977   absl::StatusOr<
978       std::vector<absl::StatusOr<::fcp::client::http::InMemoryHttpResponse>>>
979       resource_responses;
980   {
981     auto started_stopwatch = network_stopwatch_->Start();
982     resource_responses = ::fcp::client::http::FetchResourcesInMemory(
983         *http_client_, *interruptible_runner_,
984         {plan_uri_or_data, checkpoint_uri_or_data}, &http_bytes_downloaded_,
985         &http_bytes_uploaded_, resource_cache_);
986   }
987   if (!resource_responses.ok()) {
988     log_manager_->LogDiag(
989         ProdDiagCode::
990             HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_HTTP_FETCH_FAILED);
991     return resource_responses.status();
992   }
993   auto& plan_data_response = (*resource_responses)[0];
994   auto& checkpoint_data_response = (*resource_responses)[1];
995 
996   if (!plan_data_response.ok() || !checkpoint_data_response.ok()) {
997     log_manager_->LogDiag(
998         ProdDiagCode::
999             HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_HTTP_FETCH_FAILED);
1000   }
1001   // Note: we forward any error during the fetching of the plan/checkpoint
1002   // resources resources to the caller, which means that these error codes
1003   // will be checked against the set of 'permanent' error codes, just like the
1004   // errors in response to the protocol request are.
1005   if (!plan_data_response.ok()) {
1006     return absl::Status(plan_data_response.status().code(),
1007                         absl::StrCat("plan fetch failed: ",
1008                                      plan_data_response.status().ToString()));
1009   }
1010   if (!checkpoint_data_response.ok()) {
1011     return absl::Status(
1012         checkpoint_data_response.status().code(),
1013         absl::StrCat("checkpoint fetch failed: ",
1014                      checkpoint_data_response.status().ToString()));
1015   }
1016   if (!plan_uri_or_data.uri().uri.empty() ||
1017       !checkpoint_uri_or_data.uri().uri.empty()) {
1018     // We only want to log this diag code when we actually did fetch something
1019     // via HTTP.
1020     log_manager_->LogDiag(
1021         ProdDiagCode::
1022             HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_HTTP_FETCH_SUCCEEDED);
1023   }
1024 
1025   return PlanAndCheckpointPayloads{plan_data_response->body,
1026                                    checkpoint_data_response->body};
1027 }
1028 
1029 // Convert a Resource proto into a UriOrInlineData object. Returns an
1030 // `INVALID_ARGUMENT` error if the given `Resource` has the `uri` field set to
1031 // an empty value, or an `UNIMPLEMENTED` error if the `Resource` has an unknown
1032 // field set.
1033 absl::StatusOr<UriOrInlineData>
ConvertResourceToUriOrInlineData(const GrpcFederatedProtocol::TaskResource & resource)1034 GrpcFederatedProtocol::ConvertResourceToUriOrInlineData(
1035     const GrpcFederatedProtocol::TaskResource& resource) {
1036   // We need to support 3 states:
1037   // - Inline data is available.
1038   // - No inline data nor is there a URI. This should be treated as there being
1039   //   an 'empty' inline data.
1040   // - No inline data is available but a URI is available.
1041   if (!resource.has_uri) {
1042     // If the URI field wasn't set, then we'll just use the inline data field
1043     // (which will either be set or be empty).
1044     //
1045     // Note: this copies the data into the new absl::Cord. However, this Cord is
1046     // then passed around all the way to fl_runner.cc without copying its data,
1047     // so this is ultimately approx. as efficient as the non-HTTP resource code
1048     // path where we also make a copy of the protobuf string into a new string
1049     // which is then returned.
1050     return UriOrInlineData::CreateInlineData(
1051         absl::Cord(resource.data),
1052         UriOrInlineData::InlineData::CompressionFormat::kUncompressed);
1053   }
1054   if (resource.uri.empty()) {
1055     return absl::InvalidArgumentError(
1056         "Resource uri must be non-empty when set");
1057   }
1058   return UriOrInlineData::CreateUri(resource.uri, resource.client_cache_id,
1059                                     resource.max_age);
1060 }
1061 
GetNetworkStats()1062 NetworkStats GrpcFederatedProtocol::GetNetworkStats() {
1063   // Note: the `HttpClient` bandwidth stats are similar to the gRPC protocol's
1064   // "chunking layer" stats, in that they reflect as closely as possible the
1065   // amount of data sent on the wire.
1066   return {.bytes_downloaded = grpc_bidi_stream_->ChunkingLayerBytesReceived() +
1067                               http_bytes_downloaded_,
1068           .bytes_uploaded = grpc_bidi_stream_->ChunkingLayerBytesSent() +
1069                             http_bytes_uploaded_,
1070           .network_duration = network_stopwatch_->GetTotalDuration()};
1071 }
1072 
1073 }  // namespace client
1074 }  // namespace fcp
1075