xref: /aosp_15_r20/external/federated-compute/fcp/client/fl_runner.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2020 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "fcp/client/fl_runner.h"
17 
18 #include <fcntl.h>
19 
20 #include <fstream>
21 #include <functional>
22 #include <map>
23 #include <memory>
24 #include <optional>
25 #include <string>
26 #include <utility>
27 #include <variant>
28 #include <vector>
29 
30 #include "absl/status/status.h"
31 #include "absl/status/statusor.h"
32 #include "absl/strings/cord.h"
33 #include "absl/time/time.h"
34 #include "fcp/base/clock.h"
35 #include "fcp/base/monitoring.h"
36 #include "fcp/base/platform.h"
37 // #include "fcp/client/cache/file_backed_resource_cache.h"
38 #include "fcp/client/cache/resource_cache.h"
39 #include "fcp/client/engine/common.h"
40 #include "fcp/client/engine/engine.pb.h"
41 #include "fcp/client/engine/example_iterator_factory.h"
42 #include "fcp/client/engine/example_query_plan_engine.h"
43 #include "fcp/client/engine/plan_engine_helpers.h"
44 #include "fcp/client/opstats/opstats_utils.h"
45 #include "fcp/client/parsing_utils.h"
46 #ifdef FCP_CLIENT_SUPPORT_TFMOBILE
47 #include "fcp/client/engine/simple_plan_engine.h"
48 #endif
49 #include "fcp/client/engine/tflite_plan_engine.h"
50 #include "fcp/client/event_publisher.h"
51 #include "fcp/client/federated_protocol.h"
52 #include "fcp/client/federated_protocol_util.h"
53 #include "fcp/client/files.h"
54 #include "fcp/client/fl_runner.pb.h"
55 #include "fcp/client/flags.h"
56 #include "fcp/client/http/http_federated_protocol.h"
57 #ifdef FCP_CLIENT_SUPPORT_GRPC
58 #include "fcp/client/grpc_federated_protocol.h"
59 #endif
60 #include "fcp/client/interruptible_runner.h"
61 #include "fcp/client/log_manager.h"
62 #include "fcp/client/opstats/opstats_example_store.h"
63 #include "fcp/client/phase_logger_impl.h"
64 #include "fcp/client/secagg_runner.h"
65 #include "fcp/client/selector_context.pb.h"
66 #include "fcp/client/simple_task_environment.h"
67 #include "fcp/protos/federated_api.pb.h"
68 #include "fcp/protos/federatedcompute/eligibility_eval_tasks.pb.h"
69 #include "fcp/protos/opstats.pb.h"
70 #include "fcp/protos/plan.pb.h"
71 #include "openssl/digest.h"
72 #include "openssl/evp.h"
73 #include "tensorflow/core/framework/tensor.h"
74 #include "tensorflow/core/framework/tensor.pb.h"
75 #include "tensorflow/core/framework/tensor_shape.pb.h"
76 #include "tensorflow/core/protobuf/struct.pb.h"
77 namespace fcp {
78 namespace client {
79 using ::fcp::client::opstats::OpStatsLogger;
80 using ::google::internal::federated::plan::AggregationConfig;
81 using ::google::internal::federated::plan::ClientOnlyPlan;
82 using ::google::internal::federated::plan::FederatedComputeEligibilityIORouter;
83 using ::google::internal::federated::plan::FederatedComputeIORouter;
84 using ::google::internal::federated::plan::TensorflowSpec;
85 using ::google::internal::federatedcompute::v1::PopulationEligibilitySpec;
86 using ::google::internal::federatedml::v2::RetryWindow;
87 using ::google::internal::federatedml::v2::TaskEligibilityInfo;
88 using TfLiteInputs = absl::flat_hash_map<std::string, std::string>;
89 namespace {
90 template <typename T>
AddValuesToQuantized(QuantizedTensor * quantized,const tensorflow::Tensor & tensor)91 void AddValuesToQuantized(QuantizedTensor* quantized,
92                           const tensorflow::Tensor& tensor) {
93   auto flat_tensor = tensor.flat<T>();
94   quantized->values.reserve(quantized->values.size() + flat_tensor.size());
95   for (int i = 0; i < flat_tensor.size(); i++) {
96     quantized->values.push_back(flat_tensor(i));
97   }
98 }
ComputeSHA256FromStringOrCord(std::variant<std::string,absl::Cord> data)99 std::string ComputeSHA256FromStringOrCord(
100     std::variant<std::string, absl::Cord> data) {
101   std::unique_ptr<EVP_MD_CTX, void (*)(EVP_MD_CTX*)> mdctx(EVP_MD_CTX_create(),
102                                                            EVP_MD_CTX_destroy);
103   FCP_CHECK(EVP_DigestInit_ex(mdctx.get(), EVP_sha256(), nullptr));
104   std::string plan_str;
105   if (std::holds_alternative<std::string>(data)) {
106     plan_str = std::get<std::string>(data);
107   } else {
108     plan_str = std::string(std::get<absl::Cord>(data));
109   }
110   FCP_CHECK(EVP_DigestUpdate(mdctx.get(), plan_str.c_str(), sizeof(int)));
111   const int hash_len = 32;  // 32 bytes for SHA-256.
112   uint8_t computation_id_bytes[hash_len];
113   FCP_CHECK(EVP_DigestFinal_ex(mdctx.get(), computation_id_bytes, nullptr));
114   return std::string(reinterpret_cast<char const*>(computation_id_bytes),
115                      hash_len);
116 }
117 struct PlanResultAndCheckpointFile {
PlanResultAndCheckpointFilefcp::client::__anon340cd2590111::PlanResultAndCheckpointFile118   explicit PlanResultAndCheckpointFile(engine::PlanResult plan_result)
119       : plan_result(std::move(plan_result)) {}
120   engine::PlanResult plan_result;
121   std::string checkpoint_file;
122   PlanResultAndCheckpointFile(PlanResultAndCheckpointFile&&) = default;
123   PlanResultAndCheckpointFile& operator=(PlanResultAndCheckpointFile&&) =
124       default;
125   // Disallow copy and assign.
126   PlanResultAndCheckpointFile(const PlanResultAndCheckpointFile&) = delete;
127   PlanResultAndCheckpointFile& operator=(const PlanResultAndCheckpointFile&) =
128       delete;
129 };
130 // Creates computation results. The method checks for SecAgg tensors only if
131 // `tensorflow_spec != nullptr`.
CreateComputationResults(const TensorflowSpec * tensorflow_spec,const PlanResultAndCheckpointFile & plan_result_and_checkpoint_file)132 absl::StatusOr<ComputationResults> CreateComputationResults(
133     const TensorflowSpec* tensorflow_spec,
134     const PlanResultAndCheckpointFile& plan_result_and_checkpoint_file) {
135   const auto& [plan_result, checkpoint_file] = plan_result_and_checkpoint_file;
136   if (plan_result.outcome != engine::PlanOutcome::kSuccess) {
137     return absl::InvalidArgumentError("Computation failed.");
138   }
139   ComputationResults computation_results;
140   if (tensorflow_spec != nullptr) {
141     for (int i = 0; i < plan_result.output_names.size(); i++) {
142       QuantizedTensor quantized;
143       const auto& output_tensor = plan_result.output_tensors[i];
144       switch (output_tensor.dtype()) {
145         case tensorflow::DT_INT8:
146           AddValuesToQuantized<int8_t>(&quantized, output_tensor);
147           quantized.bitwidth = 7;
148           break;
149         case tensorflow::DT_UINT8:
150           AddValuesToQuantized<uint8_t>(&quantized, output_tensor);
151           quantized.bitwidth = 8;
152           break;
153         case tensorflow::DT_INT16:
154           AddValuesToQuantized<int16_t>(&quantized, output_tensor);
155           quantized.bitwidth = 15;
156           break;
157         case tensorflow::DT_UINT16:
158           AddValuesToQuantized<uint16_t>(&quantized, output_tensor);
159           quantized.bitwidth = 16;
160           break;
161         case tensorflow::DT_INT32:
162           AddValuesToQuantized<int32_t>(&quantized, output_tensor);
163           quantized.bitwidth = 31;
164           break;
165         case tensorflow::DT_INT64:
166           AddValuesToQuantized<tensorflow::int64>(&quantized, output_tensor);
167           quantized.bitwidth = 62;
168           break;
169         default:
170           return absl::InvalidArgumentError(
171               absl::StrCat("Tensor of type",
172                            tensorflow::DataType_Name(output_tensor.dtype()),
173                            "could not be converted to quantized value"));
174       }
175       computation_results[plan_result.output_names[i]] = std::move(quantized);
176     }
177     // Add dimensions to QuantizedTensors.
178     for (const tensorflow::TensorSpecProto& tensor_spec :
179          tensorflow_spec->output_tensor_specs()) {
180       if (computation_results.find(tensor_spec.name()) !=
181           computation_results.end()) {
182         for (const tensorflow::TensorShapeProto_Dim& dim :
183              tensor_spec.shape().dim()) {
184           std::get<QuantizedTensor>(computation_results[tensor_spec.name()])
185               .dimensions.push_back(dim.size());
186         }
187       }
188     }
189   }
190   // Name of the TF checkpoint inside the aggregand map in the Checkpoint
191   // protobuf. This field name is ignored by the server.
192   if (!checkpoint_file.empty()) {
193     FCP_ASSIGN_OR_RETURN(std::string tf_checkpoint,
194                          fcp::ReadFileToString(checkpoint_file));
195     computation_results[std::string(kTensorflowCheckpointAggregand)] =
196         std::move(tf_checkpoint);
197   }
198   return computation_results;
199 }
200 #ifdef FCP_CLIENT_SUPPORT_TFMOBILE
201 std::unique_ptr<std::vector<std::pair<std::string, tensorflow::Tensor>>>
ConstructInputsForEligibilityEvalPlan(const FederatedComputeEligibilityIORouter & io_router,const std::string & checkpoint_input_filename)202 ConstructInputsForEligibilityEvalPlan(
203     const FederatedComputeEligibilityIORouter& io_router,
204     const std::string& checkpoint_input_filename) {
205   auto inputs = std::make_unique<
206       std::vector<std::pair<std::string, tensorflow::Tensor>>>();
207   if (!io_router.input_filepath_tensor_name().empty()) {
208     tensorflow::Tensor input_filepath(tensorflow::DT_STRING, {});
209     input_filepath.scalar<tensorflow::tstring>()() = checkpoint_input_filename;
210     inputs->push_back({io_router.input_filepath_tensor_name(), input_filepath});
211   }
212   return inputs;
213 }
214 #endif
ConstructTfLiteInputsForEligibilityEvalPlan(const FederatedComputeEligibilityIORouter & io_router,const std::string & checkpoint_input_filename)215 std::unique_ptr<TfLiteInputs> ConstructTfLiteInputsForEligibilityEvalPlan(
216     const FederatedComputeEligibilityIORouter& io_router,
217     const std::string& checkpoint_input_filename) {
218   auto inputs = std::make_unique<TfLiteInputs>();
219   if (!io_router.input_filepath_tensor_name().empty()) {
220     (*inputs)[io_router.input_filepath_tensor_name()] =
221         checkpoint_input_filename;
222   }
223   return inputs;
224 }
225 // Returns the cumulative network stats (those incurred up until this point in
226 // time).
227 //
228 // The `FederatedSelectManager` object may be null, if it is know that there
229 // has been no network usage from it yet.
GetCumulativeNetworkStats(FederatedProtocol * federated_protocol,FederatedSelectManager * fedselect_manager)230 NetworkStats GetCumulativeNetworkStats(
231     FederatedProtocol* federated_protocol,
232     FederatedSelectManager* fedselect_manager) {
233   NetworkStats result = federated_protocol->GetNetworkStats();
234   if (fedselect_manager != nullptr) {
235     result = result + fedselect_manager->GetNetworkStats();
236   }
237   return result;
238 }
239 // Returns the newly incurred network stats since the previous snapshot of
240 // stats (the `reference_point` argument).
GetNetworkStatsSince(FederatedProtocol * federated_protocol,FederatedSelectManager * fedselect_manager,const NetworkStats & reference_point)241 NetworkStats GetNetworkStatsSince(FederatedProtocol* federated_protocol,
242                                   FederatedSelectManager* fedselect_manager,
243                                   const NetworkStats& reference_point) {
244   return GetCumulativeNetworkStats(federated_protocol, fedselect_manager) -
245          reference_point;
246 }
247 // Updates the fields of `FLRunnerResult` that should always be updated after
248 // each interaction with the `FederatedProtocol` or `FederatedSelectManager`
249 // objects.
250 //
251 // The `FederatedSelectManager` object may be null, if it is know that there
252 // has been no network usage from it yet.
UpdateRetryWindowAndNetworkStats(FederatedProtocol & federated_protocol,FederatedSelectManager * fedselect_manager,PhaseLogger & phase_logger,FLRunnerResult & fl_runner_result)253 void UpdateRetryWindowAndNetworkStats(FederatedProtocol& federated_protocol,
254                                       FederatedSelectManager* fedselect_manager,
255                                       PhaseLogger& phase_logger,
256                                       FLRunnerResult& fl_runner_result) {
257   // Update the result's retry window to the most recent one.
258   auto retry_window = federated_protocol.GetLatestRetryWindow();
259   RetryInfo retry_info;
260   *retry_info.mutable_retry_token() = retry_window.retry_token();
261   *retry_info.mutable_minimum_delay() = retry_window.delay_min();
262   *fl_runner_result.mutable_retry_info() = retry_info;
263   phase_logger.UpdateRetryWindowAndNetworkStats(
264       retry_window,
265       GetCumulativeNetworkStats(&federated_protocol, fedselect_manager));
266 }
267 // Creates an ExampleIteratorFactory that routes queries to the
268 // SimpleTaskEnvironment::CreateExampleIterator() method.
269 std::unique_ptr<engine::ExampleIteratorFactory>
CreateSimpleTaskEnvironmentIteratorFactory(SimpleTaskEnvironment * task_env,const SelectorContext & selector_context)270 CreateSimpleTaskEnvironmentIteratorFactory(
271     SimpleTaskEnvironment* task_env, const SelectorContext& selector_context) {
272   return std::make_unique<engine::FunctionalExampleIteratorFactory>(
273       /*can_handle_func=*/
274       [](const google::internal::federated::plan::ExampleSelector&) {
275         // The SimpleTaskEnvironment-based ExampleIteratorFactory should
276         // be the catch-all factory that is able to handle all queries
277         // that no other ExampleIteratorFactory is able to handle.
278         return true;
279       },
280       /*create_iterator_func=*/
281       [task_env, selector_context](
282           const google::internal::federated::plan::ExampleSelector&
283               example_selector) {
284         return task_env->CreateExampleIterator(example_selector,
285                                                selector_context);
286       },
287       /*should_collect_stats=*/true);
288 }
RunEligibilityEvalPlanWithTensorflowSpec(std::vector<engine::ExampleIteratorFactory * > example_iterator_factories,std::function<bool ()> should_abort,LogManager * log_manager,OpStatsLogger * opstats_logger,const Flags * flags,const ClientOnlyPlan & client_plan,const std::string & checkpoint_input_filename,const fcp::client::InterruptibleRunner::TimingConfig & timing_config,const absl::Time run_plan_start_time,const absl::Time reference_time)289 engine::PlanResult RunEligibilityEvalPlanWithTensorflowSpec(
290     std::vector<engine::ExampleIteratorFactory*> example_iterator_factories,
291     std::function<bool()> should_abort, LogManager* log_manager,
292     OpStatsLogger* opstats_logger, const Flags* flags,
293     const ClientOnlyPlan& client_plan,
294     const std::string& checkpoint_input_filename,
295     const fcp::client::InterruptibleRunner::TimingConfig& timing_config,
296     const absl::Time run_plan_start_time, const absl::Time reference_time) {
297   // Check that this is a TensorflowSpec-based plan for federated eligibility
298   // computation.
299   if (!client_plan.phase().has_tensorflow_spec() ||
300       !client_plan.phase().has_federated_compute_eligibility()) {
301     return engine::PlanResult(
302         engine::PlanOutcome::kInvalidArgument,
303         absl::InvalidArgumentError("Invalid eligibility eval plan"));
304   }
305   const FederatedComputeEligibilityIORouter& io_router =
306       client_plan.phase().federated_compute_eligibility();
307   std::vector<std::string> output_names = {
308       io_router.task_eligibility_info_tensor_name()};
309   if (!client_plan.tflite_graph().empty()) {
310     log_manager->LogDiag(
311         ProdDiagCode::BACKGROUND_TRAINING_TFLITE_MODEL_INCLUDED);
312   }
313   if (flags->use_tflite_training() && !client_plan.tflite_graph().empty()) {
314     std::unique_ptr<TfLiteInputs> tflite_inputs =
315         ConstructTfLiteInputsForEligibilityEvalPlan(io_router,
316                                                     checkpoint_input_filename);
317     engine::TfLitePlanEngine plan_engine(example_iterator_factories,
318                                          should_abort, log_manager,
319                                          opstats_logger, flags, &timing_config);
320     return plan_engine.RunPlan(client_plan.phase().tensorflow_spec(),
321                                client_plan.tflite_graph(),
322                                std::move(tflite_inputs), output_names);
323   }
324 #ifdef FCP_CLIENT_SUPPORT_TFMOBILE
325   // Construct input tensors and output tensor names based on the values in
326   // the FederatedComputeEligibilityIORouter message.
327   auto inputs = ConstructInputsForEligibilityEvalPlan(
328       io_router, checkpoint_input_filename);
329   // Run plan and get a set of output tensors back.
330   engine::SimplePlanEngine plan_engine(
331       example_iterator_factories, should_abort, log_manager, opstats_logger,
332       &timing_config, flags->support_constant_tf_inputs());
333   return plan_engine.RunPlan(
334       client_plan.phase().tensorflow_spec(), client_plan.graph(),
335       client_plan.tensorflow_config_proto(), std::move(inputs), output_names);
336 #else
337   return engine::PlanResult(
338       engine::PlanOutcome::kTensorflowError,
339       absl::InternalError("No eligibility eval plan engine enabled"));
340 #endif
341 }
342 // Validates the output tensors that resulted from executing the plan, and
343 // then parses the output into a TaskEligibilityInfo proto. Returns an error
344 // if validation or parsing failed.
ParseEligibilityEvalPlanOutput(const std::vector<tensorflow::Tensor> & output_tensors)345 absl::StatusOr<TaskEligibilityInfo> ParseEligibilityEvalPlanOutput(
346     const std::vector<tensorflow::Tensor>& output_tensors) {
347   auto output_size = output_tensors.size();
348   if (output_size != 1) {
349     return absl::InvalidArgumentError(
350         absl::StrCat("Unexpected number of output tensors: ", output_size));
351   }
352   auto output_elements = output_tensors[0].NumElements();
353   if (output_elements != 1) {
354     return absl::InvalidArgumentError(absl::StrCat(
355         "Unexpected number of output tensor elements: ", output_elements));
356   }
357   tensorflow::DataType output_type = output_tensors[0].dtype();
358   if (output_type != tensorflow::DT_STRING) {
359     return absl::InvalidArgumentError(
360         absl::StrCat("Unexpected output tensor type: ", output_type));
361   }
362   // Extract the serialized TaskEligibilityInfo proto from the tensor and
363   // parse it.
364   // First, convert the output Tensor into a Scalar (= a TensorMap with 1
365   // element), then use its operator() to access the actual data.
366   const tensorflow::tstring& serialized_output =
367       output_tensors[0].scalar<const tensorflow::tstring>()();
368   TaskEligibilityInfo parsed_output;
369   if (!parsed_output.ParseFromString(serialized_output)) {
370     return absl::InvalidArgumentError("Could not parse output proto");
371   }
372   return parsed_output;
373 }
374 #ifdef FCP_CLIENT_SUPPORT_TFMOBILE
375 std::unique_ptr<std::vector<std::pair<std::string, tensorflow::Tensor>>>
ConstructInputsForTensorflowSpecPlan(const FederatedComputeIORouter & io_router,const std::string & checkpoint_input_filename,const std::string & checkpoint_output_filename)376 ConstructInputsForTensorflowSpecPlan(
377     const FederatedComputeIORouter& io_router,
378     const std::string& checkpoint_input_filename,
379     const std::string& checkpoint_output_filename) {
380   auto inputs = std::make_unique<
381       std::vector<std::pair<std::string, tensorflow::Tensor>>>();
382   if (!io_router.input_filepath_tensor_name().empty()) {
383     tensorflow::Tensor input_filepath(tensorflow::DT_STRING, {});
384     input_filepath.scalar<tensorflow::tstring>()() = checkpoint_input_filename;
385     inputs->push_back({io_router.input_filepath_tensor_name(), input_filepath});
386   }
387   if (!io_router.output_filepath_tensor_name().empty()) {
388     tensorflow::Tensor output_filepath(tensorflow::DT_STRING, {});
389     output_filepath.scalar<tensorflow::tstring>()() =
390         checkpoint_output_filename;
391     inputs->push_back(
392         {io_router.output_filepath_tensor_name(), output_filepath});
393   }
394   return inputs;
395 }
396 #endif
ConstructTFLiteInputsForTensorflowSpecPlan(const FederatedComputeIORouter & io_router,const std::string & checkpoint_input_filename,const std::string & checkpoint_output_filename)397 std::unique_ptr<TfLiteInputs> ConstructTFLiteInputsForTensorflowSpecPlan(
398     const FederatedComputeIORouter& io_router,
399     const std::string& checkpoint_input_filename,
400     const std::string& checkpoint_output_filename) {
401   auto inputs = std::make_unique<TfLiteInputs>();
402   if (!io_router.input_filepath_tensor_name().empty()) {
403     (*inputs)[io_router.input_filepath_tensor_name()] =
404         checkpoint_input_filename;
405   }
406   if (!io_router.output_filepath_tensor_name().empty()) {
407     (*inputs)[io_router.output_filepath_tensor_name()] =
408         checkpoint_output_filename;
409   }
410   return inputs;
411 }
ConstructOutputsWithDeterministicOrder(const TensorflowSpec & tensorflow_spec,const FederatedComputeIORouter & io_router)412 absl::StatusOr<std::vector<std::string>> ConstructOutputsWithDeterministicOrder(
413     const TensorflowSpec& tensorflow_spec,
414     const FederatedComputeIORouter& io_router) {
415   std::vector<std::string> output_names;
416   // The order of output tensor names should match the order in
417   // TensorflowSpec.
418   for (const auto& output_tensor_spec : tensorflow_spec.output_tensor_specs()) {
419     std::string tensor_name = output_tensor_spec.name();
420     if (!io_router.aggregations().contains(tensor_name) ||
421         !io_router.aggregations().at(tensor_name).has_secure_aggregation()) {
422       return absl::InvalidArgumentError(
423           "Output tensor is missing in AggregationConfig, or has unsupported "
424           "aggregation type.");
425     }
426     output_names.push_back(tensor_name);
427   }
428   return output_names;
429 }
RunPlanWithTensorflowSpec(std::vector<engine::ExampleIteratorFactory * > example_iterator_factories,std::function<bool ()> should_abort,LogManager * log_manager,OpStatsLogger * opstats_logger,const Flags * flags,const ClientOnlyPlan & client_plan,const std::string & checkpoint_input_filename,const std::string & checkpoint_output_filename,const fcp::client::InterruptibleRunner::TimingConfig & timing_config)430 PlanResultAndCheckpointFile RunPlanWithTensorflowSpec(
431     std::vector<engine::ExampleIteratorFactory*> example_iterator_factories,
432     std::function<bool()> should_abort, LogManager* log_manager,
433     OpStatsLogger* opstats_logger, const Flags* flags,
434     const ClientOnlyPlan& client_plan,
435     const std::string& checkpoint_input_filename,
436     const std::string& checkpoint_output_filename,
437     const fcp::client::InterruptibleRunner::TimingConfig& timing_config) {
438   if (!client_plan.phase().has_tensorflow_spec()) {
439     return PlanResultAndCheckpointFile(engine::PlanResult(
440         engine::PlanOutcome::kInvalidArgument,
441         absl::InvalidArgumentError("Plan must include TensorflowSpec.")));
442   }
443   if (!client_plan.phase().has_federated_compute()) {
444     return PlanResultAndCheckpointFile(engine::PlanResult(
445         engine::PlanOutcome::kInvalidArgument,
446         absl::InvalidArgumentError("Invalid TensorflowSpec-based plan")));
447   }
448   // Get the output tensor names.
449   absl::StatusOr<std::vector<std::string>> output_names;
450   output_names = ConstructOutputsWithDeterministicOrder(
451       client_plan.phase().tensorflow_spec(),
452       client_plan.phase().federated_compute());
453   if (!output_names.ok()) {
454     return PlanResultAndCheckpointFile(engine::PlanResult(
455         engine::PlanOutcome::kInvalidArgument, output_names.status()));
456   }
457   // Run plan and get a set of output tensors back.
458   if (flags->use_tflite_training() && !client_plan.tflite_graph().empty()) {
459     std::unique_ptr<TfLiteInputs> tflite_inputs =
460         ConstructTFLiteInputsForTensorflowSpecPlan(
461             client_plan.phase().federated_compute(), checkpoint_input_filename,
462             checkpoint_output_filename);
463     engine::TfLitePlanEngine plan_engine(example_iterator_factories,
464                                          should_abort, log_manager,
465                                          opstats_logger, flags, &timing_config);
466     engine::PlanResult plan_result = plan_engine.RunPlan(
467         client_plan.phase().tensorflow_spec(), client_plan.tflite_graph(),
468         std::move(tflite_inputs), *output_names);
469     PlanResultAndCheckpointFile result(std::move(plan_result));
470     result.checkpoint_file = checkpoint_output_filename;
471     return result;
472   }
473 #ifdef FCP_CLIENT_SUPPORT_TFMOBILE
474   // Construct input tensors based on the values in the
475   // FederatedComputeIORouter message and create a temporary file for the
476   // output checkpoint if needed.
477   auto inputs = ConstructInputsForTensorflowSpecPlan(
478       client_plan.phase().federated_compute(), checkpoint_input_filename,
479       checkpoint_output_filename);
480   engine::SimplePlanEngine plan_engine(
481       example_iterator_factories, should_abort, log_manager, opstats_logger,
482       &timing_config, flags->support_constant_tf_inputs());
483   engine::PlanResult plan_result = plan_engine.RunPlan(
484       client_plan.phase().tensorflow_spec(), client_plan.graph(),
485       client_plan.tensorflow_config_proto(), std::move(inputs), *output_names);
486   PlanResultAndCheckpointFile result(std::move(plan_result));
487   result.checkpoint_file = checkpoint_output_filename;
488   return result;
489 #else
490   return PlanResultAndCheckpointFile(
491       engine::PlanResult(engine::PlanOutcome::kTensorflowError,
492                          absl::InternalError("No plan engine enabled")));
493 #endif
494 }
RunPlanWithExampleQuerySpec(std::vector<engine::ExampleIteratorFactory * > example_iterator_factories,OpStatsLogger * opstats_logger,const Flags * flags,const ClientOnlyPlan & client_plan,const std::string & checkpoint_output_filename)495 PlanResultAndCheckpointFile RunPlanWithExampleQuerySpec(
496     std::vector<engine::ExampleIteratorFactory*> example_iterator_factories,
497     OpStatsLogger* opstats_logger, const Flags* flags,
498     const ClientOnlyPlan& client_plan,
499     const std::string& checkpoint_output_filename) {
500   if (!client_plan.phase().has_example_query_spec()) {
501     return PlanResultAndCheckpointFile(engine::PlanResult(
502         engine::PlanOutcome::kInvalidArgument,
503         absl::InvalidArgumentError("Plan must include ExampleQuerySpec")));
504   }
505   if (!flags->enable_example_query_plan_engine()) {
506     // Example query plan received while the flag is off.
507     return PlanResultAndCheckpointFile(engine::PlanResult(
508         engine::PlanOutcome::kInvalidArgument,
509         absl::InvalidArgumentError(
510             "Example query plan received while the flag is off")));
511   }
512   if (!client_plan.phase().has_federated_example_query()) {
513     return PlanResultAndCheckpointFile(engine::PlanResult(
514         engine::PlanOutcome::kInvalidArgument,
515         absl::InvalidArgumentError("Invalid ExampleQuerySpec-based plan")));
516   }
517   for (const auto& example_query :
518        client_plan.phase().example_query_spec().example_queries()) {
519     for (auto const& [vector_name, spec] :
520          example_query.output_vector_specs()) {
521       const auto& aggregations =
522           client_plan.phase().federated_example_query().aggregations();
523       if ((aggregations.find(vector_name) == aggregations.end()) ||
524           !aggregations.at(vector_name).has_tf_v1_checkpoint_aggregation()) {
525         return PlanResultAndCheckpointFile(engine::PlanResult(
526             engine::PlanOutcome::kInvalidArgument,
527             absl::InvalidArgumentError("Output vector is missing in "
528                                        "AggregationConfig, or has unsupported "
529                                        "aggregation type.")));
530       }
531     }
532   }
533   engine::ExampleQueryPlanEngine plan_engine(example_iterator_factories,
534                                              opstats_logger);
535   engine::PlanResult plan_result = plan_engine.RunPlan(
536       client_plan.phase().example_query_spec(), checkpoint_output_filename);
537   PlanResultAndCheckpointFile result(std::move(plan_result));
538   result.checkpoint_file = checkpoint_output_filename;
539   return result;
540 }
LogEligibilityEvalComputationOutcome(PhaseLogger & phase_logger,engine::PlanResult plan_result,const absl::Status & eligibility_info_parsing_status,absl::Time run_plan_start_time,absl::Time reference_time)541 void LogEligibilityEvalComputationOutcome(
542     PhaseLogger& phase_logger, engine::PlanResult plan_result,
543     const absl::Status& eligibility_info_parsing_status,
544     absl::Time run_plan_start_time, absl::Time reference_time) {
545   switch (plan_result.outcome) {
546     case engine::PlanOutcome::kSuccess: {
547       if (eligibility_info_parsing_status.ok()) {
548         phase_logger.LogEligibilityEvalComputationCompleted(
549             plan_result.example_stats, run_plan_start_time, reference_time);
550       } else {
551         phase_logger.LogEligibilityEvalComputationTensorflowError(
552             eligibility_info_parsing_status, plan_result.example_stats,
553             run_plan_start_time, reference_time);
554         FCP_LOG(ERROR) << eligibility_info_parsing_status.message();
555       }
556       break;
557     }
558     case engine::PlanOutcome::kInterrupted:
559       phase_logger.LogEligibilityEvalComputationInterrupted(
560           plan_result.original_status, plan_result.example_stats,
561           run_plan_start_time, reference_time);
562       break;
563     case engine::PlanOutcome::kInvalidArgument:
564       phase_logger.LogEligibilityEvalComputationInvalidArgument(
565           plan_result.original_status, plan_result.example_stats,
566           run_plan_start_time);
567       break;
568     case engine::PlanOutcome::kTensorflowError:
569       phase_logger.LogEligibilityEvalComputationTensorflowError(
570           plan_result.original_status, plan_result.example_stats,
571           run_plan_start_time, reference_time);
572       break;
573     case engine::PlanOutcome::kExampleIteratorError:
574       phase_logger.LogEligibilityEvalComputationExampleIteratorError(
575           plan_result.original_status, plan_result.example_stats,
576           run_plan_start_time);
577       break;
578   }
579 }
LogComputationOutcome(const engine::PlanResult & plan_result,absl::Status computation_results_parsing_status,PhaseLogger & phase_logger,const NetworkStats & network_stats,absl::Time run_plan_start_time,absl::Time reference_time)580 void LogComputationOutcome(const engine::PlanResult& plan_result,
581                            absl::Status computation_results_parsing_status,
582                            PhaseLogger& phase_logger,
583                            const NetworkStats& network_stats,
584                            absl::Time run_plan_start_time,
585                            absl::Time reference_time) {
586   switch (plan_result.outcome) {
587     case engine::PlanOutcome::kSuccess: {
588       if (computation_results_parsing_status.ok()) {
589         phase_logger.LogComputationCompleted(plan_result.example_stats,
590                                              network_stats, run_plan_start_time,
591                                              reference_time);
592       } else {
593         phase_logger.LogComputationTensorflowError(
594             computation_results_parsing_status, plan_result.example_stats,
595             network_stats, run_plan_start_time, reference_time);
596       }
597       break;
598     }
599     case engine::PlanOutcome::kInterrupted:
600       phase_logger.LogComputationInterrupted(
601           plan_result.original_status, plan_result.example_stats, network_stats,
602           run_plan_start_time, reference_time);
603       break;
604     case engine::PlanOutcome::kInvalidArgument:
605       phase_logger.LogComputationInvalidArgument(
606           plan_result.original_status, plan_result.example_stats, network_stats,
607           run_plan_start_time);
608       break;
609     case engine::PlanOutcome::kTensorflowError:
610       phase_logger.LogComputationTensorflowError(
611           plan_result.original_status, plan_result.example_stats, network_stats,
612           run_plan_start_time, reference_time);
613       break;
614     case engine::PlanOutcome::kExampleIteratorError:
615       phase_logger.LogComputationExampleIteratorError(
616           plan_result.original_status, plan_result.example_stats, network_stats,
617           run_plan_start_time);
618       break;
619   }
620 }
LogResultUploadStatus(PhaseLogger & phase_logger,absl::Status result,const NetworkStats & network_stats,absl::Time time_before_result_upload,absl::Time reference_time)621 void LogResultUploadStatus(PhaseLogger& phase_logger, absl::Status result,
622                            const NetworkStats& network_stats,
623                            absl::Time time_before_result_upload,
624                            absl::Time reference_time) {
625   if (result.ok()) {
626     phase_logger.LogResultUploadCompleted(
627         network_stats, time_before_result_upload, reference_time);
628   } else {
629     auto message =
630         absl::StrCat("Error reporting results: code: ", result.code(),
631                      ", message: ", result.message());
632     FCP_LOG(INFO) << message;
633     if (result.code() == absl::StatusCode::kAborted) {
634       phase_logger.LogResultUploadServerAborted(
635           result, network_stats, time_before_result_upload, reference_time);
636     } else if (result.code() == absl::StatusCode::kCancelled) {
637       phase_logger.LogResultUploadClientInterrupted(
638           result, network_stats, time_before_result_upload, reference_time);
639     } else {
640       phase_logger.LogResultUploadIOError(
641           result, network_stats, time_before_result_upload, reference_time);
642     }
643   }
644 }
LogFailureUploadStatus(PhaseLogger & phase_logger,absl::Status result,const NetworkStats & network_stats,absl::Time time_before_failure_upload,absl::Time reference_time)645 void LogFailureUploadStatus(PhaseLogger& phase_logger, absl::Status result,
646                             const NetworkStats& network_stats,
647                             absl::Time time_before_failure_upload,
648                             absl::Time reference_time) {
649   if (result.ok()) {
650     phase_logger.LogFailureUploadCompleted(
651         network_stats, time_before_failure_upload, reference_time);
652   } else {
653     auto message = absl::StrCat("Error reporting computation failure: code: ",
654                                 result.code(), ", message: ", result.message());
655     FCP_LOG(INFO) << message;
656     if (result.code() == absl::StatusCode::kAborted) {
657       phase_logger.LogFailureUploadServerAborted(
658           result, network_stats, time_before_failure_upload, reference_time);
659     } else if (result.code() == absl::StatusCode::kCancelled) {
660       phase_logger.LogFailureUploadClientInterrupted(
661           result, network_stats, time_before_failure_upload, reference_time);
662     } else {
663       phase_logger.LogFailureUploadIOError(
664           result, network_stats, time_before_failure_upload, reference_time);
665     }
666   }
667 }
ReportPlanResult(FederatedProtocol * federated_protocol,PhaseLogger & phase_logger,absl::StatusOr<ComputationResults> computation_results,absl::Time run_plan_start_time,absl::Time reference_time)668 absl::Status ReportPlanResult(
669     FederatedProtocol* federated_protocol, PhaseLogger& phase_logger,
670     absl::StatusOr<ComputationResults> computation_results,
671     absl::Time run_plan_start_time, absl::Time reference_time) {
672   const absl::Time before_report_time = absl::Now();
673   // Note that the FederatedSelectManager shouldn't be active anymore during
674   // the reporting of results, so we don't bother passing it to
675   // GetNetworkStatsSince.
676   //
677   // We must return only stats that cover the report phase for the log events
678   // below.
679   const NetworkStats before_report_stats =
680       GetCumulativeNetworkStats(federated_protocol,
681                                 /*fedselect_manager=*/nullptr);
682   absl::Status result = absl::InternalError("");
683   if (computation_results.ok()) {
684     FCP_RETURN_IF_ERROR(phase_logger.LogResultUploadStarted());
685     result = federated_protocol->ReportCompleted(
686         std::move(*computation_results),
687         /*plan_duration=*/absl::Now() - run_plan_start_time, std::nullopt);
688     LogResultUploadStatus(phase_logger, result,
689                           GetNetworkStatsSince(federated_protocol,
690                                                /*fedselect_manager=*/nullptr,
691                                                before_report_stats),
692                           before_report_time, reference_time);
693   } else {
694     FCP_RETURN_IF_ERROR(phase_logger.LogFailureUploadStarted());
695     result = federated_protocol->ReportNotCompleted(
696         engine::PhaseOutcome::ERROR,
697         /*plan_duration=*/absl::Now() - run_plan_start_time, std::nullopt);
698     LogFailureUploadStatus(phase_logger, result,
699                            GetNetworkStatsSince(federated_protocol,
700                                                 /*fedselect_manager=*/nullptr,
701                                                 before_report_stats),
702                            before_report_time, reference_time);
703   }
704   return result;
705 }
706 // Writes the given data to the stream, and returns true if successful and
707 // false if not.
WriteStringOrCordToFstream(std::fstream & stream,const std::variant<std::string,absl::Cord> & data)708 bool WriteStringOrCordToFstream(
709     std::fstream& stream, const std::variant<std::string, absl::Cord>& data) {
710   if (stream.fail()) {
711     return false;
712   }
713   if (std::holds_alternative<std::string>(data)) {
714     return (stream << std::get<std::string>(data)).good();
715   }
716   for (absl::string_view chunk : std::get<absl::Cord>(data).Chunks()) {
717     if (!(stream << chunk).good()) {
718       return false;
719     }
720   }
721   return true;
722 }
723 // Writes the given checkpoint data to a newly created temporary file.
724 // Returns the filename if successful, or an error if the file could not be
725 // created, or if writing to the file failed.
CreateInputCheckpointFile(Files * files,const std::variant<std::string,absl::Cord> & checkpoint)726 absl::StatusOr<std::string> CreateInputCheckpointFile(
727     Files* files, const std::variant<std::string, absl::Cord>& checkpoint) {
728   // Create the temporary checkpoint file.
729   // Deletion of the file is left to the caller / the Files implementation.
730   FCP_ASSIGN_OR_RETURN(absl::StatusOr<std::string> filename,
731                        files->CreateTempFile("init", ".ckp"));
732   // Write the checkpoint data to the file.
733   std::fstream checkpoint_stream(*filename, std::ios_base::out);
734   if (!WriteStringOrCordToFstream(checkpoint_stream, checkpoint)) {
735     return absl::InvalidArgumentError("Failed to write to file");
736   }
737   checkpoint_stream.close();
738   return filename;
739 }
RunEligibilityEvalPlan(const FederatedProtocol::EligibilityEvalTask & eligibility_eval_task,std::vector<engine::ExampleIteratorFactory * > example_iterator_factories,std::function<bool ()> should_abort,PhaseLogger & phase_logger,Files * files,LogManager * log_manager,OpStatsLogger * opstats_logger,const Flags * flags,FederatedProtocol * federated_protocol,const fcp::client::InterruptibleRunner::TimingConfig & timing_config,const absl::Time reference_time,const absl::Time time_before_checkin,const absl::Time time_before_plan_download,const NetworkStats & network_stats)740 absl::StatusOr<std::optional<TaskEligibilityInfo>> RunEligibilityEvalPlan(
741     const FederatedProtocol::EligibilityEvalTask& eligibility_eval_task,
742     std::vector<engine::ExampleIteratorFactory*> example_iterator_factories,
743     std::function<bool()> should_abort, PhaseLogger& phase_logger, Files* files,
744     LogManager* log_manager, OpStatsLogger* opstats_logger, const Flags* flags,
745     FederatedProtocol* federated_protocol,
746     const fcp::client::InterruptibleRunner::TimingConfig& timing_config,
747     const absl::Time reference_time, const absl::Time time_before_checkin,
748     const absl::Time time_before_plan_download,
749     const NetworkStats& network_stats) {
750   ClientOnlyPlan plan;
751   if (!ParseFromStringOrCord(plan, eligibility_eval_task.payloads.plan)) {
752     auto message = "Failed to parse received eligibility eval plan";
753     phase_logger.LogEligibilityEvalCheckinInvalidPayloadError(
754         message, network_stats, time_before_plan_download);
755     FCP_LOG(ERROR) << message;
756     return absl::InternalError(message);
757   }
758   absl::StatusOr<std::string> checkpoint_input_filename =
759       CreateInputCheckpointFile(files,
760                                 eligibility_eval_task.payloads.checkpoint);
761   if (!checkpoint_input_filename.ok()) {
762     auto status = checkpoint_input_filename.status();
763     auto message = absl::StrCat(
764         "Failed to create eligibility eval checkpoint input file: code: ",
765         status.code(), ", message: ", status.message());
766     phase_logger.LogEligibilityEvalCheckinIOError(status, network_stats,
767                                                   time_before_plan_download);
768     FCP_LOG(ERROR) << message;
769     return absl::InternalError("");
770   }
771   phase_logger.LogEligibilityEvalCheckinCompleted(network_stats,
772                                                   /*time_before_checkin=*/
773                                                   time_before_checkin,
774                                                   /*time_before_plan_download=*/
775                                                   time_before_plan_download);
776   absl::Time run_plan_start_time = absl::Now();
777   phase_logger.LogEligibilityEvalComputationStarted();
778   engine::PlanResult plan_result = RunEligibilityEvalPlanWithTensorflowSpec(
779       example_iterator_factories, should_abort, log_manager, opstats_logger,
780       flags, plan, *checkpoint_input_filename, timing_config,
781       run_plan_start_time, reference_time);
782   absl::StatusOr<TaskEligibilityInfo> task_eligibility_info;
783   if (plan_result.outcome == engine::PlanOutcome::kSuccess) {
784     task_eligibility_info =
785         ParseEligibilityEvalPlanOutput(plan_result.output_tensors);
786   }
787   LogEligibilityEvalComputationOutcome(phase_logger, std::move(plan_result),
788                                        task_eligibility_info.status(),
789                                        run_plan_start_time, reference_time);
790   return task_eligibility_info;
791 }
792 struct EligibilityEvalResult {
793   std::optional<TaskEligibilityInfo> task_eligibility_info;
794   std::vector<std::string> task_names_for_multiple_task_assignments;
795 };
796 // Create an EligibilityEvalResult from a TaskEligibilityInfo and
797 // PopulationEligibilitySpec.  If both population_spec and
798 // task_eligibility_info are present, the returned EligibilityEvalResult will
799 // contain a TaskEligibilityInfo which only contains the tasks for single task
800 // assignment, and a vector of task names for multiple task assignment.
CreateEligibilityEvalResult(const std::optional<TaskEligibilityInfo> & task_eligibility_info,const std::optional<PopulationEligibilitySpec> & population_spec)801 EligibilityEvalResult CreateEligibilityEvalResult(
802     const std::optional<TaskEligibilityInfo>& task_eligibility_info,
803     const std::optional<PopulationEligibilitySpec>& population_spec) {
804   EligibilityEvalResult result;
805   if (population_spec.has_value() && task_eligibility_info.has_value()) {
806     absl::flat_hash_set<std::string> task_names_for_multiple_task_assignments;
807     for (const auto& task_info : population_spec.value().task_info()) {
808       if (task_info.task_assignment_mode() ==
809           PopulationEligibilitySpec::TaskInfo::TASK_ASSIGNMENT_MODE_MULTIPLE) {
810         task_names_for_multiple_task_assignments.insert(task_info.task_name());
811       }
812     }
813     TaskEligibilityInfo single_task_assignment_eligibility_info;
814     single_task_assignment_eligibility_info.set_version(
815         task_eligibility_info.value().version());
816     for (const auto& task_weight :
817          task_eligibility_info.value().task_weights()) {
818       if (task_names_for_multiple_task_assignments.contains(
819               task_weight.task_name())) {
820         result.task_names_for_multiple_task_assignments.push_back(
821             task_weight.task_name());
822       } else {
823         *single_task_assignment_eligibility_info.mutable_task_weights()->Add() =
824             task_weight;
825       }
826     }
827     result.task_eligibility_info = single_task_assignment_eligibility_info;
828   } else {
829     result.task_eligibility_info = task_eligibility_info;
830   }
831   return result;
832 }
833 // Issues an eligibility eval checkin request and executes the eligibility
834 // eval task if the server returns one.
835 //
836 // This function modifies the FLRunnerResult with values received over the
837 // course of the eligibility eval protocol interaction.
838 //
839 // Returns:
840 // - the TaskEligibilityInfo produced by the eligibility eval task, if the
841 //   server provided an eligibility eval task to run.
842 // - an std::nullopt if the server indicated that there is no eligibility eval
843 //   task configured for the population.
844 // - an INTERNAL error if the server rejects the client or another error
845 // occurs
846 //   that should abort the training run. The error will already have been
847 //   logged appropriately.
IssueEligibilityEvalCheckinAndRunPlan(std::vector<engine::ExampleIteratorFactory * > example_iterator_factories,std::function<bool ()> should_abort,PhaseLogger & phase_logger,Files * files,LogManager * log_manager,OpStatsLogger * opstats_logger,const Flags * flags,FederatedProtocol * federated_protocol,const fcp::client::InterruptibleRunner::TimingConfig & timing_config,const absl::Time reference_time,FLRunnerResult & fl_runner_result)848 absl::StatusOr<EligibilityEvalResult> IssueEligibilityEvalCheckinAndRunPlan(
849     std::vector<engine::ExampleIteratorFactory*> example_iterator_factories,
850     std::function<bool()> should_abort, PhaseLogger& phase_logger, Files* files,
851     LogManager* log_manager, OpStatsLogger* opstats_logger, const Flags* flags,
852     FederatedProtocol* federated_protocol,
853     const fcp::client::InterruptibleRunner::TimingConfig& timing_config,
854     const absl::Time reference_time, FLRunnerResult& fl_runner_result) {
855   const absl::Time time_before_checkin = absl::Now();
856   const NetworkStats network_stats_before_checkin =
857       GetCumulativeNetworkStats(federated_protocol,
858                                 /*fedselect_manager=*/nullptr);
859   // These fields will, after a successful checkin that resulted in an EET
860   // being received, contain the time at which the EET plan/checkpoint URIs
861   // were received (but not yet downloaded), as well as the cumulative network
862   // stats at that point, allowing us to separately calculate how long it took
863   // to then download the actual payloads.
864   absl::Time time_before_plan_download = time_before_checkin;
865   NetworkStats network_stats_before_plan_download =
866       network_stats_before_checkin;
867   // Log that we are about to check in with the server.
868   phase_logger.LogEligibilityEvalCheckinStarted();
869   // Issue the eligibility eval checkin request (providing a callback that
870   // will be called when an EET is assigned to the task but before its
871   // plan/checkpoint URIs have actually been downloaded).
872   bool plan_uris_received_callback_called = false;
873   std::function<void(const FederatedProtocol::EligibilityEvalTask&)>
874       plan_uris_received_callback =
875           [&time_before_plan_download, &network_stats_before_plan_download,
876            &time_before_checkin, &network_stats_before_checkin,
877            &federated_protocol, &phase_logger,
878            &plan_uris_received_callback_called](
879               const FederatedProtocol::EligibilityEvalTask& task) {
880             // When the plan URIs have been received, we already know the name
881             // of the task we have been assigned, so let's tell the
882             // PhaseLogger.
883             phase_logger.SetModelIdentifier(task.execution_id);
884             // We also should log a corresponding log event.
885             phase_logger.LogEligibilityEvalCheckinPlanUriReceived(
886                 GetNetworkStatsSince(federated_protocol,
887                                      /*fedselect_manager=*/nullptr,
888                                      network_stats_before_checkin),
889                 time_before_checkin);
890             // And we must take a snapshot of the current time & network
891             // stats, so we can distinguish between the duration/network stats
892             // incurred for the checkin request vs. the actual downloading of
893             // the plan/checkpoint resources.
894             time_before_plan_download = absl::Now();
895             network_stats_before_plan_download =
896                 GetCumulativeNetworkStats(federated_protocol,
897                                           /*fedselect_manager=*/nullptr);
898             plan_uris_received_callback_called = true;
899           };
900   absl::StatusOr<FederatedProtocol::EligibilityEvalCheckinResult>
901       eligibility_checkin_result = federated_protocol->EligibilityEvalCheckin(
902           plan_uris_received_callback);
903   UpdateRetryWindowAndNetworkStats(*federated_protocol,
904                                    /*fedselect_manager=*/nullptr, phase_logger,
905                                    fl_runner_result);
906   // It's a bit unfortunate that we have to inspect the checkin_result and
907   // extract the model identifier here rather than further down the function,
908   // but this ensures that the histograms below will have the right model
909   // identifier attached (and we want to also emit the histograms even if we
910   // have failed/rejected checkin outcomes).
911   if (eligibility_checkin_result.ok() &&
912       std::holds_alternative<FederatedProtocol::EligibilityEvalTask>(
913           *eligibility_checkin_result)) {
914     // Make sure that if we received an EligibilityEvalTask, then the callback
915     // should have already been called by this point by the protocol (ensuring
916     // that SetModelIdentifier has been called etc.).
917     FCP_CHECK(plan_uris_received_callback_called);
918   }
919   if (!eligibility_checkin_result.ok()) {
920     auto status = eligibility_checkin_result.status();
921     auto message = absl::StrCat("Error during eligibility eval checkin: code: ",
922                                 status.code(), ", message: ", status.message());
923     if (status.code() == absl::StatusCode::kAborted) {
924       phase_logger.LogEligibilityEvalCheckinServerAborted(
925           status,
926           GetNetworkStatsSince(federated_protocol,
927                                /*fedselect_manager=*/nullptr,
928                                network_stats_before_plan_download),
929           time_before_plan_download);
930     } else if (status.code() == absl::StatusCode::kCancelled) {
931       phase_logger.LogEligibilityEvalCheckinClientInterrupted(
932           status,
933           GetNetworkStatsSince(federated_protocol,
934                                /*fedselect_manager=*/nullptr,
935                                network_stats_before_plan_download),
936           time_before_plan_download);
937     } else if (!status.ok()) {
938       phase_logger.LogEligibilityEvalCheckinIOError(
939           status,
940           GetNetworkStatsSince(federated_protocol,
941                                /*fedselect_manager=*/nullptr,
942                                network_stats_before_plan_download),
943           time_before_plan_download);
944     }
945     FCP_LOG(INFO) << message;
946     return absl::InternalError("");
947   }
948   EligibilityEvalResult result;
949   if (std::holds_alternative<FederatedProtocol::Rejection>(
950           *eligibility_checkin_result)) {
951     phase_logger.LogEligibilityEvalCheckinTurnedAway(
952         GetNetworkStatsSince(federated_protocol,
953                              /*fedselect_manager=*/nullptr,
954                              network_stats_before_checkin),
955         time_before_checkin);
956     // If the server explicitly rejected our request, then we must abort and
957     // we must not proceed to the "checkin" phase below.
958     FCP_LOG(INFO) << "Device rejected by server during eligibility eval "
959                      "checkin; aborting";
960     return absl::InternalError("");
961   } else if (std::holds_alternative<FederatedProtocol::EligibilityEvalDisabled>(
962                  *eligibility_checkin_result)) {
963     phase_logger.LogEligibilityEvalNotConfigured(
964         GetNetworkStatsSince(federated_protocol,
965                              /*fedselect_manager=*/nullptr,
966                              network_stats_before_checkin),
967         time_before_checkin);
968     // If the server indicates that no eligibility eval task is configured for
969     // the population then there is nothing more to do. We simply proceed to
970     // the "checkin" phase below without providing it a TaskEligibilityInfo
971     // proto.
972     result.task_eligibility_info = std::nullopt;
973     return result;
974   }
975   auto eligibility_eval_task =
976       absl::get<FederatedProtocol::EligibilityEvalTask>(
977           *eligibility_checkin_result);
978   // Parse and run the eligibility eval task if the server returned one.
979   // Now we have a EligibilityEvalTask, if an error happens, we will report to
980   // the server via the ReportEligibilityEvalError.
981   absl::StatusOr<std::optional<TaskEligibilityInfo>> task_eligibility_info =
982       RunEligibilityEvalPlan(
983           eligibility_eval_task, example_iterator_factories, should_abort,
984           phase_logger, files, log_manager, opstats_logger, flags,
985           federated_protocol, timing_config, reference_time,
986           /*time_before_checkin=*/time_before_checkin,
987           /*time_before_plan_download=*/time_before_plan_download,
988           GetNetworkStatsSince(federated_protocol,
989                                /*fedselect_manager=*/nullptr,
990                                network_stats_before_plan_download));
991   if (!task_eligibility_info.ok()) {
992     // Note that none of the PhaseLogger methods will reflect the very little
993     // amount of network usage the will be incurred by this protocol request.
994     // We consider this to be OK to keep things simple, and because this
995     // should use such a limited amount of network bandwidth. Do note that the
996     // network usage *will* be correctly reported in the OpStats database.
997     federated_protocol->ReportEligibilityEvalError(
998         absl::Status(task_eligibility_info.status().code(),
999                      "Failed to compute eligibility info"));
1000     UpdateRetryWindowAndNetworkStats(*federated_protocol,
1001                                      /*fedselect_manager=*/nullptr,
1002                                      phase_logger, fl_runner_result);
1003     return task_eligibility_info.status();
1004   }
1005   return CreateEligibilityEvalResult(
1006       *task_eligibility_info,
1007       eligibility_eval_task.population_eligibility_spec);
1008 }
1009 struct CheckinResult {
1010   std::string task_name;
1011   ClientOnlyPlan plan;
1012   int32_t minimum_clients_in_server_visible_aggregate;
1013   std::string checkpoint_input_filename;
1014   std::string computation_id;
1015   std::string federated_select_uri_template;
1016 };
IssueCheckin(PhaseLogger & phase_logger,LogManager * log_manager,Files * files,FederatedProtocol * federated_protocol,std::optional<TaskEligibilityInfo> task_eligibility_info,absl::Time reference_time,const std::string & population_name,FLRunnerResult & fl_runner_result,const Flags * flags)1017 absl::StatusOr<CheckinResult> IssueCheckin(
1018     PhaseLogger& phase_logger, LogManager* log_manager, Files* files,
1019     FederatedProtocol* federated_protocol,
1020     std::optional<TaskEligibilityInfo> task_eligibility_info,
1021     absl::Time reference_time, const std::string& population_name,
1022     FLRunnerResult& fl_runner_result, const Flags* flags) {
1023   absl::Time time_before_checkin = absl::Now();
1024   // We must return only stats that cover the check in phase for the log
1025   // events below.
1026   const NetworkStats network_stats_before_checkin =
1027       GetCumulativeNetworkStats(federated_protocol,
1028                                 /*fedselect_manager=*/nullptr);
1029   // These fields will, after a successful checkin that resulted in a task
1030   // being assigned, contain the time at which the task plan/checkpoint URIs
1031   // were received (but not yet downloaded), as well as the cumulative network
1032   // stats at that point, allowing us to separately calculate how long it took
1033   // to then download the actual payloads.
1034   absl::Time time_before_plan_download = time_before_checkin;
1035   NetworkStats network_stats_before_plan_download =
1036       network_stats_before_checkin;
1037   // Clear the model identifier before check-in, to ensure that the any prior
1038   // eligibility eval task name isn't used any longer.
1039   phase_logger.SetModelIdentifier("");
1040   phase_logger.LogCheckinStarted();
1041   std::string task_name;
1042   // Issue the checkin request (providing a callback that will be called when
1043   // an EET is assigned to the task but before its plan/checkpoint URIs have
1044   // actually been downloaded).
1045   bool plan_uris_received_callback_called = false;
1046   std::function<void(const FederatedProtocol::TaskAssignment&)>
1047       plan_uris_received_callback =
1048           [&time_before_plan_download, &network_stats_before_plan_download,
1049            &time_before_checkin, &network_stats_before_checkin, &task_name,
1050            &federated_protocol, &population_name, &log_manager, &phase_logger,
1051            &plan_uris_received_callback_called](
1052               const FederatedProtocol::TaskAssignment& task_assignment) {
1053             // When the plan URIs have been received, we already know the name
1054             // of the task we have been assigned, so let's tell the
1055             // PhaseLogger.
1056             auto model_identifier = task_assignment.aggregation_session_id;
1057             phase_logger.SetModelIdentifier(model_identifier);
1058             // We also should log a corresponding log event.
1059             task_name = ExtractTaskNameFromAggregationSessionId(
1060                 model_identifier, population_name, *log_manager);
1061             phase_logger.LogCheckinPlanUriReceived(
1062                 task_name,
1063                 GetNetworkStatsSince(federated_protocol,
1064                                      /*fedselect_manager=*/nullptr,
1065                                      network_stats_before_checkin),
1066                 time_before_checkin);
1067             // And we must take a snapshot of the current time & network
1068             // stats, so we can distinguish between the duration/network stats
1069             // incurred for the checkin request vs. the actual downloading of
1070             // the plan/checkpoint resources.
1071             time_before_plan_download = absl::Now();
1072             network_stats_before_plan_download = GetCumulativeNetworkStats(
1073                 federated_protocol, /*fedselect_manager=*/nullptr);
1074             plan_uris_received_callback_called = true;
1075           };
1076   absl::StatusOr<FederatedProtocol::CheckinResult> checkin_result =
1077       federated_protocol->Checkin(task_eligibility_info,
1078                                   plan_uris_received_callback);
1079   UpdateRetryWindowAndNetworkStats(*federated_protocol,
1080                                    /*fedselect_manager=*/nullptr, phase_logger,
1081                                    fl_runner_result);
1082   // It's a bit unfortunate that we have to inspect the checkin_result and
1083   // extract the model identifier here rather than further down the function,
1084   // but this ensures that the histograms below will have the right model
1085   // identifier attached (and we want to also emit the histograms even if we
1086   // have failed/rejected checkin outcomes).
1087   if (checkin_result.ok() &&
1088       std::holds_alternative<FederatedProtocol::TaskAssignment>(
1089           *checkin_result)) {
1090     // Make sure that if we received a TaskAssignment, then the callback
1091     // should have already been called by this point by the protocol (ensuring
1092     // that SetModelIdentifier has been called etc.).
1093     FCP_CHECK(plan_uris_received_callback_called);
1094   }
1095   if (!checkin_result.ok()) {
1096     auto status = checkin_result.status();
1097     auto message = absl::StrCat("Error during checkin: code: ", status.code(),
1098                                 ", message: ", status.message());
1099     if (status.code() == absl::StatusCode::kAborted) {
1100       phase_logger.LogCheckinServerAborted(
1101           status,
1102           GetNetworkStatsSince(federated_protocol,
1103                                /*fedselect_manager=*/nullptr,
1104                                network_stats_before_plan_download),
1105           time_before_plan_download, reference_time);
1106     } else if (status.code() == absl::StatusCode::kCancelled) {
1107       phase_logger.LogCheckinClientInterrupted(
1108           status,
1109           GetNetworkStatsSince(federated_protocol,
1110                                /*fedselect_manager=*/nullptr,
1111                                network_stats_before_plan_download),
1112           time_before_plan_download, reference_time);
1113     } else if (!status.ok()) {
1114       phase_logger.LogCheckinIOError(
1115           status,
1116           GetNetworkStatsSince(federated_protocol,
1117                                /*fedselect_manager=*/nullptr,
1118                                network_stats_before_plan_download),
1119           time_before_plan_download, reference_time);
1120     }
1121     FCP_LOG(INFO) << message;
1122     return status;
1123   }
1124   // Server rejected us? Return the fl_runner_results as-is.
1125   if (std::holds_alternative<FederatedProtocol::Rejection>(*checkin_result)) {
1126     phase_logger.LogCheckinTurnedAway(
1127         GetNetworkStatsSince(federated_protocol,
1128                              /*fedselect_manager=*/nullptr,
1129                              network_stats_before_checkin),
1130         time_before_checkin, reference_time);
1131     FCP_LOG(INFO) << "Device rejected by server during checkin; aborting";
1132     return absl::InternalError("Device rejected by server.");
1133   }
1134   auto task_assignment =
1135       absl::get<FederatedProtocol::TaskAssignment>(*checkin_result);
1136   ClientOnlyPlan plan;
1137   auto plan_bytes = task_assignment.payloads.plan;
1138   if (!ParseFromStringOrCord(plan, plan_bytes)) {
1139     auto message = "Failed to parse received plan";
1140     phase_logger.LogCheckinInvalidPayload(
1141         message,
1142         GetNetworkStatsSince(federated_protocol,
1143                              /*fedselect_manager=*/nullptr,
1144                              network_stats_before_plan_download),
1145         time_before_plan_download, reference_time);
1146     FCP_LOG(ERROR) << message;
1147     return absl::InternalError("");
1148   }
1149   std::string computation_id;
1150   if (flags->enable_computation_id()) {
1151     computation_id = ComputeSHA256FromStringOrCord(plan_bytes);
1152   }
1153   int32_t minimum_clients_in_server_visible_aggregate = 0;
1154   if (task_assignment.sec_agg_info.has_value()) {
1155     auto minimum_number_of_participants =
1156         plan.phase().minimum_number_of_participants();
1157     if (task_assignment.sec_agg_info->expected_number_of_clients <
1158         minimum_number_of_participants) {
1159       return absl::InternalError(
1160           "expectedNumberOfClients was less than Plan's "
1161           "minimumNumberOfParticipants.");
1162     }
1163     minimum_clients_in_server_visible_aggregate =
1164         task_assignment.sec_agg_info
1165             ->minimum_clients_in_server_visible_aggregate;
1166   }
1167   absl::StatusOr<std::string> checkpoint_input_filename = "";
1168   // Example query plan does not have an input checkpoint.
1169   if (!plan.phase().has_example_query_spec()) {
1170     checkpoint_input_filename =
1171         CreateInputCheckpointFile(files, task_assignment.payloads.checkpoint);
1172     if (!checkpoint_input_filename.ok()) {
1173       auto status = checkpoint_input_filename.status();
1174       auto message = absl::StrCat(
1175           "Failed to create checkpoint input file: code: ", status.code(),
1176           ", message: ", status.message());
1177       phase_logger.LogCheckinIOError(
1178           status,
1179           GetNetworkStatsSince(federated_protocol,
1180                                /*fedselect_manager=*/nullptr,
1181                                network_stats_before_plan_download),
1182           time_before_plan_download, reference_time);
1183       FCP_LOG(ERROR) << message;
1184       return status;
1185     }
1186   }
1187   phase_logger.LogCheckinCompleted(
1188       task_name,
1189       GetNetworkStatsSince(federated_protocol, /*fedselect_manager=*/nullptr,
1190                            network_stats_before_plan_download),
1191       /*time_before_checkin=*/time_before_checkin,
1192       /*time_before_plan_download=*/time_before_plan_download, reference_time);
1193   return CheckinResult{
1194       .task_name = std::move(task_name),
1195       .plan = std::move(plan),
1196       .minimum_clients_in_server_visible_aggregate =
1197           minimum_clients_in_server_visible_aggregate,
1198       .checkpoint_input_filename = std::move(*checkpoint_input_filename),
1199       .computation_id = std::move(computation_id),
1200       .federated_select_uri_template =
1201           task_assignment.federated_select_uri_template};
1202 }
1203 }  // namespace
RunFederatedComputation(SimpleTaskEnvironment * env_deps,EventPublisher * event_publisher,Files * files,LogManager * log_manager,const Flags * flags,const std::string & federated_service_uri,const std::string & api_key,const std::string & test_cert_path,const std::string & session_name,const std::string & population_name,const std::string & retry_token,const std::string & client_version,const std::string & attestation_measurement)1204 absl::StatusOr<FLRunnerResult> RunFederatedComputation(
1205     SimpleTaskEnvironment* env_deps, EventPublisher* event_publisher,
1206     Files* files, LogManager* log_manager, const Flags* flags,
1207     const std::string& federated_service_uri, const std::string& api_key,
1208     const std::string& test_cert_path, const std::string& session_name,
1209     const std::string& population_name, const std::string& retry_token,
1210     const std::string& client_version,
1211     const std::string& attestation_measurement) {
1212   auto opstats_logger =
1213       engine::CreateOpStatsLogger(env_deps->GetBaseDir(), flags, log_manager,
1214                                   session_name, population_name);
1215   absl::Time reference_time = absl::Now();
1216   FLRunnerResult fl_runner_result;
1217   fcp::client::InterruptibleRunner::TimingConfig timing_config = {
1218       .polling_period =
1219           absl::Milliseconds(flags->condition_polling_period_millis()),
1220       .graceful_shutdown_period = absl::Milliseconds(
1221           flags->tf_execution_teardown_grace_period_millis()),
1222       .extended_shutdown_period = absl::Milliseconds(
1223           flags->tf_execution_teardown_extended_period_millis()),
1224   };
1225   auto should_abort_protocol_callback = [&env_deps, &timing_config]() -> bool {
1226     // Return the Status if failed, or the negated value if successful.
1227     return env_deps->ShouldAbort(absl::Now(), timing_config.polling_period);
1228   };
1229   PhaseLoggerImpl phase_logger(event_publisher, opstats_logger.get(),
1230                                log_manager, flags);
1231   // If there was an error initializing OpStats, opstats_logger will be a
1232   // no-op implementation and execution will be allowed to continue.
1233   if (!opstats_logger->GetInitStatus().ok()) {
1234     // This will only happen if OpStats is enabled and there was an error in
1235     // initialization.
1236     phase_logger.LogNonfatalInitializationError(
1237         opstats_logger->GetInitStatus());
1238   }
1239   Clock* clock = Clock::RealClock();
1240   std::unique_ptr<cache::ResourceCache> resource_cache;
1241   // if (flags->max_resource_cache_size_bytes() > 0) {
1242   //   // Anything that goes wrong in FileBackedResourceCache::Create is a
1243   //   // programmer error.
1244   //   absl::StatusOr<std::unique_ptr<cache::ResourceCache>>
1245   //       resource_cache_internal = cache::FileBackedResourceCache::Create(
1246   //           env_deps->GetBaseDir(), env_deps->GetCacheDir(), log_manager,
1247   //           clock, flags->max_resource_cache_size_bytes());
1248   //   if (!resource_cache_internal.ok()) {
1249   //     auto resource_init_failed_status = absl::Status(
1250   //         resource_cache_internal.status().code(),
1251   //         absl::StrCat("Failed to initialize FileBackedResourceCache: ",
1252   //                      resource_cache_internal.status().ToString()));
1253   //     if (flags->resource_cache_initialization_error_is_fatal()) {
1254   //       phase_logger.LogFatalInitializationError(resource_init_failed_status);
1255   //       return resource_init_failed_status;
1256   //     }
1257   //     // We log an error but otherwise proceed as if the cache was disabled.
1258   //     phase_logger.LogNonfatalInitializationError(resource_init_failed_status);
1259   //   } else {
1260   //     resource_cache = std::move(*resource_cache_internal);
1261   //   }
1262   // }
1263   std::unique_ptr<::fcp::client::http::HttpClient> http_client =
1264       flags->enable_grpc_with_http_resource_support() ||
1265               flags->use_http_federated_compute_protocol()
1266           ? env_deps->CreateHttpClient()
1267           : nullptr;
1268   std::unique_ptr<FederatedProtocol> federated_protocol;
1269   if (flags->use_http_federated_compute_protocol()) {
1270     log_manager->LogDiag(ProdDiagCode::HTTP_FEDERATED_PROTOCOL_USED);
1271     // Verify the entry point uri starts with "https://" or
1272     // "http://localhost". Note "http://localhost" is allowed for testing
1273     // purpose.
1274     if (!(absl::StartsWith(federated_service_uri, "https://") ||
1275           absl::StartsWith(federated_service_uri, "http://localhost"))) {
1276       return absl::InvalidArgumentError("The entry point uri is invalid.");
1277     }
1278     federated_protocol = std::make_unique<http::HttpFederatedProtocol>(
1279         clock, log_manager, flags, http_client.get(),
1280         std::make_unique<SecAggRunnerFactoryImpl>(),
1281         event_publisher->secagg_event_publisher(), federated_service_uri,
1282         api_key, population_name, retry_token, client_version,
1283         attestation_measurement, should_abort_protocol_callback, absl::BitGen(),
1284         timing_config, resource_cache.get());
1285   } else {
1286 #ifdef FCP_CLIENT_SUPPORT_GRPC
1287     // Check in with the server to either retrieve a plan + initial
1288     // checkpoint, or get rejected with a RetryWindow.
1289     auto grpc_channel_deadline = flags->grpc_channel_deadline_seconds();
1290     if (grpc_channel_deadline <= 0) {
1291       grpc_channel_deadline = 600;
1292       FCP_LOG(INFO) << "Using default channel deadline of "
1293                     << grpc_channel_deadline << " seconds.";
1294     }
1295     federated_protocol = std::make_unique<GrpcFederatedProtocol>(
1296         event_publisher, log_manager,
1297         std::make_unique<SecAggRunnerFactoryImpl>(), flags, http_client.get(),
1298         federated_service_uri, api_key, test_cert_path, population_name,
1299         retry_token, client_version, attestation_measurement,
1300         should_abort_protocol_callback, timing_config, grpc_channel_deadline,
1301         resource_cache.get());
1302 #else
1303     return absl::InternalError("No FederatedProtocol enabled.");
1304 #endif
1305   }
1306   std::unique_ptr<FederatedSelectManager> federated_select_manager;
1307   if (http_client != nullptr && flags->enable_federated_select()) {
1308     federated_select_manager = std::make_unique<HttpFederatedSelectManager>(
1309         log_manager, files, http_client.get(), should_abort_protocol_callback,
1310         timing_config);
1311   } else {
1312     federated_select_manager =
1313         std::make_unique<DisabledFederatedSelectManager>(log_manager);
1314   }
1315   return RunFederatedComputation(env_deps, phase_logger, event_publisher, files,
1316                                  log_manager, opstats_logger.get(), flags,
1317                                  federated_protocol.get(),
1318                                  federated_select_manager.get(), timing_config,
1319                                  reference_time, session_name, population_name);
1320 }
RunFederatedComputation(SimpleTaskEnvironment * env_deps,PhaseLogger & phase_logger,EventPublisher * event_publisher,Files * files,LogManager * log_manager,OpStatsLogger * opstats_logger,const Flags * flags,FederatedProtocol * federated_protocol,FederatedSelectManager * fedselect_manager,const fcp::client::InterruptibleRunner::TimingConfig & timing_config,const absl::Time reference_time,const std::string & session_name,const std::string & population_name)1321 absl::StatusOr<FLRunnerResult> RunFederatedComputation(
1322     SimpleTaskEnvironment* env_deps, PhaseLogger& phase_logger,
1323     EventPublisher* event_publisher, Files* files, LogManager* log_manager,
1324     OpStatsLogger* opstats_logger, const Flags* flags,
1325     FederatedProtocol* federated_protocol,
1326     FederatedSelectManager* fedselect_manager,
1327     const fcp::client::InterruptibleRunner::TimingConfig& timing_config,
1328     const absl::Time reference_time, const std::string& session_name,
1329     const std::string& population_name) {
1330   SelectorContext federated_selector_context;
1331   federated_selector_context.mutable_computation_properties()->set_session_name(
1332       session_name);
1333   FederatedComputation federated_computation;
1334   federated_computation.set_population_name(population_name);
1335   *federated_selector_context.mutable_computation_properties()
1336        ->mutable_federated() = federated_computation;
1337   SelectorContext eligibility_selector_context;
1338   eligibility_selector_context.mutable_computation_properties()
1339       ->set_session_name(session_name);
1340   EligibilityEvalComputation eligibility_eval_computation;
1341   eligibility_eval_computation.set_population_name(population_name);
1342   *eligibility_selector_context.mutable_computation_properties()
1343        ->mutable_eligibility_eval() = eligibility_eval_computation;
1344   // Construct a default FLRunnerResult that reflects an unsuccessful training
1345   // attempt and which uses RetryWindow corresponding to transient errors (if
1346   // the flag is on).
1347   // This is what will be returned if we have to bail early, before we've
1348   // received a RetryWindow from the server.
1349   FLRunnerResult fl_runner_result;
1350   fl_runner_result.set_contribution_result(FLRunnerResult::FAIL);
1351   // Before we even check whether we should abort right away, update the retry
1352   // window. That way we will use the most appropriate retry window we have
1353   // available (an implementation detail of FederatedProtocol, but generally a
1354   // 'transient error' retry window based on the provided flag values) in case
1355   // we do need to abort.
1356   UpdateRetryWindowAndNetworkStats(*federated_protocol, fedselect_manager,
1357                                    phase_logger, fl_runner_result);
1358   // Check if the device conditions allow for checking in with the server
1359   // and running a federated computation. If not, bail early with the
1360   // transient error retry window.
1361   std::function<bool()> should_abort = [env_deps, &timing_config]() {
1362     return env_deps->ShouldAbort(absl::Now(), timing_config.polling_period);
1363   };
1364   if (should_abort()) {
1365     std::string message =
1366         "Device conditions not satisfied, aborting federated computation";
1367     FCP_LOG(INFO) << message;
1368     phase_logger.LogTaskNotStarted(message);
1369     return fl_runner_result;
1370   }
1371   // Eligibility eval plans can use example iterators from the
1372   // SimpleTaskEnvironment and those reading the OpStats DB.
1373   opstats::OpStatsExampleIteratorFactory opstats_example_iterator_factory(
1374       opstats_logger, log_manager,
1375       flags->opstats_last_successful_contribution_criteria());
1376   std::unique_ptr<engine::ExampleIteratorFactory>
1377       env_eligibility_example_iterator_factory =
1378           CreateSimpleTaskEnvironmentIteratorFactory(
1379               env_deps, eligibility_selector_context);
1380   std::vector<engine::ExampleIteratorFactory*>
1381       eligibility_example_iterator_factories{
1382           &opstats_example_iterator_factory,
1383           env_eligibility_example_iterator_factory.get()};
1384   // Note that this method will update fl_runner_result's fields with values
1385   // received over the course of the eligibility eval protocol interaction.
1386   absl::StatusOr<EligibilityEvalResult> eligibility_eval_result =
1387       IssueEligibilityEvalCheckinAndRunPlan(
1388           eligibility_example_iterator_factories, should_abort, phase_logger,
1389           files, log_manager, opstats_logger, flags, federated_protocol,
1390           timing_config, reference_time, fl_runner_result);
1391   if (!eligibility_eval_result.ok()) {
1392     return fl_runner_result;
1393   }
1394   auto checkin_result =
1395       IssueCheckin(phase_logger, log_manager, files, federated_protocol,
1396                    std::move(eligibility_eval_result->task_eligibility_info),
1397                    reference_time, population_name, fl_runner_result, flags);
1398   if (!checkin_result.ok()) {
1399     return fl_runner_result;
1400   }
1401   SelectorContext federated_selector_context_with_task_name =
1402       federated_selector_context;
1403   federated_selector_context_with_task_name.mutable_computation_properties()
1404       ->mutable_federated()
1405       ->set_task_name(checkin_result->task_name);
1406   if (flags->enable_computation_id()) {
1407     federated_selector_context_with_task_name.mutable_computation_properties()
1408         ->mutable_federated()
1409         ->set_computation_id(checkin_result->computation_id);
1410   }
1411   if (checkin_result->plan.phase().has_example_query_spec()) {
1412     federated_selector_context_with_task_name.mutable_computation_properties()
1413         ->set_example_iterator_output_format(
1414             ::fcp::client::QueryTimeComputationProperties::
1415                 EXAMPLE_QUERY_RESULT);
1416   }
1417   // Include the last successful contribution timestamp in the
1418   // SelectorContext.
1419   const auto& opstats_db = opstats_logger->GetOpStatsDb();
1420   if (opstats_db != nullptr) {
1421     absl::StatusOr<opstats::OpStatsSequence> data = opstats_db->Read();
1422     if (data.ok()) {
1423       std::optional<google::protobuf::Timestamp>
1424           last_successful_contribution_time =
1425               opstats::GetLastSuccessfulContributionTime(
1426                   *data, checkin_result->task_name);
1427       if (last_successful_contribution_time.has_value()) {
1428         *(federated_selector_context_with_task_name
1429               .mutable_computation_properties()
1430               ->mutable_federated()
1431               ->mutable_historical_context()
1432               ->mutable_last_successful_contribution_time()) =
1433             *last_successful_contribution_time;
1434       }
1435     }
1436   }
1437   if (checkin_result->plan.phase().has_example_query_spec()) {
1438     // Example query plan only supports simple agg for now.
1439     *(federated_selector_context_with_task_name
1440           .mutable_computation_properties()
1441           ->mutable_federated()
1442           ->mutable_simple_aggregation()) = SimpleAggregation();
1443   } else {
1444     const auto& federated_compute_io_router =
1445         checkin_result->plan.phase().federated_compute();
1446     const bool has_simpleagg_tensors =
1447         !federated_compute_io_router.output_filepath_tensor_name().empty();
1448     bool all_aggregations_are_secagg = true;
1449     for (const auto& aggregation : federated_compute_io_router.aggregations()) {
1450       all_aggregations_are_secagg &=
1451           aggregation.second.protocol_config_case() ==
1452           AggregationConfig::kSecureAggregation;
1453     }
1454     if (!has_simpleagg_tensors && all_aggregations_are_secagg) {
1455       federated_selector_context_with_task_name
1456           .mutable_computation_properties()
1457           ->mutable_federated()
1458           ->mutable_secure_aggregation()
1459           ->set_minimum_clients_in_server_visible_aggregate(
1460               checkin_result->minimum_clients_in_server_visible_aggregate);
1461     } else {
1462       // Has an output checkpoint, so some tensors must be simply aggregated.
1463       *(federated_selector_context_with_task_name
1464             .mutable_computation_properties()
1465             ->mutable_federated()
1466             ->mutable_simple_aggregation()) = SimpleAggregation();
1467     }
1468   }
1469   RetryWindow report_retry_window;
1470   phase_logger.LogComputationStarted();
1471   absl::Time run_plan_start_time = absl::Now();
1472   NetworkStats run_plan_start_network_stats =
1473       GetCumulativeNetworkStats(federated_protocol, fedselect_manager);
1474   absl::StatusOr<std::string> checkpoint_output_filename =
1475       files->CreateTempFile("output", ".ckp");
1476   if (!checkpoint_output_filename.ok()) {
1477     auto status = checkpoint_output_filename.status();
1478     auto message = absl::StrCat(
1479         "Could not create temporary output checkpoint file: code: ",
1480         status.code(), ", message: ", status.message());
1481     phase_logger.LogComputationIOError(
1482         status, ExampleStats(),
1483         GetNetworkStatsSince(federated_protocol, fedselect_manager,
1484                              run_plan_start_network_stats),
1485         run_plan_start_time);
1486     return fl_runner_result;
1487   }
1488   // Regular plans can use example iterators from the SimpleTaskEnvironment,
1489   // those reading the OpStats DB, or those serving Federated Select slices.
1490   std::unique_ptr<engine::ExampleIteratorFactory> env_example_iterator_factory =
1491       CreateSimpleTaskEnvironmentIteratorFactory(
1492           env_deps, federated_selector_context_with_task_name);
1493   std::unique_ptr<::fcp::client::engine::ExampleIteratorFactory>
1494       fedselect_example_iterator_factory =
1495           fedselect_manager->CreateExampleIteratorFactoryForUriTemplate(
1496               checkin_result->federated_select_uri_template);
1497   std::vector<engine::ExampleIteratorFactory*> example_iterator_factories{
1498       fedselect_example_iterator_factory.get(),
1499       &opstats_example_iterator_factory, env_example_iterator_factory.get()};
1500   PlanResultAndCheckpointFile plan_result_and_checkpoint_file =
1501       checkin_result->plan.phase().has_example_query_spec()
1502           ? RunPlanWithExampleQuerySpec(
1503                 example_iterator_factories, opstats_logger, flags,
1504                 checkin_result->plan, *checkpoint_output_filename)
1505           : RunPlanWithTensorflowSpec(
1506                 example_iterator_factories, should_abort, log_manager,
1507                 opstats_logger, flags, checkin_result->plan,
1508                 checkin_result->checkpoint_input_filename,
1509                 *checkpoint_output_filename, timing_config);
1510   // Update the FLRunnerResult fields to account for any network usage during
1511   // the execution of the plan (e.g. due to Federated Select slices having
1512   // been fetched).
1513   UpdateRetryWindowAndNetworkStats(*federated_protocol, fedselect_manager,
1514                                    phase_logger, fl_runner_result);
1515   auto outcome = plan_result_and_checkpoint_file.plan_result.outcome;
1516   absl::StatusOr<ComputationResults> computation_results;
1517   if (outcome == engine::PlanOutcome::kSuccess) {
1518     computation_results = CreateComputationResults(
1519         checkin_result->plan.phase().has_example_query_spec()
1520             ? nullptr
1521             : &checkin_result->plan.phase().tensorflow_spec(),
1522         plan_result_and_checkpoint_file);
1523   }
1524   LogComputationOutcome(
1525       plan_result_and_checkpoint_file.plan_result, computation_results.status(),
1526       phase_logger,
1527       GetNetworkStatsSince(federated_protocol, fedselect_manager,
1528                            run_plan_start_network_stats),
1529       run_plan_start_time, reference_time);
1530   absl::Status report_result = ReportPlanResult(
1531       federated_protocol, phase_logger, std::move(computation_results),
1532       run_plan_start_time, reference_time);
1533   if (outcome == engine::PlanOutcome::kSuccess && report_result.ok()) {
1534     // Only if training succeeded *and* reporting succeeded do we consider
1535     // the device to have contributed successfully.
1536     fl_runner_result.set_contribution_result(FLRunnerResult::SUCCESS);
1537   }
1538   // Update the FLRunnerResult fields one more time to account for the
1539   // "Report" protocol interaction.
1540   UpdateRetryWindowAndNetworkStats(*federated_protocol, fedselect_manager,
1541                                    phase_logger, fl_runner_result);
1542   return fl_runner_result;
1543 }
RunPlanWithTensorflowSpecForTesting(SimpleTaskEnvironment * env_deps,EventPublisher * event_publisher,Files * files,LogManager * log_manager,const Flags * flags,const ClientOnlyPlan & client_plan,const std::string & checkpoint_input_filename,const fcp::client::InterruptibleRunner::TimingConfig & timing_config,const absl::Time run_plan_start_time,const absl::Time reference_time)1544 FLRunnerTensorflowSpecResult RunPlanWithTensorflowSpecForTesting(
1545     SimpleTaskEnvironment* env_deps, EventPublisher* event_publisher,
1546     Files* files, LogManager* log_manager, const Flags* flags,
1547     const ClientOnlyPlan& client_plan,
1548     const std::string& checkpoint_input_filename,
1549     const fcp::client::InterruptibleRunner::TimingConfig& timing_config,
1550     const absl::Time run_plan_start_time, const absl::Time reference_time) {
1551   FLRunnerTensorflowSpecResult result;
1552   result.set_outcome(engine::PhaseOutcome::ERROR);
1553   engine::PlanResult plan_result(engine::PlanOutcome::kTensorflowError,
1554                                  absl::UnknownError(""));
1555   std::function<bool()> should_abort = [env_deps, &timing_config]() {
1556     return env_deps->ShouldAbort(absl::Now(), timing_config.polling_period);
1557   };
1558   auto opstats_logger =
1559       engine::CreateOpStatsLogger(env_deps->GetBaseDir(), flags, log_manager,
1560                                   /*session_name=*/"", /*population_name=*/"");
1561   PhaseLoggerImpl phase_logger(event_publisher, opstats_logger.get(),
1562                                log_manager, flags);
1563   // Regular plans can use example iterators from the SimpleTaskEnvironment,
1564   // those reading the OpStats DB, or those serving Federated Select slices.
1565   // However, we don't provide a Federated Select-specific example iterator
1566   // factory. That way, the Federated Select slice queries will be forwarded
1567   // to SimpleTaskEnvironment, which can handle them by providing
1568   // test-specific slices if they want to.
1569   //
1570   // Eligibility eval plans can only use iterators from the
1571   // SimpleTaskEnvironment and those reading the OpStats DB.
1572   opstats::OpStatsExampleIteratorFactory opstats_example_iterator_factory(
1573       opstats_logger.get(), log_manager,
1574       flags->opstats_last_successful_contribution_criteria());
1575   std::unique_ptr<engine::ExampleIteratorFactory> env_example_iterator_factory =
1576       CreateSimpleTaskEnvironmentIteratorFactory(env_deps, SelectorContext());
1577   std::vector<engine::ExampleIteratorFactory*> example_iterator_factories{
1578       &opstats_example_iterator_factory, env_example_iterator_factory.get()};
1579   phase_logger.LogComputationStarted();
1580   if (client_plan.phase().has_federated_compute()) {
1581     absl::StatusOr<std::string> checkpoint_output_filename =
1582         files->CreateTempFile("output", ".ckp");
1583     if (!checkpoint_output_filename.ok()) {
1584       phase_logger.LogComputationIOError(
1585           checkpoint_output_filename.status(), ExampleStats(),
1586           // Empty network stats, since no network protocol is actually used
1587           // in this method.
1588           NetworkStats(), run_plan_start_time);
1589       return result;
1590     }
1591     // Regular TensorflowSpec-based plans.
1592     PlanResultAndCheckpointFile plan_result_and_checkpoint_file =
1593         RunPlanWithTensorflowSpec(example_iterator_factories, should_abort,
1594                                   log_manager, opstats_logger.get(), flags,
1595                                   client_plan, checkpoint_input_filename,
1596                                   *checkpoint_output_filename, timing_config);
1597     result.set_checkpoint_output_filename(
1598         plan_result_and_checkpoint_file.checkpoint_file);
1599     plan_result = std::move(plan_result_and_checkpoint_file.plan_result);
1600   } else if (client_plan.phase().has_federated_compute_eligibility()) {
1601     // Eligibility eval plans.
1602     plan_result = RunEligibilityEvalPlanWithTensorflowSpec(
1603         example_iterator_factories, should_abort, log_manager,
1604         opstats_logger.get(), flags, client_plan, checkpoint_input_filename,
1605         timing_config, run_plan_start_time, reference_time);
1606   } else {
1607     // This branch shouldn't be taken, unless we add an additional type of
1608     // TensorflowSpec-based plan in the future. We return a readable error so
1609     // that when such new plan types *are* added, they result in clear
1610     // compatibility test failures when such plans are erroneously targeted at
1611     // old releases that don't support them yet.
1612     event_publisher->PublishIoError("Unsupported TensorflowSpec-based plan");
1613     return result;
1614   }
1615   // Copy output tensors into the result proto.
1616   result.set_outcome(
1617       engine::ConvertPlanOutcomeToPhaseOutcome(plan_result.outcome));
1618   if (plan_result.outcome == engine::PlanOutcome::kSuccess) {
1619     for (int i = 0; i < plan_result.output_names.size(); i++) {
1620       tensorflow::TensorProto output_tensor_proto;
1621       plan_result.output_tensors[i].AsProtoField(&output_tensor_proto);
1622       (*result.mutable_output_tensors())[plan_result.output_names[i]] =
1623           std::move(output_tensor_proto);
1624     }
1625     phase_logger.LogComputationCompleted(
1626         plan_result.example_stats,
1627         // Empty network stats, since no network protocol is actually used in
1628         // this method.
1629         NetworkStats(), run_plan_start_time, reference_time);
1630   } else {
1631     phase_logger.LogComputationTensorflowError(
1632         plan_result.original_status, plan_result.example_stats, NetworkStats(),
1633         run_plan_start_time, reference_time);
1634   }
1635   return result;
1636 }
1637 }  // namespace client
1638 }  // namespace fcp
1639