1 /*
2 * Copyright 2020 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "fcp/client/grpc_federated_protocol.h"
17
18 #include <algorithm>
19 #include <functional>
20 #include <memory>
21 #include <optional>
22 #include <string>
23 #include <utility>
24 #include <variant>
25
26 #include "google/protobuf/duration.pb.h"
27 #include "absl/status/status.h"
28 #include "absl/status/statusor.h"
29 #include "absl/time/time.h"
30 #include "absl/types/span.h"
31 #include "fcp/base/monitoring.h"
32 #include "fcp/base/time_util.h"
33 #include "fcp/client/diag_codes.pb.h"
34 #include "fcp/client/engine/engine.pb.h"
35 #include "fcp/client/event_publisher.h"
36 #include "fcp/client/federated_protocol.h"
37 #include "fcp/client/federated_protocol_util.h"
38 #include "fcp/client/fl_runner.pb.h"
39 #include "fcp/client/flags.h"
40 #include "fcp/client/grpc_bidi_stream.h"
41 #include "fcp/client/http/http_client.h"
42 #include "fcp/client/http/in_memory_request_response.h"
43 #include "fcp/client/interruptible_runner.h"
44 #include "fcp/client/log_manager.h"
45 #include "fcp/client/opstats/opstats_logger.h"
46 #include "fcp/client/secagg_event_publisher.h"
47 #include "fcp/client/secagg_runner.h"
48 #include "fcp/client/stats.h"
49 #include "fcp/protos/federated_api.pb.h"
50 #include "fcp/protos/plan.pb.h"
51 #include "fcp/secagg/client/secagg_client.h"
52 #include "fcp/secagg/client/send_to_server_interface.h"
53 #include "fcp/secagg/client/state_transition_listener_interface.h"
54 #include "fcp/secagg/shared/aes_ctr_prng_factory.h"
55 #include "fcp/secagg/shared/crypto_rand_prng.h"
56 #include "fcp/secagg/shared/input_vector_specification.h"
57 #include "fcp/secagg/shared/math.h"
58 #include "fcp/secagg/shared/secagg_messages.pb.h"
59 #include "fcp/secagg/shared/secagg_vector.h"
60
61 namespace fcp {
62 namespace client {
63
64 using ::fcp::client::http::UriOrInlineData;
65 using ::fcp::secagg::ClientToServerWrapperMessage;
66 using ::google::internal::federatedml::v2::CheckinRequest;
67 using ::google::internal::federatedml::v2::CheckinRequestAck;
68 using ::google::internal::federatedml::v2::CheckinResponse;
69 using ::google::internal::federatedml::v2::ClientExecutionStats;
70 using ::google::internal::federatedml::v2::ClientStreamMessage;
71 using ::google::internal::federatedml::v2::EligibilityEvalCheckinRequest;
72 using ::google::internal::federatedml::v2::EligibilityEvalCheckinResponse;
73 using ::google::internal::federatedml::v2::EligibilityEvalPayload;
74 using ::google::internal::federatedml::v2::HttpCompressionFormat;
75 using ::google::internal::federatedml::v2::ProtocolOptionsRequest;
76 using ::google::internal::federatedml::v2::RetryWindow;
77 using ::google::internal::federatedml::v2::ServerStreamMessage;
78 using ::google::internal::federatedml::v2::SideChannelExecutionInfo;
79 using ::google::internal::federatedml::v2::TaskEligibilityInfo;
80
81 // A note on error handling:
82 //
83 // The implementation here makes a distinction between what we call 'transient'
84 // and 'permanent' errors. While the exact categorization of transient vs.
85 // permanent errors is defined by a flag, the intent is that transient errors
86 // are those types of errors that may occur in the regular course of business,
87 // e.g. due to an interrupted network connection, a load balancer temporarily
88 // rejecting our request etc. Generally, these are expected to be resolvable by
89 // merely retrying the request at a slightly later time. Permanent errors are
90 // intended to be those that are not expected to be resolvable as quickly or by
91 // merely retrying the request. E.g. if a client checks in to the server with a
92 // population name that doesn't exist, then the server may return NOT_FOUND, and
93 // until the server-side configuration is changed, it will continue returning
94 // such an error. Hence, such errors can warrant a longer retry period (to waste
95 // less of both the client's and server's resources).
96 //
97 // The errors also differ in how they interact with the server-specified retry
98 // windows that are returned via the CheckinRequestAck message.
99 // - If a permanent error occurs, then we will always return a retry window
100 // based on the target 'permanent errors retry period' flag, regardless of
101 // whether we received a CheckinRequestAck from the server at an earlier time.
102 // - If a transient error occurs, then we will only return a retry window
103 // based on the target 'transient errors retry period' flag if the server
104 // didn't already return a CheckinRequestAck. If it did return such an ack,
105 // then one of the retry windows in that message will be used instead.
106 //
107 // Finally, note that for simplicity's sake we generally check whether a
108 // permanent error was received at the level of this class's public method,
109 // rather than deeper down in each of our helper methods that actually call
110 // directly into the gRPC stack. This keeps our state-managing code simpler, but
111 // does mean that if any of our helper methods like SendCheckinRequest produce a
112 // permanent error code locally (i.e. without it being sent by the server), it
113 // will be treated as if the server sent it and the permanent error retry period
114 // will be used. We consider this a reasonable tradeoff.
115
GrpcFederatedProtocol(EventPublisher * event_publisher,LogManager * log_manager,std::unique_ptr<SecAggRunnerFactory> secagg_runner_factory,const Flags * flags,::fcp::client::http::HttpClient * http_client,const std::string & federated_service_uri,const std::string & api_key,const std::string & test_cert_path,absl::string_view population_name,absl::string_view retry_token,absl::string_view client_version,absl::string_view attestation_measurement,std::function<bool ()> should_abort,const InterruptibleRunner::TimingConfig & timing_config,const int64_t grpc_channel_deadline_seconds,cache::ResourceCache * resource_cache)116 GrpcFederatedProtocol::GrpcFederatedProtocol(
117 EventPublisher* event_publisher, LogManager* log_manager,
118 std::unique_ptr<SecAggRunnerFactory> secagg_runner_factory,
119 const Flags* flags, ::fcp::client::http::HttpClient* http_client,
120 const std::string& federated_service_uri, const std::string& api_key,
121 const std::string& test_cert_path, absl::string_view population_name,
122 absl::string_view retry_token, absl::string_view client_version,
123 absl::string_view attestation_measurement,
124 std::function<bool()> should_abort,
125 const InterruptibleRunner::TimingConfig& timing_config,
126 const int64_t grpc_channel_deadline_seconds,
127 cache::ResourceCache* resource_cache)
128 : GrpcFederatedProtocol(
129 event_publisher, log_manager, std::move(secagg_runner_factory), flags,
130 http_client,
131 std::make_unique<GrpcBidiStream>(
132 federated_service_uri, api_key, std::string(population_name),
133 grpc_channel_deadline_seconds, test_cert_path),
134 population_name, retry_token, client_version, attestation_measurement,
135 should_abort, absl::BitGen(), timing_config, resource_cache) {}
136
GrpcFederatedProtocol(EventPublisher * event_publisher,LogManager * log_manager,std::unique_ptr<SecAggRunnerFactory> secagg_runner_factory,const Flags * flags,::fcp::client::http::HttpClient * http_client,std::unique_ptr<GrpcBidiStreamInterface> grpc_bidi_stream,absl::string_view population_name,absl::string_view retry_token,absl::string_view client_version,absl::string_view attestation_measurement,std::function<bool ()> should_abort,absl::BitGen bit_gen,const InterruptibleRunner::TimingConfig & timing_config,cache::ResourceCache * resource_cache)137 GrpcFederatedProtocol::GrpcFederatedProtocol(
138 EventPublisher* event_publisher, LogManager* log_manager,
139 std::unique_ptr<SecAggRunnerFactory> secagg_runner_factory,
140 const Flags* flags, ::fcp::client::http::HttpClient* http_client,
141 std::unique_ptr<GrpcBidiStreamInterface> grpc_bidi_stream,
142 absl::string_view population_name, absl::string_view retry_token,
143 absl::string_view client_version, absl::string_view attestation_measurement,
144 std::function<bool()> should_abort, absl::BitGen bit_gen,
145 const InterruptibleRunner::TimingConfig& timing_config,
146 cache::ResourceCache* resource_cache)
147 : object_state_(ObjectState::kInitialized),
148 event_publisher_(event_publisher),
149 log_manager_(log_manager),
150 secagg_runner_factory_(std::move(secagg_runner_factory)),
151 flags_(flags),
152 http_client_(http_client),
153 grpc_bidi_stream_(std::move(grpc_bidi_stream)),
154 population_name_(population_name),
155 retry_token_(retry_token),
156 client_version_(client_version),
157 attestation_measurement_(attestation_measurement),
158 bit_gen_(std::move(bit_gen)),
159 resource_cache_(resource_cache) {
160 interruptible_runner_ = std::make_unique<InterruptibleRunner>(
161 log_manager, should_abort, timing_config,
162 InterruptibleRunner::DiagnosticsConfig{
163 .interrupted = ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_GRPC,
164 .interrupt_timeout =
165 ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_GRPC_TIMED_OUT,
166 .interrupted_extended = ProdDiagCode::
167 BACKGROUND_TRAINING_INTERRUPT_GRPC_EXTENDED_COMPLETED,
168 .interrupt_timeout_extended = ProdDiagCode::
169 BACKGROUND_TRAINING_INTERRUPT_GRPC_EXTENDED_TIMED_OUT});
170 // Note that we could cast the provided error codes to absl::StatusCode
171 // values here. However, that means we'd have to handle the case when
172 // invalid integers that don't map to a StatusCode enum are provided in the
173 // flag here. Instead, we cast absl::StatusCodes to int32_t each time we
174 // compare them with the flag-provided list of codes, which means we never
175 // have to worry about invalid flag values (besides the fact that invalid
176 // values will be silently ignored, which could make it harder to realize when
177 // flag is misconfigured).
178 const std::vector<int32_t>& error_codes =
179 flags->federated_training_permanent_error_codes();
180 federated_training_permanent_error_codes_ =
181 absl::flat_hash_set<int32_t>(error_codes.begin(), error_codes.end());
182 }
183
~GrpcFederatedProtocol()184 GrpcFederatedProtocol::~GrpcFederatedProtocol() { grpc_bidi_stream_->Close(); }
185
Send(google::internal::federatedml::v2::ClientStreamMessage * client_stream_message)186 absl::Status GrpcFederatedProtocol::Send(
187 google::internal::federatedml::v2::ClientStreamMessage*
188 client_stream_message) {
189 // Note that this stopwatch measurement may not fully measure the time it
190 // takes to send all of the data, as it may return before all data was written
191 // to the network socket. It's the best estimate we can provide though.
192 auto started_stopwatch = network_stopwatch_->Start();
193 FCP_RETURN_IF_ERROR(interruptible_runner_->Run(
194 [this, &client_stream_message]() {
195 return this->grpc_bidi_stream_->Send(client_stream_message);
196 },
197 [this]() { this->grpc_bidi_stream_->Close(); }));
198 return absl::OkStatus();
199 }
200
Receive(google::internal::federatedml::v2::ServerStreamMessage * server_stream_message)201 absl::Status GrpcFederatedProtocol::Receive(
202 google::internal::federatedml::v2::ServerStreamMessage*
203 server_stream_message) {
204 // Note that this stopwatch measurement will generally include time spent
205 // waiting for the server to return a response (i.e. idle time rather than the
206 // true time it takes to send/receive data on the network). It's the best
207 // estimate we can provide though.
208 auto started_stopwatch = network_stopwatch_->Start();
209 FCP_RETURN_IF_ERROR(interruptible_runner_->Run(
210 [this, &server_stream_message]() {
211 return grpc_bidi_stream_->Receive(server_stream_message);
212 },
213 [this]() { this->grpc_bidi_stream_->Close(); }));
214 return absl::OkStatus();
215 }
216
CreateProtocolOptionsRequest(bool should_ack_checkin) const217 ProtocolOptionsRequest GrpcFederatedProtocol::CreateProtocolOptionsRequest(
218 bool should_ack_checkin) const {
219 ProtocolOptionsRequest request;
220 request.set_should_ack_checkin(should_ack_checkin);
221 request.set_supports_http_download(http_client_ != nullptr);
222 request.set_supports_eligibility_eval_http_download(
223 http_client_ != nullptr &&
224 flags_->enable_grpc_with_eligibility_eval_http_resource_support());
225
226 // Note that we set this field for both eligibility eval checkin requests
227 // and regular checkin requests. Even though eligibility eval tasks do not
228 // have any aggregation phase, we still advertise the client's support for
229 // Secure Aggregation during the eligibility eval checkin phase. We do
230 // this because it doesn't hurt anything, and because letting the server
231 // know whether client supports SecAgg sooner rather than later in the
232 // protocol seems to provide maximum flexibility if the server ever were
233 // to use that information at this stage of the protocol in the future.
234 request.mutable_side_channels()
235 ->mutable_secure_aggregation()
236 ->add_client_variant(secagg::SECAGG_CLIENT_VARIANT_NATIVE_V1);
237 request.mutable_supported_http_compression_formats()->Add(
238 HttpCompressionFormat::HTTP_COMPRESSION_FORMAT_GZIP);
239 return request;
240 }
241
SendEligibilityEvalCheckinRequest()242 absl::Status GrpcFederatedProtocol::SendEligibilityEvalCheckinRequest() {
243 ClientStreamMessage client_stream_message;
244 EligibilityEvalCheckinRequest* eligibility_checkin_request =
245 client_stream_message.mutable_eligibility_eval_checkin_request();
246 eligibility_checkin_request->set_population_name(population_name_);
247 eligibility_checkin_request->set_retry_token(retry_token_);
248 eligibility_checkin_request->set_client_version(client_version_);
249 eligibility_checkin_request->set_attestation_measurement(
250 attestation_measurement_);
251 *eligibility_checkin_request->mutable_protocol_options_request() =
252 CreateProtocolOptionsRequest(
253 /* should_ack_checkin=*/true);
254
255 return Send(&client_stream_message);
256 }
257
SendCheckinRequest(const std::optional<TaskEligibilityInfo> & task_eligibility_info)258 absl::Status GrpcFederatedProtocol::SendCheckinRequest(
259 const std::optional<TaskEligibilityInfo>& task_eligibility_info) {
260 ClientStreamMessage client_stream_message;
261 CheckinRequest* checkin_request =
262 client_stream_message.mutable_checkin_request();
263 checkin_request->set_population_name(population_name_);
264 checkin_request->set_retry_token(retry_token_);
265 checkin_request->set_client_version(client_version_);
266 checkin_request->set_attestation_measurement(attestation_measurement_);
267 *checkin_request->mutable_protocol_options_request() =
268 CreateProtocolOptionsRequest(/* should_ack_checkin=*/false);
269
270 if (task_eligibility_info.has_value()) {
271 *checkin_request->mutable_task_eligibility_info() = *task_eligibility_info;
272 }
273
274 return Send(&client_stream_message);
275 }
276
ReceiveCheckinRequestAck()277 absl::Status GrpcFederatedProtocol::ReceiveCheckinRequestAck() {
278 // Wait for a CheckinRequestAck.
279 ServerStreamMessage server_stream_message;
280 absl::Status receive_status = Receive(&server_stream_message);
281 if (receive_status.code() == absl::StatusCode::kNotFound) {
282 FCP_LOG(INFO) << "Server responded NOT_FOUND to checkin request, "
283 "population name '"
284 << population_name_ << "' is likely incorrect.";
285 }
286 FCP_RETURN_IF_ERROR(receive_status);
287
288 if (!server_stream_message.has_checkin_request_ack()) {
289 log_manager_->LogDiag(
290 ProdDiagCode::
291 BACKGROUND_TRAINING_CHECKIN_REQUEST_ACK_EXPECTED_BUT_NOT_RECVD);
292 return absl::UnimplementedError(
293 "Requested but did not receive CheckinRequestAck");
294 }
295 log_manager_->LogDiag(
296 ProdDiagCode::BACKGROUND_TRAINING_CHECKIN_REQUEST_ACK_RECEIVED);
297 // Process the received CheckinRequestAck message.
298 const CheckinRequestAck& checkin_request_ack =
299 server_stream_message.checkin_request_ack();
300 if (!checkin_request_ack.has_retry_window_if_accepted() ||
301 !checkin_request_ack.has_retry_window_if_rejected()) {
302 return absl::UnimplementedError(
303 "Received CheckinRequestAck message with missing retry windows");
304 }
305 // Upon receiving the server's RetryWindows we immediately choose a concrete
306 // target timestamp to retry at. This ensures that a) clients of this class
307 // don't have to implement the logic to select a timestamp from a min/max
308 // range themselves, b) we tell clients of this class to come back at exactly
309 // a point in time the server intended us to come at (i.e. "now +
310 // server_specified_retry_period", and not a point in time that is partly
311 // determined by how long the remaining protocol interactions (e.g. training
312 // and results upload) will take (i.e. "now +
313 // duration_of_remaining_protocol_interactions +
314 // server_specified_retry_period").
315 checkin_request_ack_info_ = CheckinRequestAckInfo{
316 .retry_info_if_rejected =
317 RetryTimeAndToken{
318 PickRetryTimeFromRange(
319 checkin_request_ack.retry_window_if_rejected().delay_min(),
320 checkin_request_ack.retry_window_if_rejected().delay_max(),
321 bit_gen_),
322 checkin_request_ack.retry_window_if_rejected().retry_token()},
323 .retry_info_if_accepted = RetryTimeAndToken{
324 PickRetryTimeFromRange(
325 checkin_request_ack.retry_window_if_accepted().delay_min(),
326 checkin_request_ack.retry_window_if_accepted().delay_max(),
327 bit_gen_),
328 checkin_request_ack.retry_window_if_accepted().retry_token()}};
329 return absl::OkStatus();
330 }
331
332 absl::StatusOr<FederatedProtocol::EligibilityEvalCheckinResult>
ReceiveEligibilityEvalCheckinResponse(absl::Time start_time,std::function<void (const EligibilityEvalTask &)> payload_uris_received_callback)333 GrpcFederatedProtocol::ReceiveEligibilityEvalCheckinResponse(
334 absl::Time start_time, std::function<void(const EligibilityEvalTask&)>
335 payload_uris_received_callback) {
336 ServerStreamMessage server_stream_message;
337 FCP_RETURN_IF_ERROR(Receive(&server_stream_message));
338
339 if (!server_stream_message.has_eligibility_eval_checkin_response()) {
340 return absl::UnimplementedError(
341 absl::StrCat("Bad response to EligibilityEvalCheckinRequest; Expected "
342 "EligibilityEvalCheckinResponse but got ",
343 server_stream_message.kind_case(), "."));
344 }
345
346 const EligibilityEvalCheckinResponse& eligibility_checkin_response =
347 server_stream_message.eligibility_eval_checkin_response();
348 switch (eligibility_checkin_response.checkin_result_case()) {
349 case EligibilityEvalCheckinResponse::kEligibilityEvalPayload: {
350 const EligibilityEvalPayload& eligibility_eval_payload =
351 eligibility_checkin_response.eligibility_eval_payload();
352 object_state_ = ObjectState::kEligibilityEvalEnabled;
353 EligibilityEvalTask result{.execution_id =
354 eligibility_eval_payload.execution_id()};
355
356 payload_uris_received_callback(result);
357
358 PlanAndCheckpointPayloads payloads;
359 if (http_client_ == nullptr ||
360 !flags_->enable_grpc_with_eligibility_eval_http_resource_support()) {
361 result.payloads = {
362 .plan = eligibility_eval_payload.plan(),
363 .checkpoint = eligibility_eval_payload.init_checkpoint()};
364 } else {
365 // Fetch the task resources, returning any errors that may be
366 // encountered in the process.
367 FCP_ASSIGN_OR_RETURN(
368 result.payloads,
369 FetchTaskResources(
370 {.plan =
371 {
372 .has_uri =
373 eligibility_eval_payload.has_plan_resource(),
374 .uri = eligibility_eval_payload.plan_resource().uri(),
375 .data = eligibility_eval_payload.plan(),
376 .client_cache_id =
377 eligibility_eval_payload.plan_resource()
378 .client_cache_id(),
379 .max_age = TimeUtil::ConvertProtoToAbslDuration(
380 eligibility_eval_payload.plan_resource()
381 .max_age()),
382 },
383 .checkpoint = {
384 .has_uri = eligibility_eval_payload
385 .has_init_checkpoint_resource(),
386 .uri = eligibility_eval_payload.init_checkpoint_resource()
387 .uri(),
388 .data = eligibility_eval_payload.init_checkpoint(),
389 .client_cache_id =
390 eligibility_eval_payload.init_checkpoint_resource()
391 .client_cache_id(),
392 .max_age = TimeUtil::ConvertProtoToAbslDuration(
393 eligibility_eval_payload.init_checkpoint_resource()
394 .max_age()),
395 }}));
396 }
397 return std::move(result);
398 }
399 case EligibilityEvalCheckinResponse::kNoEligibilityEvalConfigured: {
400 // Nothing to do...
401 object_state_ = ObjectState::kEligibilityEvalDisabled;
402 return EligibilityEvalDisabled{};
403 }
404 case EligibilityEvalCheckinResponse::kRejectionInfo: {
405 object_state_ = ObjectState::kEligibilityEvalCheckinRejected;
406 return Rejection{};
407 }
408 default:
409 return absl::UnimplementedError(
410 "Unrecognized EligibilityEvalCheckinResponse");
411 }
412 }
413
414 absl::StatusOr<FederatedProtocol::CheckinResult>
ReceiveCheckinResponse(absl::Time start_time,std::function<void (const TaskAssignment &)> payload_uris_received_callback)415 GrpcFederatedProtocol::ReceiveCheckinResponse(
416 absl::Time start_time,
417 std::function<void(const TaskAssignment&)> payload_uris_received_callback) {
418 ServerStreamMessage server_stream_message;
419 absl::Status receive_status = Receive(&server_stream_message);
420 FCP_RETURN_IF_ERROR(receive_status);
421
422 if (!server_stream_message.has_checkin_response()) {
423 return absl::UnimplementedError(absl::StrCat(
424 "Bad response to CheckinRequest; Expected CheckinResponse but got ",
425 server_stream_message.kind_case(), "."));
426 }
427
428 const CheckinResponse& checkin_response =
429 server_stream_message.checkin_response();
430
431 execution_phase_id_ =
432 checkin_response.has_acceptance_info()
433 ? checkin_response.acceptance_info().execution_phase_id()
434 : "";
435 switch (checkin_response.checkin_result_case()) {
436 case CheckinResponse::kAcceptanceInfo: {
437 const auto& acceptance_info = checkin_response.acceptance_info();
438
439 for (const auto& [k, v] : acceptance_info.side_channels())
440 side_channels_[k] = v;
441 side_channel_protocol_execution_info_ =
442 acceptance_info.side_channel_protocol_execution_info();
443 side_channel_protocol_options_response_ =
444 checkin_response.protocol_options_response().side_channels();
445
446 std::optional<SecAggInfo> sec_agg_info = std::nullopt;
447 if (side_channel_protocol_execution_info_.has_secure_aggregation()) {
448 sec_agg_info = SecAggInfo{
449 .expected_number_of_clients =
450 side_channel_protocol_execution_info_.secure_aggregation()
451 .expected_number_of_clients(),
452 .minimum_clients_in_server_visible_aggregate =
453 side_channel_protocol_execution_info_.secure_aggregation()
454 .minimum_clients_in_server_visible_aggregate()};
455 }
456
457 TaskAssignment result{
458 .federated_select_uri_template =
459 acceptance_info.federated_select_uri_info().uri_template(),
460 .aggregation_session_id = acceptance_info.execution_phase_id(),
461 .sec_agg_info = sec_agg_info};
462
463 payload_uris_received_callback(result);
464
465 PlanAndCheckpointPayloads payloads;
466 if (http_client_ == nullptr) {
467 result.payloads = {.plan = acceptance_info.plan(),
468 .checkpoint = acceptance_info.init_checkpoint()};
469 } else {
470 // Fetch the task resources, returning any errors that may be
471 // encountered in the process.
472 FCP_ASSIGN_OR_RETURN(
473 result.payloads,
474 FetchTaskResources(
475 {.plan =
476 {
477 .has_uri = acceptance_info.has_plan_resource(),
478 .uri = acceptance_info.plan_resource().uri(),
479 .data = acceptance_info.plan(),
480 .client_cache_id =
481 acceptance_info.plan_resource().client_cache_id(),
482 .max_age = TimeUtil::ConvertProtoToAbslDuration(
483 acceptance_info.plan_resource().max_age()),
484 },
485 .checkpoint = {
486 .has_uri = acceptance_info.has_init_checkpoint_resource(),
487 .uri = acceptance_info.init_checkpoint_resource().uri(),
488 .data = acceptance_info.init_checkpoint(),
489 .client_cache_id =
490 acceptance_info.init_checkpoint_resource()
491 .client_cache_id(),
492 .max_age = TimeUtil::ConvertProtoToAbslDuration(
493 acceptance_info.init_checkpoint_resource().max_age()),
494 }}));
495 }
496
497 object_state_ = ObjectState::kCheckinAccepted;
498 return result;
499 }
500 case CheckinResponse::kRejectionInfo: {
501 object_state_ = ObjectState::kCheckinRejected;
502 return Rejection{};
503 }
504 default:
505 return absl::UnimplementedError("Unrecognized CheckinResponse");
506 }
507 }
508
509 absl::StatusOr<FederatedProtocol::EligibilityEvalCheckinResult>
EligibilityEvalCheckin(std::function<void (const EligibilityEvalTask &)> payload_uris_received_callback)510 GrpcFederatedProtocol::EligibilityEvalCheckin(
511 std::function<void(const EligibilityEvalTask&)>
512 payload_uris_received_callback) {
513 FCP_CHECK(object_state_ == ObjectState::kInitialized)
514 << "Invalid call sequence";
515 object_state_ = ObjectState::kEligibilityEvalCheckinFailed;
516
517 absl::Time start_time = absl::Now();
518
519 // Send an EligibilityEvalCheckinRequest.
520 absl::Status request_status = SendEligibilityEvalCheckinRequest();
521 // See note about how we handle 'permanent' errors at the top of this file.
522 UpdateObjectStateIfPermanentError(
523 request_status, ObjectState::kEligibilityEvalCheckinFailedPermanentError);
524 FCP_RETURN_IF_ERROR(request_status);
525
526 // Receive a CheckinRequestAck.
527 absl::Status ack_status = ReceiveCheckinRequestAck();
528 UpdateObjectStateIfPermanentError(
529 ack_status, ObjectState::kEligibilityEvalCheckinFailedPermanentError);
530 FCP_RETURN_IF_ERROR(ack_status);
531
532 // Receive + handle an EligibilityEvalCheckinResponse message, and update the
533 // object state based on the received response.
534 auto response = ReceiveEligibilityEvalCheckinResponse(
535 start_time, payload_uris_received_callback);
536 UpdateObjectStateIfPermanentError(
537 response.status(),
538 ObjectState::kEligibilityEvalCheckinFailedPermanentError);
539 return response;
540 }
541
542 // This is not supported in gRPC federated protocol, we'll do nothing.
ReportEligibilityEvalError(absl::Status error_status)543 void GrpcFederatedProtocol::ReportEligibilityEvalError(
544 absl::Status error_status) {}
545
Checkin(const std::optional<TaskEligibilityInfo> & task_eligibility_info,std::function<void (const TaskAssignment &)> payload_uris_received_callback)546 absl::StatusOr<FederatedProtocol::CheckinResult> GrpcFederatedProtocol::Checkin(
547 const std::optional<TaskEligibilityInfo>& task_eligibility_info,
548 std::function<void(const TaskAssignment&)> payload_uris_received_callback) {
549 // Checkin(...) must follow an earlier call to EligibilityEvalCheckin() that
550 // resulted in a CheckinResultPayload or an EligibilityEvalDisabled result.
551 FCP_CHECK(object_state_ == ObjectState::kEligibilityEvalDisabled ||
552 object_state_ == ObjectState::kEligibilityEvalEnabled)
553 << "Checkin(...) called despite failed/rejected earlier "
554 "EligibilityEvalCheckin";
555 if (object_state_ == ObjectState::kEligibilityEvalEnabled) {
556 FCP_CHECK(task_eligibility_info.has_value())
557 << "Missing TaskEligibilityInfo despite receiving prior "
558 "EligibilityEvalCheckin payload";
559 } else {
560 FCP_CHECK(!task_eligibility_info.has_value())
561 << "Received TaskEligibilityInfo despite not receiving a prior "
562 "EligibilityEvalCheckin payload";
563 }
564
565 object_state_ = ObjectState::kCheckinFailed;
566
567 absl::Time start_time = absl::Now();
568 // Send a CheckinRequest.
569 absl::Status request_status = SendCheckinRequest(task_eligibility_info);
570 // See note about how we handle 'permanent' errors at the top of this file.
571 UpdateObjectStateIfPermanentError(request_status,
572 ObjectState::kCheckinFailedPermanentError);
573 FCP_RETURN_IF_ERROR(request_status);
574
575 // Receive + handle a CheckinResponse message, and update the object state
576 // based on the received response.
577 auto response =
578 ReceiveCheckinResponse(start_time, payload_uris_received_callback);
579 UpdateObjectStateIfPermanentError(response.status(),
580 ObjectState::kCheckinFailedPermanentError);
581 return response;
582 }
583
584 absl::StatusOr<FederatedProtocol::MultipleTaskAssignments>
PerformMultipleTaskAssignments(const std::vector<std::string> & task_names)585 GrpcFederatedProtocol::PerformMultipleTaskAssignments(
586 const std::vector<std::string>& task_names) {
587 return absl::UnimplementedError(
588 "PerformMultipleTaskAssignments is not supported by "
589 "GrpcFederatedProtocol.");
590 }
591
ReportCompleted(ComputationResults results,absl::Duration plan_duration,std::optional<std::string> aggregation_session_id)592 absl::Status GrpcFederatedProtocol::ReportCompleted(
593 ComputationResults results, absl::Duration plan_duration,
594 std::optional<std::string> aggregation_session_id) {
595 FCP_LOG(INFO) << "Reporting outcome: " << static_cast<int>(engine::COMPLETED);
596 FCP_CHECK(object_state_ == ObjectState::kCheckinAccepted)
597 << "Invalid call sequence";
598 object_state_ = ObjectState::kReportCalled;
599 auto response = Report(std::move(results), engine::COMPLETED, plan_duration);
600 // See note about how we handle 'permanent' errors at the top of this file.
601 UpdateObjectStateIfPermanentError(response,
602 ObjectState::kReportFailedPermanentError);
603 return response;
604 }
605
ReportNotCompleted(engine::PhaseOutcome phase_outcome,absl::Duration plan_duration,std::optional<std::string> aggregation_session_Id)606 absl::Status GrpcFederatedProtocol::ReportNotCompleted(
607 engine::PhaseOutcome phase_outcome, absl::Duration plan_duration,
608 std::optional<std::string> aggregation_session_Id) {
609 FCP_LOG(WARNING) << "Reporting outcome: " << static_cast<int>(phase_outcome);
610 FCP_CHECK(object_state_ == ObjectState::kCheckinAccepted)
611 << "Invalid call sequence";
612 object_state_ = ObjectState::kReportCalled;
613 ComputationResults results;
614 results.emplace("tensorflow_checkpoint", "");
615 auto response = Report(std::move(results), phase_outcome, plan_duration);
616 // See note about how we handle 'permanent' errors at the top of this file.
617 UpdateObjectStateIfPermanentError(response,
618 ObjectState::kReportFailedPermanentError);
619 return response;
620 }
621
622 class GrpcSecAggSendToServerImpl : public SecAggSendToServerBase {
623 public:
GrpcSecAggSendToServerImpl(GrpcBidiStreamInterface * grpc_bidi_stream,const std::function<absl::Status (ClientToServerWrapperMessage *)> & report_func)624 GrpcSecAggSendToServerImpl(
625 GrpcBidiStreamInterface* grpc_bidi_stream,
626 const std::function<absl::Status(ClientToServerWrapperMessage*)>&
627 report_func)
628 : grpc_bidi_stream_(grpc_bidi_stream), report_func_(report_func) {}
629 ~GrpcSecAggSendToServerImpl() override = default;
630
Send(ClientToServerWrapperMessage * message)631 void Send(ClientToServerWrapperMessage* message) override {
632 // The commit message (MaskedInputRequest) must be piggy-backed onto the
633 // ReportRequest message, the logic for which is encapsulated in
634 // report_func_ so that it may be held in common between both accumulation
635 // methods.
636 if (message->message_content_case() ==
637 ClientToServerWrapperMessage::MessageContentCase::
638 kMaskedInputResponse) {
639 auto status = report_func_(message);
640 if (!status.ok())
641 FCP_LOG(ERROR) << "Could not send ReportRequest: " << status;
642 return;
643 }
644 ClientStreamMessage client_stream_message;
645 client_stream_message.mutable_secure_aggregation_client_message()->Swap(
646 message);
647 auto bytes_to_upload = client_stream_message.ByteSizeLong();
648 auto status = grpc_bidi_stream_->Send(&client_stream_message);
649 if (status.ok()) {
650 last_sent_message_size_ = bytes_to_upload;
651 }
652 }
653
654 private:
655 GrpcBidiStreamInterface* grpc_bidi_stream_;
656 // SecAgg's output must be wrapped in a ReportRequest; because the report
657 // logic is mostly generic, this lambda allows it to be shared between
658 // aggregation types.
659 const std::function<absl::Status(ClientToServerWrapperMessage*)>&
660 report_func_;
661 };
662
663 class GrpcSecAggProtocolDelegate : public SecAggProtocolDelegate {
664 public:
GrpcSecAggProtocolDelegate(absl::flat_hash_map<std::string,SideChannelExecutionInfo> side_channels,GrpcBidiStreamInterface * grpc_bidi_stream)665 GrpcSecAggProtocolDelegate(
666 absl::flat_hash_map<std::string, SideChannelExecutionInfo> side_channels,
667 GrpcBidiStreamInterface* grpc_bidi_stream)
668 : side_channels_(std::move(side_channels)),
669 grpc_bidi_stream_(grpc_bidi_stream) {}
670
GetModulus(const std::string & key)671 absl::StatusOr<uint64_t> GetModulus(const std::string& key) override {
672 auto execution_info = side_channels_.find(key);
673 if (execution_info == side_channels_.end())
674 return absl::InternalError(
675 absl::StrCat("Execution not found for aggregand: ", key));
676 uint64_t modulus;
677 auto secure_aggregand = execution_info->second.secure_aggregand();
678 // TODO(team): Delete output_bitwidth support once
679 // modulus is fully rolled out.
680 if (secure_aggregand.modulus() > 0) {
681 modulus = secure_aggregand.modulus();
682 } else {
683 // Note: we ignore vector.get_bitwidth() here, because (1)
684 // it is only an upper bound on the *input* bitwidth,
685 // based on the Tensorflow dtype, but (2) we have exact
686 // *output* bitwidth information from the execution_info,
687 // and that is what SecAgg needs.
688 modulus = 1ULL << secure_aggregand.output_bitwidth();
689 }
690 return modulus;
691 }
692
ReceiveServerMessage()693 absl::StatusOr<secagg::ServerToClientWrapperMessage> ReceiveServerMessage()
694 override {
695 ServerStreamMessage server_stream_message;
696 absl::Status receive_status =
697 grpc_bidi_stream_->Receive(&server_stream_message);
698 if (!receive_status.ok()) {
699 return absl::Status(receive_status.code(),
700 absl::StrCat("Error during SecAgg receive: ",
701 receive_status.message()));
702 }
703 last_received_message_size_ = server_stream_message.ByteSizeLong();
704 if (!server_stream_message.has_secure_aggregation_server_message()) {
705 return absl::InternalError(
706 absl::StrCat("Bad response to SecAgg protocol; Expected "
707 "ServerToClientWrapperMessage but got ",
708 server_stream_message.kind_case(), "."));
709 }
710 return server_stream_message.secure_aggregation_server_message();
711 }
712
Abort()713 void Abort() override { grpc_bidi_stream_->Close(); }
last_received_message_size()714 size_t last_received_message_size() override {
715 return last_received_message_size_;
716 };
717
718 private:
719 absl::flat_hash_map<std::string, SideChannelExecutionInfo> side_channels_;
720 GrpcBidiStreamInterface* grpc_bidi_stream_;
721 size_t last_received_message_size_;
722 };
723
ReportInternal(std::string tf_checkpoint,engine::PhaseOutcome phase_outcome,absl::Duration plan_duration,ClientToServerWrapperMessage * secagg_commit_message)724 absl::Status GrpcFederatedProtocol::ReportInternal(
725 std::string tf_checkpoint, engine::PhaseOutcome phase_outcome,
726 absl::Duration plan_duration,
727 ClientToServerWrapperMessage* secagg_commit_message) {
728 ClientStreamMessage client_stream_message;
729 auto report_request = client_stream_message.mutable_report_request();
730 report_request->set_population_name(population_name_);
731 report_request->set_execution_phase_id(execution_phase_id_);
732 auto report = report_request->mutable_report();
733
734 // 1. Include TF checkpoint and/or SecAgg commit message.
735 report->set_update_checkpoint(std::move(tf_checkpoint));
736 if (secagg_commit_message) {
737 client_stream_message.mutable_secure_aggregation_client_message()->Swap(
738 secagg_commit_message);
739 }
740
741 // 2. Include outcome of computation.
742 report->set_status_code(phase_outcome == engine::COMPLETED
743 ? google::rpc::OK
744 : google::rpc::INTERNAL);
745
746 // 3. Include client execution statistics, if any.
747 ClientExecutionStats client_execution_stats;
748 client_execution_stats.mutable_duration()->set_seconds(
749 absl::IDivDuration(plan_duration, absl::Seconds(1), &plan_duration));
750 client_execution_stats.mutable_duration()->set_nanos(static_cast<int32_t>(
751 absl::IDivDuration(plan_duration, absl::Nanoseconds(1), &plan_duration)));
752 report->add_serialized_train_event()->PackFrom(client_execution_stats);
753
754 // 4. Send ReportRequest.
755
756 // Note that we do not use the GrpcFederatedProtocol::Send(...) helper method
757 // here, since we are already running within a call to
758 // InterruptibleRunner::Run.
759 const auto status = this->grpc_bidi_stream_->Send(&client_stream_message);
760 if (!status.ok()) {
761 return absl::Status(
762 status.code(),
763 absl::StrCat("Error sending ReportRequest: ", status.message()));
764 }
765
766 return absl::OkStatus();
767 }
768
Report(ComputationResults results,engine::PhaseOutcome phase_outcome,absl::Duration plan_duration)769 absl::Status GrpcFederatedProtocol::Report(ComputationResults results,
770 engine::PhaseOutcome phase_outcome,
771 absl::Duration plan_duration) {
772 std::string tf_checkpoint;
773 bool has_checkpoint;
774 for (auto& [k, v] : results) {
775 if (std::holds_alternative<TFCheckpoint>(v)) {
776 tf_checkpoint = std::get<TFCheckpoint>(std::move(v));
777 has_checkpoint = true;
778 break;
779 }
780 }
781
782 // This lambda allows for convenient reporting from within SecAgg's
783 // SendToServerInterface::Send().
784 std::function<absl::Status(ClientToServerWrapperMessage*)> report_lambda =
785 [&](ClientToServerWrapperMessage* secagg_commit_message) -> absl::Status {
786 return ReportInternal(std::move(tf_checkpoint), phase_outcome,
787 plan_duration, secagg_commit_message);
788 };
789
790 // Run the Secure Aggregation protocol, if necessary.
791 if (side_channel_protocol_execution_info_.has_secure_aggregation()) {
792 auto secure_aggregation_protocol_execution_info =
793 side_channel_protocol_execution_info_.secure_aggregation();
794 auto expected_number_of_clients =
795 secure_aggregation_protocol_execution_info.expected_number_of_clients();
796
797 FCP_LOG(INFO) << "Reporting via Secure Aggregation";
798 if (phase_outcome != engine::COMPLETED)
799 return absl::InternalError(
800 "Aborting the SecAgg protocol (no update was produced).");
801
802 if (side_channel_protocol_options_response_.secure_aggregation()
803 .client_variant() != secagg::SECAGG_CLIENT_VARIANT_NATIVE_V1) {
804 log_manager_->LogDiag(
805 ProdDiagCode::SECAGG_CLIENT_ERROR_UNSUPPORTED_VERSION);
806 return absl::InternalError(absl::StrCat(
807 "Unsupported SecAgg client variant: ",
808 side_channel_protocol_options_response_.secure_aggregation()
809 .client_variant()));
810 }
811
812 auto send_to_server_impl = std::make_unique<GrpcSecAggSendToServerImpl>(
813 grpc_bidi_stream_.get(), report_lambda);
814 auto secagg_event_publisher = event_publisher_->secagg_event_publisher();
815 FCP_CHECK(secagg_event_publisher)
816 << "An implementation of "
817 << "SecAggEventPublisher must be provided.";
818 auto delegate = std::make_unique<GrpcSecAggProtocolDelegate>(
819 side_channels_, grpc_bidi_stream_.get());
820 std::unique_ptr<SecAggRunner> secagg_runner =
821 secagg_runner_factory_->CreateSecAggRunner(
822 std::move(send_to_server_impl), std::move(delegate),
823 secagg_event_publisher, log_manager_, interruptible_runner_.get(),
824 expected_number_of_clients,
825 secure_aggregation_protocol_execution_info
826 .minimum_surviving_clients_for_reconstruction());
827
828 FCP_RETURN_IF_ERROR(secagg_runner->Run(std::move(results)));
829 } else {
830 // Report without secure aggregation.
831 FCP_LOG(INFO) << "Reporting via Simple Aggregation";
832 if (results.size() != 1 || !has_checkpoint) {
833 return absl::InternalError(
834 "Simple Aggregation aggregands have unexpected format.");
835 }
836 FCP_RETURN_IF_ERROR(interruptible_runner_->Run(
837 [&report_lambda]() { return report_lambda(nullptr); },
838 [this]() {
839 // What about event_publisher_ and log_manager_?
840 this->grpc_bidi_stream_->Close();
841 }));
842 }
843
844 FCP_LOG(INFO) << "Finished reporting.";
845
846 // Receive ReportResponse.
847 ServerStreamMessage server_stream_message;
848 absl::Status receive_status = Receive(&server_stream_message);
849 if (receive_status.code() == absl::StatusCode::kAborted) {
850 FCP_LOG(INFO) << "Server responded ABORTED.";
851 } else if (receive_status.code() == absl::StatusCode::kCancelled) {
852 FCP_LOG(INFO) << "Upload was cancelled by the client.";
853 }
854 if (!receive_status.ok()) {
855 return absl::Status(
856 receive_status.code(),
857 absl::StrCat("Error after ReportRequest: ", receive_status.message()));
858 }
859 if (!server_stream_message.has_report_response()) {
860 return absl::UnimplementedError(absl::StrCat(
861 "Bad response to ReportRequest; Expected REPORT_RESPONSE but got ",
862 server_stream_message.kind_case(), "."));
863 }
864 return absl::OkStatus();
865 }
866
GetLatestRetryWindow()867 RetryWindow GrpcFederatedProtocol::GetLatestRetryWindow() {
868 // We explicitly enumerate all possible states here rather than using
869 // "default", to ensure that when new states are added later on, the author
870 // is forced to update this method and consider which is the correct
871 // RetryWindow to return.
872 switch (object_state_) {
873 case ObjectState::kCheckinAccepted:
874 case ObjectState::kReportCalled:
875 // If a client makes it past the 'checkin acceptance' stage, we use the
876 // 'accepted' RetryWindow unconditionally (unless a permanent error is
877 // encountered). This includes cases where the checkin is accepted, but
878 // the report request results in a (transient) error.
879 FCP_CHECK(checkin_request_ack_info_.has_value());
880 return GenerateRetryWindowFromRetryTimeAndToken(
881 checkin_request_ack_info_->retry_info_if_accepted);
882 case ObjectState::kEligibilityEvalCheckinRejected:
883 case ObjectState::kEligibilityEvalDisabled:
884 case ObjectState::kEligibilityEvalEnabled:
885 case ObjectState::kCheckinRejected:
886 FCP_CHECK(checkin_request_ack_info_.has_value());
887 return GenerateRetryWindowFromRetryTimeAndToken(
888 checkin_request_ack_info_->retry_info_if_rejected);
889 case ObjectState::kInitialized:
890 case ObjectState::kEligibilityEvalCheckinFailed:
891 case ObjectState::kCheckinFailed:
892 // If the flag is true, then we use the previously chosen absolute retry
893 // time instead (if available).
894 if (checkin_request_ack_info_.has_value()) {
895 // If we already received a server-provided retry window, then use it.
896 return GenerateRetryWindowFromRetryTimeAndToken(
897 checkin_request_ack_info_->retry_info_if_rejected);
898 }
899 // Otherwise, we generate a retry window using the flag-provided transient
900 // error retry period.
901 return GenerateRetryWindowFromTargetDelay(
902 absl::Seconds(
903 flags_->federated_training_transient_errors_retry_delay_secs()),
904 // NOLINTBEGIN(whitespace/line_length)
905 flags_
906 ->federated_training_transient_errors_retry_delay_jitter_percent(),
907 // NOLINTEND
908 bit_gen_);
909 case ObjectState::kEligibilityEvalCheckinFailedPermanentError:
910 case ObjectState::kCheckinFailedPermanentError:
911 case ObjectState::kReportFailedPermanentError:
912 // If we encountered a permanent error during the eligibility eval or
913 // regular checkins, then we use the Flags-configured 'permanent error'
914 // retry period. Note that we do so regardless of whether the server had,
915 // by the time the permanent error was received, already returned a
916 // CheckinRequestAck containing a set of retry windows. See note on error
917 // handling at the top of this file.
918 return GenerateRetryWindowFromTargetDelay(
919 absl::Seconds(
920 flags_->federated_training_permanent_errors_retry_delay_secs()),
921 // NOLINTBEGIN(whitespace/line_length)
922 flags_
923 ->federated_training_permanent_errors_retry_delay_jitter_percent(),
924 // NOLINTEND
925 bit_gen_);
926 case ObjectState::kMultipleTaskAssignmentsAccepted:
927 case ObjectState::kMultipleTaskAssignmentsFailed:
928 case ObjectState::kMultipleTaskAssignmentsFailedPermanentError:
929 case ObjectState::kMultipleTaskAssignmentsNoAvailableTask:
930 case ObjectState::kReportMultipleTaskPartialError:
931 FCP_LOG(FATAL) << "Multi-task assignments is not supported by gRPC.";
932 RetryWindow retry_window;
933 return retry_window;
934 }
935 }
936
937 // Converts the given RetryTimeAndToken to a zero-width RetryWindow (where
938 // delay_min and delay_max are set to the same value), by converting the target
939 // retry time to a delay relative to the current timestamp.
GenerateRetryWindowFromRetryTimeAndToken(const GrpcFederatedProtocol::RetryTimeAndToken & retry_info)940 RetryWindow GrpcFederatedProtocol::GenerateRetryWindowFromRetryTimeAndToken(
941 const GrpcFederatedProtocol::RetryTimeAndToken& retry_info) {
942 // Generate a RetryWindow with delay_min and delay_max both set to the same
943 // value.
944 RetryWindow retry_window =
945 GenerateRetryWindowFromRetryTime(retry_info.retry_time);
946 retry_window.set_retry_token(retry_info.retry_token);
947 return retry_window;
948 }
949
UpdateObjectStateIfPermanentError(absl::Status status,GrpcFederatedProtocol::ObjectState permanent_error_object_state)950 void GrpcFederatedProtocol::UpdateObjectStateIfPermanentError(
951 absl::Status status,
952 GrpcFederatedProtocol::ObjectState permanent_error_object_state) {
953 if (federated_training_permanent_error_codes_.contains(
954 static_cast<int32_t>(status.code()))) {
955 object_state_ = permanent_error_object_state;
956 }
957 }
958
959 absl::StatusOr<FederatedProtocol::PlanAndCheckpointPayloads>
FetchTaskResources(GrpcFederatedProtocol::TaskResources task_resources)960 GrpcFederatedProtocol::FetchTaskResources(
961 GrpcFederatedProtocol::TaskResources task_resources) {
962 FCP_ASSIGN_OR_RETURN(UriOrInlineData plan_uri_or_data,
963 ConvertResourceToUriOrInlineData(task_resources.plan));
964 FCP_ASSIGN_OR_RETURN(
965 UriOrInlineData checkpoint_uri_or_data,
966 ConvertResourceToUriOrInlineData(task_resources.checkpoint));
967
968 // Log a diag code if either resource is about to be downloaded via HTTP.
969 if (!plan_uri_or_data.uri().uri.empty() ||
970 !checkpoint_uri_or_data.uri().uri.empty()) {
971 log_manager_->LogDiag(
972 ProdDiagCode::HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_USES_HTTP);
973 }
974
975 // Fetch the plan and init checkpoint resources if they need to be fetched
976 // (using the inline data instead if available).
977 absl::StatusOr<
978 std::vector<absl::StatusOr<::fcp::client::http::InMemoryHttpResponse>>>
979 resource_responses;
980 {
981 auto started_stopwatch = network_stopwatch_->Start();
982 resource_responses = ::fcp::client::http::FetchResourcesInMemory(
983 *http_client_, *interruptible_runner_,
984 {plan_uri_or_data, checkpoint_uri_or_data}, &http_bytes_downloaded_,
985 &http_bytes_uploaded_, resource_cache_);
986 }
987 if (!resource_responses.ok()) {
988 log_manager_->LogDiag(
989 ProdDiagCode::
990 HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_HTTP_FETCH_FAILED);
991 return resource_responses.status();
992 }
993 auto& plan_data_response = (*resource_responses)[0];
994 auto& checkpoint_data_response = (*resource_responses)[1];
995
996 if (!plan_data_response.ok() || !checkpoint_data_response.ok()) {
997 log_manager_->LogDiag(
998 ProdDiagCode::
999 HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_HTTP_FETCH_FAILED);
1000 }
1001 // Note: we forward any error during the fetching of the plan/checkpoint
1002 // resources resources to the caller, which means that these error codes
1003 // will be checked against the set of 'permanent' error codes, just like the
1004 // errors in response to the protocol request are.
1005 if (!plan_data_response.ok()) {
1006 return absl::Status(plan_data_response.status().code(),
1007 absl::StrCat("plan fetch failed: ",
1008 plan_data_response.status().ToString()));
1009 }
1010 if (!checkpoint_data_response.ok()) {
1011 return absl::Status(
1012 checkpoint_data_response.status().code(),
1013 absl::StrCat("checkpoint fetch failed: ",
1014 checkpoint_data_response.status().ToString()));
1015 }
1016 if (!plan_uri_or_data.uri().uri.empty() ||
1017 !checkpoint_uri_or_data.uri().uri.empty()) {
1018 // We only want to log this diag code when we actually did fetch something
1019 // via HTTP.
1020 log_manager_->LogDiag(
1021 ProdDiagCode::
1022 HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_HTTP_FETCH_SUCCEEDED);
1023 }
1024
1025 return PlanAndCheckpointPayloads{plan_data_response->body,
1026 checkpoint_data_response->body};
1027 }
1028
1029 // Convert a Resource proto into a UriOrInlineData object. Returns an
1030 // `INVALID_ARGUMENT` error if the given `Resource` has the `uri` field set to
1031 // an empty value, or an `UNIMPLEMENTED` error if the `Resource` has an unknown
1032 // field set.
1033 absl::StatusOr<UriOrInlineData>
ConvertResourceToUriOrInlineData(const GrpcFederatedProtocol::TaskResource & resource)1034 GrpcFederatedProtocol::ConvertResourceToUriOrInlineData(
1035 const GrpcFederatedProtocol::TaskResource& resource) {
1036 // We need to support 3 states:
1037 // - Inline data is available.
1038 // - No inline data nor is there a URI. This should be treated as there being
1039 // an 'empty' inline data.
1040 // - No inline data is available but a URI is available.
1041 if (!resource.has_uri) {
1042 // If the URI field wasn't set, then we'll just use the inline data field
1043 // (which will either be set or be empty).
1044 //
1045 // Note: this copies the data into the new absl::Cord. However, this Cord is
1046 // then passed around all the way to fl_runner.cc without copying its data,
1047 // so this is ultimately approx. as efficient as the non-HTTP resource code
1048 // path where we also make a copy of the protobuf string into a new string
1049 // which is then returned.
1050 return UriOrInlineData::CreateInlineData(
1051 absl::Cord(resource.data),
1052 UriOrInlineData::InlineData::CompressionFormat::kUncompressed);
1053 }
1054 if (resource.uri.empty()) {
1055 return absl::InvalidArgumentError(
1056 "Resource uri must be non-empty when set");
1057 }
1058 return UriOrInlineData::CreateUri(resource.uri, resource.client_cache_id,
1059 resource.max_age);
1060 }
1061
GetNetworkStats()1062 NetworkStats GrpcFederatedProtocol::GetNetworkStats() {
1063 // Note: the `HttpClient` bandwidth stats are similar to the gRPC protocol's
1064 // "chunking layer" stats, in that they reflect as closely as possible the
1065 // amount of data sent on the wire.
1066 return {.bytes_downloaded = grpc_bidi_stream_->ChunkingLayerBytesReceived() +
1067 http_bytes_downloaded_,
1068 .bytes_uploaded = grpc_bidi_stream_->ChunkingLayerBytesSent() +
1069 http_bytes_uploaded_,
1070 .network_duration = network_stopwatch_->GetTotalDuration()};
1071 }
1072
1073 } // namespace client
1074 } // namespace fcp
1075