xref: /aosp_15_r20/external/federated-compute/fcp/client/fl_runner.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker /*
2*14675a02SAndroid Build Coastguard Worker  * Copyright 2020 Google LLC
3*14675a02SAndroid Build Coastguard Worker  *
4*14675a02SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*14675a02SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*14675a02SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*14675a02SAndroid Build Coastguard Worker  *
8*14675a02SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*14675a02SAndroid Build Coastguard Worker  *
10*14675a02SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*14675a02SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*14675a02SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*14675a02SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*14675a02SAndroid Build Coastguard Worker  * limitations under the License.
15*14675a02SAndroid Build Coastguard Worker  */
16*14675a02SAndroid Build Coastguard Worker #include "fcp/client/fl_runner.h"
17*14675a02SAndroid Build Coastguard Worker 
18*14675a02SAndroid Build Coastguard Worker #include <fcntl.h>
19*14675a02SAndroid Build Coastguard Worker 
20*14675a02SAndroid Build Coastguard Worker #include <fstream>
21*14675a02SAndroid Build Coastguard Worker #include <functional>
22*14675a02SAndroid Build Coastguard Worker #include <map>
23*14675a02SAndroid Build Coastguard Worker #include <memory>
24*14675a02SAndroid Build Coastguard Worker #include <optional>
25*14675a02SAndroid Build Coastguard Worker #include <string>
26*14675a02SAndroid Build Coastguard Worker #include <utility>
27*14675a02SAndroid Build Coastguard Worker #include <variant>
28*14675a02SAndroid Build Coastguard Worker #include <vector>
29*14675a02SAndroid Build Coastguard Worker 
30*14675a02SAndroid Build Coastguard Worker #include "absl/status/status.h"
31*14675a02SAndroid Build Coastguard Worker #include "absl/status/statusor.h"
32*14675a02SAndroid Build Coastguard Worker #include "absl/strings/cord.h"
33*14675a02SAndroid Build Coastguard Worker #include "absl/time/time.h"
34*14675a02SAndroid Build Coastguard Worker #include "fcp/base/clock.h"
35*14675a02SAndroid Build Coastguard Worker #include "fcp/base/monitoring.h"
36*14675a02SAndroid Build Coastguard Worker #include "fcp/base/platform.h"
37*14675a02SAndroid Build Coastguard Worker // #include "fcp/client/cache/file_backed_resource_cache.h"
38*14675a02SAndroid Build Coastguard Worker #include "fcp/client/cache/resource_cache.h"
39*14675a02SAndroid Build Coastguard Worker #include "fcp/client/engine/common.h"
40*14675a02SAndroid Build Coastguard Worker #include "fcp/client/engine/engine.pb.h"
41*14675a02SAndroid Build Coastguard Worker #include "fcp/client/engine/example_iterator_factory.h"
42*14675a02SAndroid Build Coastguard Worker #include "fcp/client/engine/example_query_plan_engine.h"
43*14675a02SAndroid Build Coastguard Worker #include "fcp/client/engine/plan_engine_helpers.h"
44*14675a02SAndroid Build Coastguard Worker #include "fcp/client/opstats/opstats_utils.h"
45*14675a02SAndroid Build Coastguard Worker #include "fcp/client/parsing_utils.h"
46*14675a02SAndroid Build Coastguard Worker #ifdef FCP_CLIENT_SUPPORT_TFMOBILE
47*14675a02SAndroid Build Coastguard Worker #include "fcp/client/engine/simple_plan_engine.h"
48*14675a02SAndroid Build Coastguard Worker #endif
49*14675a02SAndroid Build Coastguard Worker #include "fcp/client/engine/tflite_plan_engine.h"
50*14675a02SAndroid Build Coastguard Worker #include "fcp/client/event_publisher.h"
51*14675a02SAndroid Build Coastguard Worker #include "fcp/client/federated_protocol.h"
52*14675a02SAndroid Build Coastguard Worker #include "fcp/client/federated_protocol_util.h"
53*14675a02SAndroid Build Coastguard Worker #include "fcp/client/files.h"
54*14675a02SAndroid Build Coastguard Worker #include "fcp/client/fl_runner.pb.h"
55*14675a02SAndroid Build Coastguard Worker #include "fcp/client/flags.h"
56*14675a02SAndroid Build Coastguard Worker #include "fcp/client/http/http_federated_protocol.h"
57*14675a02SAndroid Build Coastguard Worker #ifdef FCP_CLIENT_SUPPORT_GRPC
58*14675a02SAndroid Build Coastguard Worker #include "fcp/client/grpc_federated_protocol.h"
59*14675a02SAndroid Build Coastguard Worker #endif
60*14675a02SAndroid Build Coastguard Worker #include "fcp/client/interruptible_runner.h"
61*14675a02SAndroid Build Coastguard Worker #include "fcp/client/log_manager.h"
62*14675a02SAndroid Build Coastguard Worker #include "fcp/client/opstats/opstats_example_store.h"
63*14675a02SAndroid Build Coastguard Worker #include "fcp/client/phase_logger_impl.h"
64*14675a02SAndroid Build Coastguard Worker #include "fcp/client/secagg_runner.h"
65*14675a02SAndroid Build Coastguard Worker #include "fcp/client/selector_context.pb.h"
66*14675a02SAndroid Build Coastguard Worker #include "fcp/client/simple_task_environment.h"
67*14675a02SAndroid Build Coastguard Worker #include "fcp/protos/federated_api.pb.h"
68*14675a02SAndroid Build Coastguard Worker #include "fcp/protos/federatedcompute/eligibility_eval_tasks.pb.h"
69*14675a02SAndroid Build Coastguard Worker #include "fcp/protos/opstats.pb.h"
70*14675a02SAndroid Build Coastguard Worker #include "fcp/protos/plan.pb.h"
71*14675a02SAndroid Build Coastguard Worker #include "openssl/digest.h"
72*14675a02SAndroid Build Coastguard Worker #include "openssl/evp.h"
73*14675a02SAndroid Build Coastguard Worker #include "tensorflow/core/framework/tensor.h"
74*14675a02SAndroid Build Coastguard Worker #include "tensorflow/core/framework/tensor.pb.h"
75*14675a02SAndroid Build Coastguard Worker #include "tensorflow/core/framework/tensor_shape.pb.h"
76*14675a02SAndroid Build Coastguard Worker #include "tensorflow/core/protobuf/struct.pb.h"
77*14675a02SAndroid Build Coastguard Worker namespace fcp {
78*14675a02SAndroid Build Coastguard Worker namespace client {
79*14675a02SAndroid Build Coastguard Worker using ::fcp::client::opstats::OpStatsLogger;
80*14675a02SAndroid Build Coastguard Worker using ::google::internal::federated::plan::AggregationConfig;
81*14675a02SAndroid Build Coastguard Worker using ::google::internal::federated::plan::ClientOnlyPlan;
82*14675a02SAndroid Build Coastguard Worker using ::google::internal::federated::plan::FederatedComputeEligibilityIORouter;
83*14675a02SAndroid Build Coastguard Worker using ::google::internal::federated::plan::FederatedComputeIORouter;
84*14675a02SAndroid Build Coastguard Worker using ::google::internal::federated::plan::TensorflowSpec;
85*14675a02SAndroid Build Coastguard Worker using ::google::internal::federatedcompute::v1::PopulationEligibilitySpec;
86*14675a02SAndroid Build Coastguard Worker using ::google::internal::federatedml::v2::RetryWindow;
87*14675a02SAndroid Build Coastguard Worker using ::google::internal::federatedml::v2::TaskEligibilityInfo;
88*14675a02SAndroid Build Coastguard Worker using TfLiteInputs = absl::flat_hash_map<std::string, std::string>;
89*14675a02SAndroid Build Coastguard Worker namespace {
90*14675a02SAndroid Build Coastguard Worker template <typename T>
AddValuesToQuantized(QuantizedTensor * quantized,const tensorflow::Tensor & tensor)91*14675a02SAndroid Build Coastguard Worker void AddValuesToQuantized(QuantizedTensor* quantized,
92*14675a02SAndroid Build Coastguard Worker                           const tensorflow::Tensor& tensor) {
93*14675a02SAndroid Build Coastguard Worker   auto flat_tensor = tensor.flat<T>();
94*14675a02SAndroid Build Coastguard Worker   quantized->values.reserve(quantized->values.size() + flat_tensor.size());
95*14675a02SAndroid Build Coastguard Worker   for (int i = 0; i < flat_tensor.size(); i++) {
96*14675a02SAndroid Build Coastguard Worker     quantized->values.push_back(flat_tensor(i));
97*14675a02SAndroid Build Coastguard Worker   }
98*14675a02SAndroid Build Coastguard Worker }
ComputeSHA256FromStringOrCord(std::variant<std::string,absl::Cord> data)99*14675a02SAndroid Build Coastguard Worker std::string ComputeSHA256FromStringOrCord(
100*14675a02SAndroid Build Coastguard Worker     std::variant<std::string, absl::Cord> data) {
101*14675a02SAndroid Build Coastguard Worker   std::unique_ptr<EVP_MD_CTX, void (*)(EVP_MD_CTX*)> mdctx(EVP_MD_CTX_create(),
102*14675a02SAndroid Build Coastguard Worker                                                            EVP_MD_CTX_destroy);
103*14675a02SAndroid Build Coastguard Worker   FCP_CHECK(EVP_DigestInit_ex(mdctx.get(), EVP_sha256(), nullptr));
104*14675a02SAndroid Build Coastguard Worker   std::string plan_str;
105*14675a02SAndroid Build Coastguard Worker   if (std::holds_alternative<std::string>(data)) {
106*14675a02SAndroid Build Coastguard Worker     plan_str = std::get<std::string>(data);
107*14675a02SAndroid Build Coastguard Worker   } else {
108*14675a02SAndroid Build Coastguard Worker     plan_str = std::string(std::get<absl::Cord>(data));
109*14675a02SAndroid Build Coastguard Worker   }
110*14675a02SAndroid Build Coastguard Worker   FCP_CHECK(EVP_DigestUpdate(mdctx.get(), plan_str.c_str(), sizeof(int)));
111*14675a02SAndroid Build Coastguard Worker   const int hash_len = 32;  // 32 bytes for SHA-256.
112*14675a02SAndroid Build Coastguard Worker   uint8_t computation_id_bytes[hash_len];
113*14675a02SAndroid Build Coastguard Worker   FCP_CHECK(EVP_DigestFinal_ex(mdctx.get(), computation_id_bytes, nullptr));
114*14675a02SAndroid Build Coastguard Worker   return std::string(reinterpret_cast<char const*>(computation_id_bytes),
115*14675a02SAndroid Build Coastguard Worker                      hash_len);
116*14675a02SAndroid Build Coastguard Worker }
117*14675a02SAndroid Build Coastguard Worker struct PlanResultAndCheckpointFile {
PlanResultAndCheckpointFilefcp::client::__anon340cd2590111::PlanResultAndCheckpointFile118*14675a02SAndroid Build Coastguard Worker   explicit PlanResultAndCheckpointFile(engine::PlanResult plan_result)
119*14675a02SAndroid Build Coastguard Worker       : plan_result(std::move(plan_result)) {}
120*14675a02SAndroid Build Coastguard Worker   engine::PlanResult plan_result;
121*14675a02SAndroid Build Coastguard Worker   std::string checkpoint_file;
122*14675a02SAndroid Build Coastguard Worker   PlanResultAndCheckpointFile(PlanResultAndCheckpointFile&&) = default;
123*14675a02SAndroid Build Coastguard Worker   PlanResultAndCheckpointFile& operator=(PlanResultAndCheckpointFile&&) =
124*14675a02SAndroid Build Coastguard Worker       default;
125*14675a02SAndroid Build Coastguard Worker   // Disallow copy and assign.
126*14675a02SAndroid Build Coastguard Worker   PlanResultAndCheckpointFile(const PlanResultAndCheckpointFile&) = delete;
127*14675a02SAndroid Build Coastguard Worker   PlanResultAndCheckpointFile& operator=(const PlanResultAndCheckpointFile&) =
128*14675a02SAndroid Build Coastguard Worker       delete;
129*14675a02SAndroid Build Coastguard Worker };
130*14675a02SAndroid Build Coastguard Worker // Creates computation results. The method checks for SecAgg tensors only if
131*14675a02SAndroid Build Coastguard Worker // `tensorflow_spec != nullptr`.
CreateComputationResults(const TensorflowSpec * tensorflow_spec,const PlanResultAndCheckpointFile & plan_result_and_checkpoint_file)132*14675a02SAndroid Build Coastguard Worker absl::StatusOr<ComputationResults> CreateComputationResults(
133*14675a02SAndroid Build Coastguard Worker     const TensorflowSpec* tensorflow_spec,
134*14675a02SAndroid Build Coastguard Worker     const PlanResultAndCheckpointFile& plan_result_and_checkpoint_file) {
135*14675a02SAndroid Build Coastguard Worker   const auto& [plan_result, checkpoint_file] = plan_result_and_checkpoint_file;
136*14675a02SAndroid Build Coastguard Worker   if (plan_result.outcome != engine::PlanOutcome::kSuccess) {
137*14675a02SAndroid Build Coastguard Worker     return absl::InvalidArgumentError("Computation failed.");
138*14675a02SAndroid Build Coastguard Worker   }
139*14675a02SAndroid Build Coastguard Worker   ComputationResults computation_results;
140*14675a02SAndroid Build Coastguard Worker   if (tensorflow_spec != nullptr) {
141*14675a02SAndroid Build Coastguard Worker     for (int i = 0; i < plan_result.output_names.size(); i++) {
142*14675a02SAndroid Build Coastguard Worker       QuantizedTensor quantized;
143*14675a02SAndroid Build Coastguard Worker       const auto& output_tensor = plan_result.output_tensors[i];
144*14675a02SAndroid Build Coastguard Worker       switch (output_tensor.dtype()) {
145*14675a02SAndroid Build Coastguard Worker         case tensorflow::DT_INT8:
146*14675a02SAndroid Build Coastguard Worker           AddValuesToQuantized<int8_t>(&quantized, output_tensor);
147*14675a02SAndroid Build Coastguard Worker           quantized.bitwidth = 7;
148*14675a02SAndroid Build Coastguard Worker           break;
149*14675a02SAndroid Build Coastguard Worker         case tensorflow::DT_UINT8:
150*14675a02SAndroid Build Coastguard Worker           AddValuesToQuantized<uint8_t>(&quantized, output_tensor);
151*14675a02SAndroid Build Coastguard Worker           quantized.bitwidth = 8;
152*14675a02SAndroid Build Coastguard Worker           break;
153*14675a02SAndroid Build Coastguard Worker         case tensorflow::DT_INT16:
154*14675a02SAndroid Build Coastguard Worker           AddValuesToQuantized<int16_t>(&quantized, output_tensor);
155*14675a02SAndroid Build Coastguard Worker           quantized.bitwidth = 15;
156*14675a02SAndroid Build Coastguard Worker           break;
157*14675a02SAndroid Build Coastguard Worker         case tensorflow::DT_UINT16:
158*14675a02SAndroid Build Coastguard Worker           AddValuesToQuantized<uint16_t>(&quantized, output_tensor);
159*14675a02SAndroid Build Coastguard Worker           quantized.bitwidth = 16;
160*14675a02SAndroid Build Coastguard Worker           break;
161*14675a02SAndroid Build Coastguard Worker         case tensorflow::DT_INT32:
162*14675a02SAndroid Build Coastguard Worker           AddValuesToQuantized<int32_t>(&quantized, output_tensor);
163*14675a02SAndroid Build Coastguard Worker           quantized.bitwidth = 31;
164*14675a02SAndroid Build Coastguard Worker           break;
165*14675a02SAndroid Build Coastguard Worker         case tensorflow::DT_INT64:
166*14675a02SAndroid Build Coastguard Worker           AddValuesToQuantized<tensorflow::int64>(&quantized, output_tensor);
167*14675a02SAndroid Build Coastguard Worker           quantized.bitwidth = 62;
168*14675a02SAndroid Build Coastguard Worker           break;
169*14675a02SAndroid Build Coastguard Worker         default:
170*14675a02SAndroid Build Coastguard Worker           return absl::InvalidArgumentError(
171*14675a02SAndroid Build Coastguard Worker               absl::StrCat("Tensor of type",
172*14675a02SAndroid Build Coastguard Worker                            tensorflow::DataType_Name(output_tensor.dtype()),
173*14675a02SAndroid Build Coastguard Worker                            "could not be converted to quantized value"));
174*14675a02SAndroid Build Coastguard Worker       }
175*14675a02SAndroid Build Coastguard Worker       computation_results[plan_result.output_names[i]] = std::move(quantized);
176*14675a02SAndroid Build Coastguard Worker     }
177*14675a02SAndroid Build Coastguard Worker     // Add dimensions to QuantizedTensors.
178*14675a02SAndroid Build Coastguard Worker     for (const tensorflow::TensorSpecProto& tensor_spec :
179*14675a02SAndroid Build Coastguard Worker          tensorflow_spec->output_tensor_specs()) {
180*14675a02SAndroid Build Coastguard Worker       if (computation_results.find(tensor_spec.name()) !=
181*14675a02SAndroid Build Coastguard Worker           computation_results.end()) {
182*14675a02SAndroid Build Coastguard Worker         for (const tensorflow::TensorShapeProto_Dim& dim :
183*14675a02SAndroid Build Coastguard Worker              tensor_spec.shape().dim()) {
184*14675a02SAndroid Build Coastguard Worker           std::get<QuantizedTensor>(computation_results[tensor_spec.name()])
185*14675a02SAndroid Build Coastguard Worker               .dimensions.push_back(dim.size());
186*14675a02SAndroid Build Coastguard Worker         }
187*14675a02SAndroid Build Coastguard Worker       }
188*14675a02SAndroid Build Coastguard Worker     }
189*14675a02SAndroid Build Coastguard Worker   }
190*14675a02SAndroid Build Coastguard Worker   // Name of the TF checkpoint inside the aggregand map in the Checkpoint
191*14675a02SAndroid Build Coastguard Worker   // protobuf. This field name is ignored by the server.
192*14675a02SAndroid Build Coastguard Worker   if (!checkpoint_file.empty()) {
193*14675a02SAndroid Build Coastguard Worker     FCP_ASSIGN_OR_RETURN(std::string tf_checkpoint,
194*14675a02SAndroid Build Coastguard Worker                          fcp::ReadFileToString(checkpoint_file));
195*14675a02SAndroid Build Coastguard Worker     computation_results[std::string(kTensorflowCheckpointAggregand)] =
196*14675a02SAndroid Build Coastguard Worker         std::move(tf_checkpoint);
197*14675a02SAndroid Build Coastguard Worker   }
198*14675a02SAndroid Build Coastguard Worker   return computation_results;
199*14675a02SAndroid Build Coastguard Worker }
200*14675a02SAndroid Build Coastguard Worker #ifdef FCP_CLIENT_SUPPORT_TFMOBILE
201*14675a02SAndroid Build Coastguard Worker std::unique_ptr<std::vector<std::pair<std::string, tensorflow::Tensor>>>
ConstructInputsForEligibilityEvalPlan(const FederatedComputeEligibilityIORouter & io_router,const std::string & checkpoint_input_filename)202*14675a02SAndroid Build Coastguard Worker ConstructInputsForEligibilityEvalPlan(
203*14675a02SAndroid Build Coastguard Worker     const FederatedComputeEligibilityIORouter& io_router,
204*14675a02SAndroid Build Coastguard Worker     const std::string& checkpoint_input_filename) {
205*14675a02SAndroid Build Coastguard Worker   auto inputs = std::make_unique<
206*14675a02SAndroid Build Coastguard Worker       std::vector<std::pair<std::string, tensorflow::Tensor>>>();
207*14675a02SAndroid Build Coastguard Worker   if (!io_router.input_filepath_tensor_name().empty()) {
208*14675a02SAndroid Build Coastguard Worker     tensorflow::Tensor input_filepath(tensorflow::DT_STRING, {});
209*14675a02SAndroid Build Coastguard Worker     input_filepath.scalar<tensorflow::tstring>()() = checkpoint_input_filename;
210*14675a02SAndroid Build Coastguard Worker     inputs->push_back({io_router.input_filepath_tensor_name(), input_filepath});
211*14675a02SAndroid Build Coastguard Worker   }
212*14675a02SAndroid Build Coastguard Worker   return inputs;
213*14675a02SAndroid Build Coastguard Worker }
214*14675a02SAndroid Build Coastguard Worker #endif
ConstructTfLiteInputsForEligibilityEvalPlan(const FederatedComputeEligibilityIORouter & io_router,const std::string & checkpoint_input_filename)215*14675a02SAndroid Build Coastguard Worker std::unique_ptr<TfLiteInputs> ConstructTfLiteInputsForEligibilityEvalPlan(
216*14675a02SAndroid Build Coastguard Worker     const FederatedComputeEligibilityIORouter& io_router,
217*14675a02SAndroid Build Coastguard Worker     const std::string& checkpoint_input_filename) {
218*14675a02SAndroid Build Coastguard Worker   auto inputs = std::make_unique<TfLiteInputs>();
219*14675a02SAndroid Build Coastguard Worker   if (!io_router.input_filepath_tensor_name().empty()) {
220*14675a02SAndroid Build Coastguard Worker     (*inputs)[io_router.input_filepath_tensor_name()] =
221*14675a02SAndroid Build Coastguard Worker         checkpoint_input_filename;
222*14675a02SAndroid Build Coastguard Worker   }
223*14675a02SAndroid Build Coastguard Worker   return inputs;
224*14675a02SAndroid Build Coastguard Worker }
225*14675a02SAndroid Build Coastguard Worker // Returns the cumulative network stats (those incurred up until this point in
226*14675a02SAndroid Build Coastguard Worker // time).
227*14675a02SAndroid Build Coastguard Worker //
228*14675a02SAndroid Build Coastguard Worker // The `FederatedSelectManager` object may be null, if it is know that there
229*14675a02SAndroid Build Coastguard Worker // has been no network usage from it yet.
GetCumulativeNetworkStats(FederatedProtocol * federated_protocol,FederatedSelectManager * fedselect_manager)230*14675a02SAndroid Build Coastguard Worker NetworkStats GetCumulativeNetworkStats(
231*14675a02SAndroid Build Coastguard Worker     FederatedProtocol* federated_protocol,
232*14675a02SAndroid Build Coastguard Worker     FederatedSelectManager* fedselect_manager) {
233*14675a02SAndroid Build Coastguard Worker   NetworkStats result = federated_protocol->GetNetworkStats();
234*14675a02SAndroid Build Coastguard Worker   if (fedselect_manager != nullptr) {
235*14675a02SAndroid Build Coastguard Worker     result = result + fedselect_manager->GetNetworkStats();
236*14675a02SAndroid Build Coastguard Worker   }
237*14675a02SAndroid Build Coastguard Worker   return result;
238*14675a02SAndroid Build Coastguard Worker }
239*14675a02SAndroid Build Coastguard Worker // Returns the newly incurred network stats since the previous snapshot of
240*14675a02SAndroid Build Coastguard Worker // stats (the `reference_point` argument).
GetNetworkStatsSince(FederatedProtocol * federated_protocol,FederatedSelectManager * fedselect_manager,const NetworkStats & reference_point)241*14675a02SAndroid Build Coastguard Worker NetworkStats GetNetworkStatsSince(FederatedProtocol* federated_protocol,
242*14675a02SAndroid Build Coastguard Worker                                   FederatedSelectManager* fedselect_manager,
243*14675a02SAndroid Build Coastguard Worker                                   const NetworkStats& reference_point) {
244*14675a02SAndroid Build Coastguard Worker   return GetCumulativeNetworkStats(federated_protocol, fedselect_manager) -
245*14675a02SAndroid Build Coastguard Worker          reference_point;
246*14675a02SAndroid Build Coastguard Worker }
247*14675a02SAndroid Build Coastguard Worker // Updates the fields of `FLRunnerResult` that should always be updated after
248*14675a02SAndroid Build Coastguard Worker // each interaction with the `FederatedProtocol` or `FederatedSelectManager`
249*14675a02SAndroid Build Coastguard Worker // objects.
250*14675a02SAndroid Build Coastguard Worker //
251*14675a02SAndroid Build Coastguard Worker // The `FederatedSelectManager` object may be null, if it is know that there
252*14675a02SAndroid Build Coastguard Worker // has been no network usage from it yet.
UpdateRetryWindowAndNetworkStats(FederatedProtocol & federated_protocol,FederatedSelectManager * fedselect_manager,PhaseLogger & phase_logger,FLRunnerResult & fl_runner_result)253*14675a02SAndroid Build Coastguard Worker void UpdateRetryWindowAndNetworkStats(FederatedProtocol& federated_protocol,
254*14675a02SAndroid Build Coastguard Worker                                       FederatedSelectManager* fedselect_manager,
255*14675a02SAndroid Build Coastguard Worker                                       PhaseLogger& phase_logger,
256*14675a02SAndroid Build Coastguard Worker                                       FLRunnerResult& fl_runner_result) {
257*14675a02SAndroid Build Coastguard Worker   // Update the result's retry window to the most recent one.
258*14675a02SAndroid Build Coastguard Worker   auto retry_window = federated_protocol.GetLatestRetryWindow();
259*14675a02SAndroid Build Coastguard Worker   RetryInfo retry_info;
260*14675a02SAndroid Build Coastguard Worker   *retry_info.mutable_retry_token() = retry_window.retry_token();
261*14675a02SAndroid Build Coastguard Worker   *retry_info.mutable_minimum_delay() = retry_window.delay_min();
262*14675a02SAndroid Build Coastguard Worker   *fl_runner_result.mutable_retry_info() = retry_info;
263*14675a02SAndroid Build Coastguard Worker   phase_logger.UpdateRetryWindowAndNetworkStats(
264*14675a02SAndroid Build Coastguard Worker       retry_window,
265*14675a02SAndroid Build Coastguard Worker       GetCumulativeNetworkStats(&federated_protocol, fedselect_manager));
266*14675a02SAndroid Build Coastguard Worker }
267*14675a02SAndroid Build Coastguard Worker // Creates an ExampleIteratorFactory that routes queries to the
268*14675a02SAndroid Build Coastguard Worker // SimpleTaskEnvironment::CreateExampleIterator() method.
269*14675a02SAndroid Build Coastguard Worker std::unique_ptr<engine::ExampleIteratorFactory>
CreateSimpleTaskEnvironmentIteratorFactory(SimpleTaskEnvironment * task_env,const SelectorContext & selector_context)270*14675a02SAndroid Build Coastguard Worker CreateSimpleTaskEnvironmentIteratorFactory(
271*14675a02SAndroid Build Coastguard Worker     SimpleTaskEnvironment* task_env, const SelectorContext& selector_context) {
272*14675a02SAndroid Build Coastguard Worker   return std::make_unique<engine::FunctionalExampleIteratorFactory>(
273*14675a02SAndroid Build Coastguard Worker       /*can_handle_func=*/
274*14675a02SAndroid Build Coastguard Worker       [](const google::internal::federated::plan::ExampleSelector&) {
275*14675a02SAndroid Build Coastguard Worker         // The SimpleTaskEnvironment-based ExampleIteratorFactory should
276*14675a02SAndroid Build Coastguard Worker         // be the catch-all factory that is able to handle all queries
277*14675a02SAndroid Build Coastguard Worker         // that no other ExampleIteratorFactory is able to handle.
278*14675a02SAndroid Build Coastguard Worker         return true;
279*14675a02SAndroid Build Coastguard Worker       },
280*14675a02SAndroid Build Coastguard Worker       /*create_iterator_func=*/
281*14675a02SAndroid Build Coastguard Worker       [task_env, selector_context](
282*14675a02SAndroid Build Coastguard Worker           const google::internal::federated::plan::ExampleSelector&
283*14675a02SAndroid Build Coastguard Worker               example_selector) {
284*14675a02SAndroid Build Coastguard Worker         return task_env->CreateExampleIterator(example_selector,
285*14675a02SAndroid Build Coastguard Worker                                                selector_context);
286*14675a02SAndroid Build Coastguard Worker       },
287*14675a02SAndroid Build Coastguard Worker       /*should_collect_stats=*/true);
288*14675a02SAndroid Build Coastguard Worker }
RunEligibilityEvalPlanWithTensorflowSpec(std::vector<engine::ExampleIteratorFactory * > example_iterator_factories,std::function<bool ()> should_abort,LogManager * log_manager,OpStatsLogger * opstats_logger,const Flags * flags,const ClientOnlyPlan & client_plan,const std::string & checkpoint_input_filename,const fcp::client::InterruptibleRunner::TimingConfig & timing_config,const absl::Time run_plan_start_time,const absl::Time reference_time)289*14675a02SAndroid Build Coastguard Worker engine::PlanResult RunEligibilityEvalPlanWithTensorflowSpec(
290*14675a02SAndroid Build Coastguard Worker     std::vector<engine::ExampleIteratorFactory*> example_iterator_factories,
291*14675a02SAndroid Build Coastguard Worker     std::function<bool()> should_abort, LogManager* log_manager,
292*14675a02SAndroid Build Coastguard Worker     OpStatsLogger* opstats_logger, const Flags* flags,
293*14675a02SAndroid Build Coastguard Worker     const ClientOnlyPlan& client_plan,
294*14675a02SAndroid Build Coastguard Worker     const std::string& checkpoint_input_filename,
295*14675a02SAndroid Build Coastguard Worker     const fcp::client::InterruptibleRunner::TimingConfig& timing_config,
296*14675a02SAndroid Build Coastguard Worker     const absl::Time run_plan_start_time, const absl::Time reference_time) {
297*14675a02SAndroid Build Coastguard Worker   // Check that this is a TensorflowSpec-based plan for federated eligibility
298*14675a02SAndroid Build Coastguard Worker   // computation.
299*14675a02SAndroid Build Coastguard Worker   if (!client_plan.phase().has_tensorflow_spec() ||
300*14675a02SAndroid Build Coastguard Worker       !client_plan.phase().has_federated_compute_eligibility()) {
301*14675a02SAndroid Build Coastguard Worker     return engine::PlanResult(
302*14675a02SAndroid Build Coastguard Worker         engine::PlanOutcome::kInvalidArgument,
303*14675a02SAndroid Build Coastguard Worker         absl::InvalidArgumentError("Invalid eligibility eval plan"));
304*14675a02SAndroid Build Coastguard Worker   }
305*14675a02SAndroid Build Coastguard Worker   const FederatedComputeEligibilityIORouter& io_router =
306*14675a02SAndroid Build Coastguard Worker       client_plan.phase().federated_compute_eligibility();
307*14675a02SAndroid Build Coastguard Worker   std::vector<std::string> output_names = {
308*14675a02SAndroid Build Coastguard Worker       io_router.task_eligibility_info_tensor_name()};
309*14675a02SAndroid Build Coastguard Worker   if (!client_plan.tflite_graph().empty()) {
310*14675a02SAndroid Build Coastguard Worker     log_manager->LogDiag(
311*14675a02SAndroid Build Coastguard Worker         ProdDiagCode::BACKGROUND_TRAINING_TFLITE_MODEL_INCLUDED);
312*14675a02SAndroid Build Coastguard Worker   }
313*14675a02SAndroid Build Coastguard Worker   if (flags->use_tflite_training() && !client_plan.tflite_graph().empty()) {
314*14675a02SAndroid Build Coastguard Worker     std::unique_ptr<TfLiteInputs> tflite_inputs =
315*14675a02SAndroid Build Coastguard Worker         ConstructTfLiteInputsForEligibilityEvalPlan(io_router,
316*14675a02SAndroid Build Coastguard Worker                                                     checkpoint_input_filename);
317*14675a02SAndroid Build Coastguard Worker     engine::TfLitePlanEngine plan_engine(example_iterator_factories,
318*14675a02SAndroid Build Coastguard Worker                                          should_abort, log_manager,
319*14675a02SAndroid Build Coastguard Worker                                          opstats_logger, flags, &timing_config);
320*14675a02SAndroid Build Coastguard Worker     return plan_engine.RunPlan(client_plan.phase().tensorflow_spec(),
321*14675a02SAndroid Build Coastguard Worker                                client_plan.tflite_graph(),
322*14675a02SAndroid Build Coastguard Worker                                std::move(tflite_inputs), output_names);
323*14675a02SAndroid Build Coastguard Worker   }
324*14675a02SAndroid Build Coastguard Worker #ifdef FCP_CLIENT_SUPPORT_TFMOBILE
325*14675a02SAndroid Build Coastguard Worker   // Construct input tensors and output tensor names based on the values in
326*14675a02SAndroid Build Coastguard Worker   // the FederatedComputeEligibilityIORouter message.
327*14675a02SAndroid Build Coastguard Worker   auto inputs = ConstructInputsForEligibilityEvalPlan(
328*14675a02SAndroid Build Coastguard Worker       io_router, checkpoint_input_filename);
329*14675a02SAndroid Build Coastguard Worker   // Run plan and get a set of output tensors back.
330*14675a02SAndroid Build Coastguard Worker   engine::SimplePlanEngine plan_engine(
331*14675a02SAndroid Build Coastguard Worker       example_iterator_factories, should_abort, log_manager, opstats_logger,
332*14675a02SAndroid Build Coastguard Worker       &timing_config, flags->support_constant_tf_inputs());
333*14675a02SAndroid Build Coastguard Worker   return plan_engine.RunPlan(
334*14675a02SAndroid Build Coastguard Worker       client_plan.phase().tensorflow_spec(), client_plan.graph(),
335*14675a02SAndroid Build Coastguard Worker       client_plan.tensorflow_config_proto(), std::move(inputs), output_names);
336*14675a02SAndroid Build Coastguard Worker #else
337*14675a02SAndroid Build Coastguard Worker   return engine::PlanResult(
338*14675a02SAndroid Build Coastguard Worker       engine::PlanOutcome::kTensorflowError,
339*14675a02SAndroid Build Coastguard Worker       absl::InternalError("No eligibility eval plan engine enabled"));
340*14675a02SAndroid Build Coastguard Worker #endif
341*14675a02SAndroid Build Coastguard Worker }
342*14675a02SAndroid Build Coastguard Worker // Validates the output tensors that resulted from executing the plan, and
343*14675a02SAndroid Build Coastguard Worker // then parses the output into a TaskEligibilityInfo proto. Returns an error
344*14675a02SAndroid Build Coastguard Worker // if validation or parsing failed.
ParseEligibilityEvalPlanOutput(const std::vector<tensorflow::Tensor> & output_tensors)345*14675a02SAndroid Build Coastguard Worker absl::StatusOr<TaskEligibilityInfo> ParseEligibilityEvalPlanOutput(
346*14675a02SAndroid Build Coastguard Worker     const std::vector<tensorflow::Tensor>& output_tensors) {
347*14675a02SAndroid Build Coastguard Worker   auto output_size = output_tensors.size();
348*14675a02SAndroid Build Coastguard Worker   if (output_size != 1) {
349*14675a02SAndroid Build Coastguard Worker     return absl::InvalidArgumentError(
350*14675a02SAndroid Build Coastguard Worker         absl::StrCat("Unexpected number of output tensors: ", output_size));
351*14675a02SAndroid Build Coastguard Worker   }
352*14675a02SAndroid Build Coastguard Worker   auto output_elements = output_tensors[0].NumElements();
353*14675a02SAndroid Build Coastguard Worker   if (output_elements != 1) {
354*14675a02SAndroid Build Coastguard Worker     return absl::InvalidArgumentError(absl::StrCat(
355*14675a02SAndroid Build Coastguard Worker         "Unexpected number of output tensor elements: ", output_elements));
356*14675a02SAndroid Build Coastguard Worker   }
357*14675a02SAndroid Build Coastguard Worker   tensorflow::DataType output_type = output_tensors[0].dtype();
358*14675a02SAndroid Build Coastguard Worker   if (output_type != tensorflow::DT_STRING) {
359*14675a02SAndroid Build Coastguard Worker     return absl::InvalidArgumentError(
360*14675a02SAndroid Build Coastguard Worker         absl::StrCat("Unexpected output tensor type: ", output_type));
361*14675a02SAndroid Build Coastguard Worker   }
362*14675a02SAndroid Build Coastguard Worker   // Extract the serialized TaskEligibilityInfo proto from the tensor and
363*14675a02SAndroid Build Coastguard Worker   // parse it.
364*14675a02SAndroid Build Coastguard Worker   // First, convert the output Tensor into a Scalar (= a TensorMap with 1
365*14675a02SAndroid Build Coastguard Worker   // element), then use its operator() to access the actual data.
366*14675a02SAndroid Build Coastguard Worker   const tensorflow::tstring& serialized_output =
367*14675a02SAndroid Build Coastguard Worker       output_tensors[0].scalar<const tensorflow::tstring>()();
368*14675a02SAndroid Build Coastguard Worker   TaskEligibilityInfo parsed_output;
369*14675a02SAndroid Build Coastguard Worker   if (!parsed_output.ParseFromString(serialized_output)) {
370*14675a02SAndroid Build Coastguard Worker     return absl::InvalidArgumentError("Could not parse output proto");
371*14675a02SAndroid Build Coastguard Worker   }
372*14675a02SAndroid Build Coastguard Worker   return parsed_output;
373*14675a02SAndroid Build Coastguard Worker }
374*14675a02SAndroid Build Coastguard Worker #ifdef FCP_CLIENT_SUPPORT_TFMOBILE
375*14675a02SAndroid Build Coastguard Worker std::unique_ptr<std::vector<std::pair<std::string, tensorflow::Tensor>>>
ConstructInputsForTensorflowSpecPlan(const FederatedComputeIORouter & io_router,const std::string & checkpoint_input_filename,const std::string & checkpoint_output_filename)376*14675a02SAndroid Build Coastguard Worker ConstructInputsForTensorflowSpecPlan(
377*14675a02SAndroid Build Coastguard Worker     const FederatedComputeIORouter& io_router,
378*14675a02SAndroid Build Coastguard Worker     const std::string& checkpoint_input_filename,
379*14675a02SAndroid Build Coastguard Worker     const std::string& checkpoint_output_filename) {
380*14675a02SAndroid Build Coastguard Worker   auto inputs = std::make_unique<
381*14675a02SAndroid Build Coastguard Worker       std::vector<std::pair<std::string, tensorflow::Tensor>>>();
382*14675a02SAndroid Build Coastguard Worker   if (!io_router.input_filepath_tensor_name().empty()) {
383*14675a02SAndroid Build Coastguard Worker     tensorflow::Tensor input_filepath(tensorflow::DT_STRING, {});
384*14675a02SAndroid Build Coastguard Worker     input_filepath.scalar<tensorflow::tstring>()() = checkpoint_input_filename;
385*14675a02SAndroid Build Coastguard Worker     inputs->push_back({io_router.input_filepath_tensor_name(), input_filepath});
386*14675a02SAndroid Build Coastguard Worker   }
387*14675a02SAndroid Build Coastguard Worker   if (!io_router.output_filepath_tensor_name().empty()) {
388*14675a02SAndroid Build Coastguard Worker     tensorflow::Tensor output_filepath(tensorflow::DT_STRING, {});
389*14675a02SAndroid Build Coastguard Worker     output_filepath.scalar<tensorflow::tstring>()() =
390*14675a02SAndroid Build Coastguard Worker         checkpoint_output_filename;
391*14675a02SAndroid Build Coastguard Worker     inputs->push_back(
392*14675a02SAndroid Build Coastguard Worker         {io_router.output_filepath_tensor_name(), output_filepath});
393*14675a02SAndroid Build Coastguard Worker   }
394*14675a02SAndroid Build Coastguard Worker   return inputs;
395*14675a02SAndroid Build Coastguard Worker }
396*14675a02SAndroid Build Coastguard Worker #endif
ConstructTFLiteInputsForTensorflowSpecPlan(const FederatedComputeIORouter & io_router,const std::string & checkpoint_input_filename,const std::string & checkpoint_output_filename)397*14675a02SAndroid Build Coastguard Worker std::unique_ptr<TfLiteInputs> ConstructTFLiteInputsForTensorflowSpecPlan(
398*14675a02SAndroid Build Coastguard Worker     const FederatedComputeIORouter& io_router,
399*14675a02SAndroid Build Coastguard Worker     const std::string& checkpoint_input_filename,
400*14675a02SAndroid Build Coastguard Worker     const std::string& checkpoint_output_filename) {
401*14675a02SAndroid Build Coastguard Worker   auto inputs = std::make_unique<TfLiteInputs>();
402*14675a02SAndroid Build Coastguard Worker   if (!io_router.input_filepath_tensor_name().empty()) {
403*14675a02SAndroid Build Coastguard Worker     (*inputs)[io_router.input_filepath_tensor_name()] =
404*14675a02SAndroid Build Coastguard Worker         checkpoint_input_filename;
405*14675a02SAndroid Build Coastguard Worker   }
406*14675a02SAndroid Build Coastguard Worker   if (!io_router.output_filepath_tensor_name().empty()) {
407*14675a02SAndroid Build Coastguard Worker     (*inputs)[io_router.output_filepath_tensor_name()] =
408*14675a02SAndroid Build Coastguard Worker         checkpoint_output_filename;
409*14675a02SAndroid Build Coastguard Worker   }
410*14675a02SAndroid Build Coastguard Worker   return inputs;
411*14675a02SAndroid Build Coastguard Worker }
ConstructOutputsWithDeterministicOrder(const TensorflowSpec & tensorflow_spec,const FederatedComputeIORouter & io_router)412*14675a02SAndroid Build Coastguard Worker absl::StatusOr<std::vector<std::string>> ConstructOutputsWithDeterministicOrder(
413*14675a02SAndroid Build Coastguard Worker     const TensorflowSpec& tensorflow_spec,
414*14675a02SAndroid Build Coastguard Worker     const FederatedComputeIORouter& io_router) {
415*14675a02SAndroid Build Coastguard Worker   std::vector<std::string> output_names;
416*14675a02SAndroid Build Coastguard Worker   // The order of output tensor names should match the order in
417*14675a02SAndroid Build Coastguard Worker   // TensorflowSpec.
418*14675a02SAndroid Build Coastguard Worker   for (const auto& output_tensor_spec : tensorflow_spec.output_tensor_specs()) {
419*14675a02SAndroid Build Coastguard Worker     std::string tensor_name = output_tensor_spec.name();
420*14675a02SAndroid Build Coastguard Worker     if (!io_router.aggregations().contains(tensor_name) ||
421*14675a02SAndroid Build Coastguard Worker         !io_router.aggregations().at(tensor_name).has_secure_aggregation()) {
422*14675a02SAndroid Build Coastguard Worker       return absl::InvalidArgumentError(
423*14675a02SAndroid Build Coastguard Worker           "Output tensor is missing in AggregationConfig, or has unsupported "
424*14675a02SAndroid Build Coastguard Worker           "aggregation type.");
425*14675a02SAndroid Build Coastguard Worker     }
426*14675a02SAndroid Build Coastguard Worker     output_names.push_back(tensor_name);
427*14675a02SAndroid Build Coastguard Worker   }
428*14675a02SAndroid Build Coastguard Worker   return output_names;
429*14675a02SAndroid Build Coastguard Worker }
RunPlanWithTensorflowSpec(std::vector<engine::ExampleIteratorFactory * > example_iterator_factories,std::function<bool ()> should_abort,LogManager * log_manager,OpStatsLogger * opstats_logger,const Flags * flags,const ClientOnlyPlan & client_plan,const std::string & checkpoint_input_filename,const std::string & checkpoint_output_filename,const fcp::client::InterruptibleRunner::TimingConfig & timing_config)430*14675a02SAndroid Build Coastguard Worker PlanResultAndCheckpointFile RunPlanWithTensorflowSpec(
431*14675a02SAndroid Build Coastguard Worker     std::vector<engine::ExampleIteratorFactory*> example_iterator_factories,
432*14675a02SAndroid Build Coastguard Worker     std::function<bool()> should_abort, LogManager* log_manager,
433*14675a02SAndroid Build Coastguard Worker     OpStatsLogger* opstats_logger, const Flags* flags,
434*14675a02SAndroid Build Coastguard Worker     const ClientOnlyPlan& client_plan,
435*14675a02SAndroid Build Coastguard Worker     const std::string& checkpoint_input_filename,
436*14675a02SAndroid Build Coastguard Worker     const std::string& checkpoint_output_filename,
437*14675a02SAndroid Build Coastguard Worker     const fcp::client::InterruptibleRunner::TimingConfig& timing_config) {
438*14675a02SAndroid Build Coastguard Worker   if (!client_plan.phase().has_tensorflow_spec()) {
439*14675a02SAndroid Build Coastguard Worker     return PlanResultAndCheckpointFile(engine::PlanResult(
440*14675a02SAndroid Build Coastguard Worker         engine::PlanOutcome::kInvalidArgument,
441*14675a02SAndroid Build Coastguard Worker         absl::InvalidArgumentError("Plan must include TensorflowSpec.")));
442*14675a02SAndroid Build Coastguard Worker   }
443*14675a02SAndroid Build Coastguard Worker   if (!client_plan.phase().has_federated_compute()) {
444*14675a02SAndroid Build Coastguard Worker     return PlanResultAndCheckpointFile(engine::PlanResult(
445*14675a02SAndroid Build Coastguard Worker         engine::PlanOutcome::kInvalidArgument,
446*14675a02SAndroid Build Coastguard Worker         absl::InvalidArgumentError("Invalid TensorflowSpec-based plan")));
447*14675a02SAndroid Build Coastguard Worker   }
448*14675a02SAndroid Build Coastguard Worker   // Get the output tensor names.
449*14675a02SAndroid Build Coastguard Worker   absl::StatusOr<std::vector<std::string>> output_names;
450*14675a02SAndroid Build Coastguard Worker   output_names = ConstructOutputsWithDeterministicOrder(
451*14675a02SAndroid Build Coastguard Worker       client_plan.phase().tensorflow_spec(),
452*14675a02SAndroid Build Coastguard Worker       client_plan.phase().federated_compute());
453*14675a02SAndroid Build Coastguard Worker   if (!output_names.ok()) {
454*14675a02SAndroid Build Coastguard Worker     return PlanResultAndCheckpointFile(engine::PlanResult(
455*14675a02SAndroid Build Coastguard Worker         engine::PlanOutcome::kInvalidArgument, output_names.status()));
456*14675a02SAndroid Build Coastguard Worker   }
457*14675a02SAndroid Build Coastguard Worker   // Run plan and get a set of output tensors back.
458*14675a02SAndroid Build Coastguard Worker   if (flags->use_tflite_training() && !client_plan.tflite_graph().empty()) {
459*14675a02SAndroid Build Coastguard Worker     std::unique_ptr<TfLiteInputs> tflite_inputs =
460*14675a02SAndroid Build Coastguard Worker         ConstructTFLiteInputsForTensorflowSpecPlan(
461*14675a02SAndroid Build Coastguard Worker             client_plan.phase().federated_compute(), checkpoint_input_filename,
462*14675a02SAndroid Build Coastguard Worker             checkpoint_output_filename);
463*14675a02SAndroid Build Coastguard Worker     engine::TfLitePlanEngine plan_engine(example_iterator_factories,
464*14675a02SAndroid Build Coastguard Worker                                          should_abort, log_manager,
465*14675a02SAndroid Build Coastguard Worker                                          opstats_logger, flags, &timing_config);
466*14675a02SAndroid Build Coastguard Worker     engine::PlanResult plan_result = plan_engine.RunPlan(
467*14675a02SAndroid Build Coastguard Worker         client_plan.phase().tensorflow_spec(), client_plan.tflite_graph(),
468*14675a02SAndroid Build Coastguard Worker         std::move(tflite_inputs), *output_names);
469*14675a02SAndroid Build Coastguard Worker     PlanResultAndCheckpointFile result(std::move(plan_result));
470*14675a02SAndroid Build Coastguard Worker     result.checkpoint_file = checkpoint_output_filename;
471*14675a02SAndroid Build Coastguard Worker     return result;
472*14675a02SAndroid Build Coastguard Worker   }
473*14675a02SAndroid Build Coastguard Worker #ifdef FCP_CLIENT_SUPPORT_TFMOBILE
474*14675a02SAndroid Build Coastguard Worker   // Construct input tensors based on the values in the
475*14675a02SAndroid Build Coastguard Worker   // FederatedComputeIORouter message and create a temporary file for the
476*14675a02SAndroid Build Coastguard Worker   // output checkpoint if needed.
477*14675a02SAndroid Build Coastguard Worker   auto inputs = ConstructInputsForTensorflowSpecPlan(
478*14675a02SAndroid Build Coastguard Worker       client_plan.phase().federated_compute(), checkpoint_input_filename,
479*14675a02SAndroid Build Coastguard Worker       checkpoint_output_filename);
480*14675a02SAndroid Build Coastguard Worker   engine::SimplePlanEngine plan_engine(
481*14675a02SAndroid Build Coastguard Worker       example_iterator_factories, should_abort, log_manager, opstats_logger,
482*14675a02SAndroid Build Coastguard Worker       &timing_config, flags->support_constant_tf_inputs());
483*14675a02SAndroid Build Coastguard Worker   engine::PlanResult plan_result = plan_engine.RunPlan(
484*14675a02SAndroid Build Coastguard Worker       client_plan.phase().tensorflow_spec(), client_plan.graph(),
485*14675a02SAndroid Build Coastguard Worker       client_plan.tensorflow_config_proto(), std::move(inputs), *output_names);
486*14675a02SAndroid Build Coastguard Worker   PlanResultAndCheckpointFile result(std::move(plan_result));
487*14675a02SAndroid Build Coastguard Worker   result.checkpoint_file = checkpoint_output_filename;
488*14675a02SAndroid Build Coastguard Worker   return result;
489*14675a02SAndroid Build Coastguard Worker #else
490*14675a02SAndroid Build Coastguard Worker   return PlanResultAndCheckpointFile(
491*14675a02SAndroid Build Coastguard Worker       engine::PlanResult(engine::PlanOutcome::kTensorflowError,
492*14675a02SAndroid Build Coastguard Worker                          absl::InternalError("No plan engine enabled")));
493*14675a02SAndroid Build Coastguard Worker #endif
494*14675a02SAndroid Build Coastguard Worker }
RunPlanWithExampleQuerySpec(std::vector<engine::ExampleIteratorFactory * > example_iterator_factories,OpStatsLogger * opstats_logger,const Flags * flags,const ClientOnlyPlan & client_plan,const std::string & checkpoint_output_filename)495*14675a02SAndroid Build Coastguard Worker PlanResultAndCheckpointFile RunPlanWithExampleQuerySpec(
496*14675a02SAndroid Build Coastguard Worker     std::vector<engine::ExampleIteratorFactory*> example_iterator_factories,
497*14675a02SAndroid Build Coastguard Worker     OpStatsLogger* opstats_logger, const Flags* flags,
498*14675a02SAndroid Build Coastguard Worker     const ClientOnlyPlan& client_plan,
499*14675a02SAndroid Build Coastguard Worker     const std::string& checkpoint_output_filename) {
500*14675a02SAndroid Build Coastguard Worker   if (!client_plan.phase().has_example_query_spec()) {
501*14675a02SAndroid Build Coastguard Worker     return PlanResultAndCheckpointFile(engine::PlanResult(
502*14675a02SAndroid Build Coastguard Worker         engine::PlanOutcome::kInvalidArgument,
503*14675a02SAndroid Build Coastguard Worker         absl::InvalidArgumentError("Plan must include ExampleQuerySpec")));
504*14675a02SAndroid Build Coastguard Worker   }
505*14675a02SAndroid Build Coastguard Worker   if (!flags->enable_example_query_plan_engine()) {
506*14675a02SAndroid Build Coastguard Worker     // Example query plan received while the flag is off.
507*14675a02SAndroid Build Coastguard Worker     return PlanResultAndCheckpointFile(engine::PlanResult(
508*14675a02SAndroid Build Coastguard Worker         engine::PlanOutcome::kInvalidArgument,
509*14675a02SAndroid Build Coastguard Worker         absl::InvalidArgumentError(
510*14675a02SAndroid Build Coastguard Worker             "Example query plan received while the flag is off")));
511*14675a02SAndroid Build Coastguard Worker   }
512*14675a02SAndroid Build Coastguard Worker   if (!client_plan.phase().has_federated_example_query()) {
513*14675a02SAndroid Build Coastguard Worker     return PlanResultAndCheckpointFile(engine::PlanResult(
514*14675a02SAndroid Build Coastguard Worker         engine::PlanOutcome::kInvalidArgument,
515*14675a02SAndroid Build Coastguard Worker         absl::InvalidArgumentError("Invalid ExampleQuerySpec-based plan")));
516*14675a02SAndroid Build Coastguard Worker   }
517*14675a02SAndroid Build Coastguard Worker   for (const auto& example_query :
518*14675a02SAndroid Build Coastguard Worker        client_plan.phase().example_query_spec().example_queries()) {
519*14675a02SAndroid Build Coastguard Worker     for (auto const& [vector_name, spec] :
520*14675a02SAndroid Build Coastguard Worker          example_query.output_vector_specs()) {
521*14675a02SAndroid Build Coastguard Worker       const auto& aggregations =
522*14675a02SAndroid Build Coastguard Worker           client_plan.phase().federated_example_query().aggregations();
523*14675a02SAndroid Build Coastguard Worker       if ((aggregations.find(vector_name) == aggregations.end()) ||
524*14675a02SAndroid Build Coastguard Worker           !aggregations.at(vector_name).has_tf_v1_checkpoint_aggregation()) {
525*14675a02SAndroid Build Coastguard Worker         return PlanResultAndCheckpointFile(engine::PlanResult(
526*14675a02SAndroid Build Coastguard Worker             engine::PlanOutcome::kInvalidArgument,
527*14675a02SAndroid Build Coastguard Worker             absl::InvalidArgumentError("Output vector is missing in "
528*14675a02SAndroid Build Coastguard Worker                                        "AggregationConfig, or has unsupported "
529*14675a02SAndroid Build Coastguard Worker                                        "aggregation type.")));
530*14675a02SAndroid Build Coastguard Worker       }
531*14675a02SAndroid Build Coastguard Worker     }
532*14675a02SAndroid Build Coastguard Worker   }
533*14675a02SAndroid Build Coastguard Worker   engine::ExampleQueryPlanEngine plan_engine(example_iterator_factories,
534*14675a02SAndroid Build Coastguard Worker                                              opstats_logger);
535*14675a02SAndroid Build Coastguard Worker   engine::PlanResult plan_result = plan_engine.RunPlan(
536*14675a02SAndroid Build Coastguard Worker       client_plan.phase().example_query_spec(), checkpoint_output_filename);
537*14675a02SAndroid Build Coastguard Worker   PlanResultAndCheckpointFile result(std::move(plan_result));
538*14675a02SAndroid Build Coastguard Worker   result.checkpoint_file = checkpoint_output_filename;
539*14675a02SAndroid Build Coastguard Worker   return result;
540*14675a02SAndroid Build Coastguard Worker }
LogEligibilityEvalComputationOutcome(PhaseLogger & phase_logger,engine::PlanResult plan_result,const absl::Status & eligibility_info_parsing_status,absl::Time run_plan_start_time,absl::Time reference_time)541*14675a02SAndroid Build Coastguard Worker void LogEligibilityEvalComputationOutcome(
542*14675a02SAndroid Build Coastguard Worker     PhaseLogger& phase_logger, engine::PlanResult plan_result,
543*14675a02SAndroid Build Coastguard Worker     const absl::Status& eligibility_info_parsing_status,
544*14675a02SAndroid Build Coastguard Worker     absl::Time run_plan_start_time, absl::Time reference_time) {
545*14675a02SAndroid Build Coastguard Worker   switch (plan_result.outcome) {
546*14675a02SAndroid Build Coastguard Worker     case engine::PlanOutcome::kSuccess: {
547*14675a02SAndroid Build Coastguard Worker       if (eligibility_info_parsing_status.ok()) {
548*14675a02SAndroid Build Coastguard Worker         phase_logger.LogEligibilityEvalComputationCompleted(
549*14675a02SAndroid Build Coastguard Worker             plan_result.example_stats, run_plan_start_time, reference_time);
550*14675a02SAndroid Build Coastguard Worker       } else {
551*14675a02SAndroid Build Coastguard Worker         phase_logger.LogEligibilityEvalComputationTensorflowError(
552*14675a02SAndroid Build Coastguard Worker             eligibility_info_parsing_status, plan_result.example_stats,
553*14675a02SAndroid Build Coastguard Worker             run_plan_start_time, reference_time);
554*14675a02SAndroid Build Coastguard Worker         FCP_LOG(ERROR) << eligibility_info_parsing_status.message();
555*14675a02SAndroid Build Coastguard Worker       }
556*14675a02SAndroid Build Coastguard Worker       break;
557*14675a02SAndroid Build Coastguard Worker     }
558*14675a02SAndroid Build Coastguard Worker     case engine::PlanOutcome::kInterrupted:
559*14675a02SAndroid Build Coastguard Worker       phase_logger.LogEligibilityEvalComputationInterrupted(
560*14675a02SAndroid Build Coastguard Worker           plan_result.original_status, plan_result.example_stats,
561*14675a02SAndroid Build Coastguard Worker           run_plan_start_time, reference_time);
562*14675a02SAndroid Build Coastguard Worker       break;
563*14675a02SAndroid Build Coastguard Worker     case engine::PlanOutcome::kInvalidArgument:
564*14675a02SAndroid Build Coastguard Worker       phase_logger.LogEligibilityEvalComputationInvalidArgument(
565*14675a02SAndroid Build Coastguard Worker           plan_result.original_status, plan_result.example_stats,
566*14675a02SAndroid Build Coastguard Worker           run_plan_start_time);
567*14675a02SAndroid Build Coastguard Worker       break;
568*14675a02SAndroid Build Coastguard Worker     case engine::PlanOutcome::kTensorflowError:
569*14675a02SAndroid Build Coastguard Worker       phase_logger.LogEligibilityEvalComputationTensorflowError(
570*14675a02SAndroid Build Coastguard Worker           plan_result.original_status, plan_result.example_stats,
571*14675a02SAndroid Build Coastguard Worker           run_plan_start_time, reference_time);
572*14675a02SAndroid Build Coastguard Worker       break;
573*14675a02SAndroid Build Coastguard Worker     case engine::PlanOutcome::kExampleIteratorError:
574*14675a02SAndroid Build Coastguard Worker       phase_logger.LogEligibilityEvalComputationExampleIteratorError(
575*14675a02SAndroid Build Coastguard Worker           plan_result.original_status, plan_result.example_stats,
576*14675a02SAndroid Build Coastguard Worker           run_plan_start_time);
577*14675a02SAndroid Build Coastguard Worker       break;
578*14675a02SAndroid Build Coastguard Worker   }
579*14675a02SAndroid Build Coastguard Worker }
LogComputationOutcome(const engine::PlanResult & plan_result,absl::Status computation_results_parsing_status,PhaseLogger & phase_logger,const NetworkStats & network_stats,absl::Time run_plan_start_time,absl::Time reference_time)580*14675a02SAndroid Build Coastguard Worker void LogComputationOutcome(const engine::PlanResult& plan_result,
581*14675a02SAndroid Build Coastguard Worker                            absl::Status computation_results_parsing_status,
582*14675a02SAndroid Build Coastguard Worker                            PhaseLogger& phase_logger,
583*14675a02SAndroid Build Coastguard Worker                            const NetworkStats& network_stats,
584*14675a02SAndroid Build Coastguard Worker                            absl::Time run_plan_start_time,
585*14675a02SAndroid Build Coastguard Worker                            absl::Time reference_time) {
586*14675a02SAndroid Build Coastguard Worker   switch (plan_result.outcome) {
587*14675a02SAndroid Build Coastguard Worker     case engine::PlanOutcome::kSuccess: {
588*14675a02SAndroid Build Coastguard Worker       if (computation_results_parsing_status.ok()) {
589*14675a02SAndroid Build Coastguard Worker         phase_logger.LogComputationCompleted(plan_result.example_stats,
590*14675a02SAndroid Build Coastguard Worker                                              network_stats, run_plan_start_time,
591*14675a02SAndroid Build Coastguard Worker                                              reference_time);
592*14675a02SAndroid Build Coastguard Worker       } else {
593*14675a02SAndroid Build Coastguard Worker         phase_logger.LogComputationTensorflowError(
594*14675a02SAndroid Build Coastguard Worker             computation_results_parsing_status, plan_result.example_stats,
595*14675a02SAndroid Build Coastguard Worker             network_stats, run_plan_start_time, reference_time);
596*14675a02SAndroid Build Coastguard Worker       }
597*14675a02SAndroid Build Coastguard Worker       break;
598*14675a02SAndroid Build Coastguard Worker     }
599*14675a02SAndroid Build Coastguard Worker     case engine::PlanOutcome::kInterrupted:
600*14675a02SAndroid Build Coastguard Worker       phase_logger.LogComputationInterrupted(
601*14675a02SAndroid Build Coastguard Worker           plan_result.original_status, plan_result.example_stats, network_stats,
602*14675a02SAndroid Build Coastguard Worker           run_plan_start_time, reference_time);
603*14675a02SAndroid Build Coastguard Worker       break;
604*14675a02SAndroid Build Coastguard Worker     case engine::PlanOutcome::kInvalidArgument:
605*14675a02SAndroid Build Coastguard Worker       phase_logger.LogComputationInvalidArgument(
606*14675a02SAndroid Build Coastguard Worker           plan_result.original_status, plan_result.example_stats, network_stats,
607*14675a02SAndroid Build Coastguard Worker           run_plan_start_time);
608*14675a02SAndroid Build Coastguard Worker       break;
609*14675a02SAndroid Build Coastguard Worker     case engine::PlanOutcome::kTensorflowError:
610*14675a02SAndroid Build Coastguard Worker       phase_logger.LogComputationTensorflowError(
611*14675a02SAndroid Build Coastguard Worker           plan_result.original_status, plan_result.example_stats, network_stats,
612*14675a02SAndroid Build Coastguard Worker           run_plan_start_time, reference_time);
613*14675a02SAndroid Build Coastguard Worker       break;
614*14675a02SAndroid Build Coastguard Worker     case engine::PlanOutcome::kExampleIteratorError:
615*14675a02SAndroid Build Coastguard Worker       phase_logger.LogComputationExampleIteratorError(
616*14675a02SAndroid Build Coastguard Worker           plan_result.original_status, plan_result.example_stats, network_stats,
617*14675a02SAndroid Build Coastguard Worker           run_plan_start_time);
618*14675a02SAndroid Build Coastguard Worker       break;
619*14675a02SAndroid Build Coastguard Worker   }
620*14675a02SAndroid Build Coastguard Worker }
LogResultUploadStatus(PhaseLogger & phase_logger,absl::Status result,const NetworkStats & network_stats,absl::Time time_before_result_upload,absl::Time reference_time)621*14675a02SAndroid Build Coastguard Worker void LogResultUploadStatus(PhaseLogger& phase_logger, absl::Status result,
622*14675a02SAndroid Build Coastguard Worker                            const NetworkStats& network_stats,
623*14675a02SAndroid Build Coastguard Worker                            absl::Time time_before_result_upload,
624*14675a02SAndroid Build Coastguard Worker                            absl::Time reference_time) {
625*14675a02SAndroid Build Coastguard Worker   if (result.ok()) {
626*14675a02SAndroid Build Coastguard Worker     phase_logger.LogResultUploadCompleted(
627*14675a02SAndroid Build Coastguard Worker         network_stats, time_before_result_upload, reference_time);
628*14675a02SAndroid Build Coastguard Worker   } else {
629*14675a02SAndroid Build Coastguard Worker     auto message =
630*14675a02SAndroid Build Coastguard Worker         absl::StrCat("Error reporting results: code: ", result.code(),
631*14675a02SAndroid Build Coastguard Worker                      ", message: ", result.message());
632*14675a02SAndroid Build Coastguard Worker     FCP_LOG(INFO) << message;
633*14675a02SAndroid Build Coastguard Worker     if (result.code() == absl::StatusCode::kAborted) {
634*14675a02SAndroid Build Coastguard Worker       phase_logger.LogResultUploadServerAborted(
635*14675a02SAndroid Build Coastguard Worker           result, network_stats, time_before_result_upload, reference_time);
636*14675a02SAndroid Build Coastguard Worker     } else if (result.code() == absl::StatusCode::kCancelled) {
637*14675a02SAndroid Build Coastguard Worker       phase_logger.LogResultUploadClientInterrupted(
638*14675a02SAndroid Build Coastguard Worker           result, network_stats, time_before_result_upload, reference_time);
639*14675a02SAndroid Build Coastguard Worker     } else {
640*14675a02SAndroid Build Coastguard Worker       phase_logger.LogResultUploadIOError(
641*14675a02SAndroid Build Coastguard Worker           result, network_stats, time_before_result_upload, reference_time);
642*14675a02SAndroid Build Coastguard Worker     }
643*14675a02SAndroid Build Coastguard Worker   }
644*14675a02SAndroid Build Coastguard Worker }
LogFailureUploadStatus(PhaseLogger & phase_logger,absl::Status result,const NetworkStats & network_stats,absl::Time time_before_failure_upload,absl::Time reference_time)645*14675a02SAndroid Build Coastguard Worker void LogFailureUploadStatus(PhaseLogger& phase_logger, absl::Status result,
646*14675a02SAndroid Build Coastguard Worker                             const NetworkStats& network_stats,
647*14675a02SAndroid Build Coastguard Worker                             absl::Time time_before_failure_upload,
648*14675a02SAndroid Build Coastguard Worker                             absl::Time reference_time) {
649*14675a02SAndroid Build Coastguard Worker   if (result.ok()) {
650*14675a02SAndroid Build Coastguard Worker     phase_logger.LogFailureUploadCompleted(
651*14675a02SAndroid Build Coastguard Worker         network_stats, time_before_failure_upload, reference_time);
652*14675a02SAndroid Build Coastguard Worker   } else {
653*14675a02SAndroid Build Coastguard Worker     auto message = absl::StrCat("Error reporting computation failure: code: ",
654*14675a02SAndroid Build Coastguard Worker                                 result.code(), ", message: ", result.message());
655*14675a02SAndroid Build Coastguard Worker     FCP_LOG(INFO) << message;
656*14675a02SAndroid Build Coastguard Worker     if (result.code() == absl::StatusCode::kAborted) {
657*14675a02SAndroid Build Coastguard Worker       phase_logger.LogFailureUploadServerAborted(
658*14675a02SAndroid Build Coastguard Worker           result, network_stats, time_before_failure_upload, reference_time);
659*14675a02SAndroid Build Coastguard Worker     } else if (result.code() == absl::StatusCode::kCancelled) {
660*14675a02SAndroid Build Coastguard Worker       phase_logger.LogFailureUploadClientInterrupted(
661*14675a02SAndroid Build Coastguard Worker           result, network_stats, time_before_failure_upload, reference_time);
662*14675a02SAndroid Build Coastguard Worker     } else {
663*14675a02SAndroid Build Coastguard Worker       phase_logger.LogFailureUploadIOError(
664*14675a02SAndroid Build Coastguard Worker           result, network_stats, time_before_failure_upload, reference_time);
665*14675a02SAndroid Build Coastguard Worker     }
666*14675a02SAndroid Build Coastguard Worker   }
667*14675a02SAndroid Build Coastguard Worker }
ReportPlanResult(FederatedProtocol * federated_protocol,PhaseLogger & phase_logger,absl::StatusOr<ComputationResults> computation_results,absl::Time run_plan_start_time,absl::Time reference_time)668*14675a02SAndroid Build Coastguard Worker absl::Status ReportPlanResult(
669*14675a02SAndroid Build Coastguard Worker     FederatedProtocol* federated_protocol, PhaseLogger& phase_logger,
670*14675a02SAndroid Build Coastguard Worker     absl::StatusOr<ComputationResults> computation_results,
671*14675a02SAndroid Build Coastguard Worker     absl::Time run_plan_start_time, absl::Time reference_time) {
672*14675a02SAndroid Build Coastguard Worker   const absl::Time before_report_time = absl::Now();
673*14675a02SAndroid Build Coastguard Worker   // Note that the FederatedSelectManager shouldn't be active anymore during
674*14675a02SAndroid Build Coastguard Worker   // the reporting of results, so we don't bother passing it to
675*14675a02SAndroid Build Coastguard Worker   // GetNetworkStatsSince.
676*14675a02SAndroid Build Coastguard Worker   //
677*14675a02SAndroid Build Coastguard Worker   // We must return only stats that cover the report phase for the log events
678*14675a02SAndroid Build Coastguard Worker   // below.
679*14675a02SAndroid Build Coastguard Worker   const NetworkStats before_report_stats =
680*14675a02SAndroid Build Coastguard Worker       GetCumulativeNetworkStats(federated_protocol,
681*14675a02SAndroid Build Coastguard Worker                                 /*fedselect_manager=*/nullptr);
682*14675a02SAndroid Build Coastguard Worker   absl::Status result = absl::InternalError("");
683*14675a02SAndroid Build Coastguard Worker   if (computation_results.ok()) {
684*14675a02SAndroid Build Coastguard Worker     FCP_RETURN_IF_ERROR(phase_logger.LogResultUploadStarted());
685*14675a02SAndroid Build Coastguard Worker     result = federated_protocol->ReportCompleted(
686*14675a02SAndroid Build Coastguard Worker         std::move(*computation_results),
687*14675a02SAndroid Build Coastguard Worker         /*plan_duration=*/absl::Now() - run_plan_start_time, std::nullopt);
688*14675a02SAndroid Build Coastguard Worker     LogResultUploadStatus(phase_logger, result,
689*14675a02SAndroid Build Coastguard Worker                           GetNetworkStatsSince(federated_protocol,
690*14675a02SAndroid Build Coastguard Worker                                                /*fedselect_manager=*/nullptr,
691*14675a02SAndroid Build Coastguard Worker                                                before_report_stats),
692*14675a02SAndroid Build Coastguard Worker                           before_report_time, reference_time);
693*14675a02SAndroid Build Coastguard Worker   } else {
694*14675a02SAndroid Build Coastguard Worker     FCP_RETURN_IF_ERROR(phase_logger.LogFailureUploadStarted());
695*14675a02SAndroid Build Coastguard Worker     result = federated_protocol->ReportNotCompleted(
696*14675a02SAndroid Build Coastguard Worker         engine::PhaseOutcome::ERROR,
697*14675a02SAndroid Build Coastguard Worker         /*plan_duration=*/absl::Now() - run_plan_start_time, std::nullopt);
698*14675a02SAndroid Build Coastguard Worker     LogFailureUploadStatus(phase_logger, result,
699*14675a02SAndroid Build Coastguard Worker                            GetNetworkStatsSince(federated_protocol,
700*14675a02SAndroid Build Coastguard Worker                                                 /*fedselect_manager=*/nullptr,
701*14675a02SAndroid Build Coastguard Worker                                                 before_report_stats),
702*14675a02SAndroid Build Coastguard Worker                            before_report_time, reference_time);
703*14675a02SAndroid Build Coastguard Worker   }
704*14675a02SAndroid Build Coastguard Worker   return result;
705*14675a02SAndroid Build Coastguard Worker }
706*14675a02SAndroid Build Coastguard Worker // Writes the given data to the stream, and returns true if successful and
707*14675a02SAndroid Build Coastguard Worker // false if not.
WriteStringOrCordToFstream(std::fstream & stream,const std::variant<std::string,absl::Cord> & data)708*14675a02SAndroid Build Coastguard Worker bool WriteStringOrCordToFstream(
709*14675a02SAndroid Build Coastguard Worker     std::fstream& stream, const std::variant<std::string, absl::Cord>& data) {
710*14675a02SAndroid Build Coastguard Worker   if (stream.fail()) {
711*14675a02SAndroid Build Coastguard Worker     return false;
712*14675a02SAndroid Build Coastguard Worker   }
713*14675a02SAndroid Build Coastguard Worker   if (std::holds_alternative<std::string>(data)) {
714*14675a02SAndroid Build Coastguard Worker     return (stream << std::get<std::string>(data)).good();
715*14675a02SAndroid Build Coastguard Worker   }
716*14675a02SAndroid Build Coastguard Worker   for (absl::string_view chunk : std::get<absl::Cord>(data).Chunks()) {
717*14675a02SAndroid Build Coastguard Worker     if (!(stream << chunk).good()) {
718*14675a02SAndroid Build Coastguard Worker       return false;
719*14675a02SAndroid Build Coastguard Worker     }
720*14675a02SAndroid Build Coastguard Worker   }
721*14675a02SAndroid Build Coastguard Worker   return true;
722*14675a02SAndroid Build Coastguard Worker }
723*14675a02SAndroid Build Coastguard Worker // Writes the given checkpoint data to a newly created temporary file.
724*14675a02SAndroid Build Coastguard Worker // Returns the filename if successful, or an error if the file could not be
725*14675a02SAndroid Build Coastguard Worker // created, or if writing to the file failed.
CreateInputCheckpointFile(Files * files,const std::variant<std::string,absl::Cord> & checkpoint)726*14675a02SAndroid Build Coastguard Worker absl::StatusOr<std::string> CreateInputCheckpointFile(
727*14675a02SAndroid Build Coastguard Worker     Files* files, const std::variant<std::string, absl::Cord>& checkpoint) {
728*14675a02SAndroid Build Coastguard Worker   // Create the temporary checkpoint file.
729*14675a02SAndroid Build Coastguard Worker   // Deletion of the file is left to the caller / the Files implementation.
730*14675a02SAndroid Build Coastguard Worker   FCP_ASSIGN_OR_RETURN(absl::StatusOr<std::string> filename,
731*14675a02SAndroid Build Coastguard Worker                        files->CreateTempFile("init", ".ckp"));
732*14675a02SAndroid Build Coastguard Worker   // Write the checkpoint data to the file.
733*14675a02SAndroid Build Coastguard Worker   std::fstream checkpoint_stream(*filename, std::ios_base::out);
734*14675a02SAndroid Build Coastguard Worker   if (!WriteStringOrCordToFstream(checkpoint_stream, checkpoint)) {
735*14675a02SAndroid Build Coastguard Worker     return absl::InvalidArgumentError("Failed to write to file");
736*14675a02SAndroid Build Coastguard Worker   }
737*14675a02SAndroid Build Coastguard Worker   checkpoint_stream.close();
738*14675a02SAndroid Build Coastguard Worker   return filename;
739*14675a02SAndroid Build Coastguard Worker }
RunEligibilityEvalPlan(const FederatedProtocol::EligibilityEvalTask & eligibility_eval_task,std::vector<engine::ExampleIteratorFactory * > example_iterator_factories,std::function<bool ()> should_abort,PhaseLogger & phase_logger,Files * files,LogManager * log_manager,OpStatsLogger * opstats_logger,const Flags * flags,FederatedProtocol * federated_protocol,const fcp::client::InterruptibleRunner::TimingConfig & timing_config,const absl::Time reference_time,const absl::Time time_before_checkin,const absl::Time time_before_plan_download,const NetworkStats & network_stats)740*14675a02SAndroid Build Coastguard Worker absl::StatusOr<std::optional<TaskEligibilityInfo>> RunEligibilityEvalPlan(
741*14675a02SAndroid Build Coastguard Worker     const FederatedProtocol::EligibilityEvalTask& eligibility_eval_task,
742*14675a02SAndroid Build Coastguard Worker     std::vector<engine::ExampleIteratorFactory*> example_iterator_factories,
743*14675a02SAndroid Build Coastguard Worker     std::function<bool()> should_abort, PhaseLogger& phase_logger, Files* files,
744*14675a02SAndroid Build Coastguard Worker     LogManager* log_manager, OpStatsLogger* opstats_logger, const Flags* flags,
745*14675a02SAndroid Build Coastguard Worker     FederatedProtocol* federated_protocol,
746*14675a02SAndroid Build Coastguard Worker     const fcp::client::InterruptibleRunner::TimingConfig& timing_config,
747*14675a02SAndroid Build Coastguard Worker     const absl::Time reference_time, const absl::Time time_before_checkin,
748*14675a02SAndroid Build Coastguard Worker     const absl::Time time_before_plan_download,
749*14675a02SAndroid Build Coastguard Worker     const NetworkStats& network_stats) {
750*14675a02SAndroid Build Coastguard Worker   ClientOnlyPlan plan;
751*14675a02SAndroid Build Coastguard Worker   if (!ParseFromStringOrCord(plan, eligibility_eval_task.payloads.plan)) {
752*14675a02SAndroid Build Coastguard Worker     auto message = "Failed to parse received eligibility eval plan";
753*14675a02SAndroid Build Coastguard Worker     phase_logger.LogEligibilityEvalCheckinInvalidPayloadError(
754*14675a02SAndroid Build Coastguard Worker         message, network_stats, time_before_plan_download);
755*14675a02SAndroid Build Coastguard Worker     FCP_LOG(ERROR) << message;
756*14675a02SAndroid Build Coastguard Worker     return absl::InternalError(message);
757*14675a02SAndroid Build Coastguard Worker   }
758*14675a02SAndroid Build Coastguard Worker   absl::StatusOr<std::string> checkpoint_input_filename =
759*14675a02SAndroid Build Coastguard Worker       CreateInputCheckpointFile(files,
760*14675a02SAndroid Build Coastguard Worker                                 eligibility_eval_task.payloads.checkpoint);
761*14675a02SAndroid Build Coastguard Worker   if (!checkpoint_input_filename.ok()) {
762*14675a02SAndroid Build Coastguard Worker     auto status = checkpoint_input_filename.status();
763*14675a02SAndroid Build Coastguard Worker     auto message = absl::StrCat(
764*14675a02SAndroid Build Coastguard Worker         "Failed to create eligibility eval checkpoint input file: code: ",
765*14675a02SAndroid Build Coastguard Worker         status.code(), ", message: ", status.message());
766*14675a02SAndroid Build Coastguard Worker     phase_logger.LogEligibilityEvalCheckinIOError(status, network_stats,
767*14675a02SAndroid Build Coastguard Worker                                                   time_before_plan_download);
768*14675a02SAndroid Build Coastguard Worker     FCP_LOG(ERROR) << message;
769*14675a02SAndroid Build Coastguard Worker     return absl::InternalError("");
770*14675a02SAndroid Build Coastguard Worker   }
771*14675a02SAndroid Build Coastguard Worker   phase_logger.LogEligibilityEvalCheckinCompleted(network_stats,
772*14675a02SAndroid Build Coastguard Worker                                                   /*time_before_checkin=*/
773*14675a02SAndroid Build Coastguard Worker                                                   time_before_checkin,
774*14675a02SAndroid Build Coastguard Worker                                                   /*time_before_plan_download=*/
775*14675a02SAndroid Build Coastguard Worker                                                   time_before_plan_download);
776*14675a02SAndroid Build Coastguard Worker   absl::Time run_plan_start_time = absl::Now();
777*14675a02SAndroid Build Coastguard Worker   phase_logger.LogEligibilityEvalComputationStarted();
778*14675a02SAndroid Build Coastguard Worker   engine::PlanResult plan_result = RunEligibilityEvalPlanWithTensorflowSpec(
779*14675a02SAndroid Build Coastguard Worker       example_iterator_factories, should_abort, log_manager, opstats_logger,
780*14675a02SAndroid Build Coastguard Worker       flags, plan, *checkpoint_input_filename, timing_config,
781*14675a02SAndroid Build Coastguard Worker       run_plan_start_time, reference_time);
782*14675a02SAndroid Build Coastguard Worker   absl::StatusOr<TaskEligibilityInfo> task_eligibility_info;
783*14675a02SAndroid Build Coastguard Worker   if (plan_result.outcome == engine::PlanOutcome::kSuccess) {
784*14675a02SAndroid Build Coastguard Worker     task_eligibility_info =
785*14675a02SAndroid Build Coastguard Worker         ParseEligibilityEvalPlanOutput(plan_result.output_tensors);
786*14675a02SAndroid Build Coastguard Worker   }
787*14675a02SAndroid Build Coastguard Worker   LogEligibilityEvalComputationOutcome(phase_logger, std::move(plan_result),
788*14675a02SAndroid Build Coastguard Worker                                        task_eligibility_info.status(),
789*14675a02SAndroid Build Coastguard Worker                                        run_plan_start_time, reference_time);
790*14675a02SAndroid Build Coastguard Worker   return task_eligibility_info;
791*14675a02SAndroid Build Coastguard Worker }
792*14675a02SAndroid Build Coastguard Worker struct EligibilityEvalResult {
793*14675a02SAndroid Build Coastguard Worker   std::optional<TaskEligibilityInfo> task_eligibility_info;
794*14675a02SAndroid Build Coastguard Worker   std::vector<std::string> task_names_for_multiple_task_assignments;
795*14675a02SAndroid Build Coastguard Worker };
796*14675a02SAndroid Build Coastguard Worker // Create an EligibilityEvalResult from a TaskEligibilityInfo and
797*14675a02SAndroid Build Coastguard Worker // PopulationEligibilitySpec.  If both population_spec and
798*14675a02SAndroid Build Coastguard Worker // task_eligibility_info are present, the returned EligibilityEvalResult will
799*14675a02SAndroid Build Coastguard Worker // contain a TaskEligibilityInfo which only contains the tasks for single task
800*14675a02SAndroid Build Coastguard Worker // assignment, and a vector of task names for multiple task assignment.
CreateEligibilityEvalResult(const std::optional<TaskEligibilityInfo> & task_eligibility_info,const std::optional<PopulationEligibilitySpec> & population_spec)801*14675a02SAndroid Build Coastguard Worker EligibilityEvalResult CreateEligibilityEvalResult(
802*14675a02SAndroid Build Coastguard Worker     const std::optional<TaskEligibilityInfo>& task_eligibility_info,
803*14675a02SAndroid Build Coastguard Worker     const std::optional<PopulationEligibilitySpec>& population_spec) {
804*14675a02SAndroid Build Coastguard Worker   EligibilityEvalResult result;
805*14675a02SAndroid Build Coastguard Worker   if (population_spec.has_value() && task_eligibility_info.has_value()) {
806*14675a02SAndroid Build Coastguard Worker     absl::flat_hash_set<std::string> task_names_for_multiple_task_assignments;
807*14675a02SAndroid Build Coastguard Worker     for (const auto& task_info : population_spec.value().task_info()) {
808*14675a02SAndroid Build Coastguard Worker       if (task_info.task_assignment_mode() ==
809*14675a02SAndroid Build Coastguard Worker           PopulationEligibilitySpec::TaskInfo::TASK_ASSIGNMENT_MODE_MULTIPLE) {
810*14675a02SAndroid Build Coastguard Worker         task_names_for_multiple_task_assignments.insert(task_info.task_name());
811*14675a02SAndroid Build Coastguard Worker       }
812*14675a02SAndroid Build Coastguard Worker     }
813*14675a02SAndroid Build Coastguard Worker     TaskEligibilityInfo single_task_assignment_eligibility_info;
814*14675a02SAndroid Build Coastguard Worker     single_task_assignment_eligibility_info.set_version(
815*14675a02SAndroid Build Coastguard Worker         task_eligibility_info.value().version());
816*14675a02SAndroid Build Coastguard Worker     for (const auto& task_weight :
817*14675a02SAndroid Build Coastguard Worker          task_eligibility_info.value().task_weights()) {
818*14675a02SAndroid Build Coastguard Worker       if (task_names_for_multiple_task_assignments.contains(
819*14675a02SAndroid Build Coastguard Worker               task_weight.task_name())) {
820*14675a02SAndroid Build Coastguard Worker         result.task_names_for_multiple_task_assignments.push_back(
821*14675a02SAndroid Build Coastguard Worker             task_weight.task_name());
822*14675a02SAndroid Build Coastguard Worker       } else {
823*14675a02SAndroid Build Coastguard Worker         *single_task_assignment_eligibility_info.mutable_task_weights()->Add() =
824*14675a02SAndroid Build Coastguard Worker             task_weight;
825*14675a02SAndroid Build Coastguard Worker       }
826*14675a02SAndroid Build Coastguard Worker     }
827*14675a02SAndroid Build Coastguard Worker     result.task_eligibility_info = single_task_assignment_eligibility_info;
828*14675a02SAndroid Build Coastguard Worker   } else {
829*14675a02SAndroid Build Coastguard Worker     result.task_eligibility_info = task_eligibility_info;
830*14675a02SAndroid Build Coastguard Worker   }
831*14675a02SAndroid Build Coastguard Worker   return result;
832*14675a02SAndroid Build Coastguard Worker }
833*14675a02SAndroid Build Coastguard Worker // Issues an eligibility eval checkin request and executes the eligibility
834*14675a02SAndroid Build Coastguard Worker // eval task if the server returns one.
835*14675a02SAndroid Build Coastguard Worker //
836*14675a02SAndroid Build Coastguard Worker // This function modifies the FLRunnerResult with values received over the
837*14675a02SAndroid Build Coastguard Worker // course of the eligibility eval protocol interaction.
838*14675a02SAndroid Build Coastguard Worker //
839*14675a02SAndroid Build Coastguard Worker // Returns:
840*14675a02SAndroid Build Coastguard Worker // - the TaskEligibilityInfo produced by the eligibility eval task, if the
841*14675a02SAndroid Build Coastguard Worker //   server provided an eligibility eval task to run.
842*14675a02SAndroid Build Coastguard Worker // - an std::nullopt if the server indicated that there is no eligibility eval
843*14675a02SAndroid Build Coastguard Worker //   task configured for the population.
844*14675a02SAndroid Build Coastguard Worker // - an INTERNAL error if the server rejects the client or another error
845*14675a02SAndroid Build Coastguard Worker // occurs
846*14675a02SAndroid Build Coastguard Worker //   that should abort the training run. The error will already have been
847*14675a02SAndroid Build Coastguard Worker //   logged appropriately.
IssueEligibilityEvalCheckinAndRunPlan(std::vector<engine::ExampleIteratorFactory * > example_iterator_factories,std::function<bool ()> should_abort,PhaseLogger & phase_logger,Files * files,LogManager * log_manager,OpStatsLogger * opstats_logger,const Flags * flags,FederatedProtocol * federated_protocol,const fcp::client::InterruptibleRunner::TimingConfig & timing_config,const absl::Time reference_time,FLRunnerResult & fl_runner_result)848*14675a02SAndroid Build Coastguard Worker absl::StatusOr<EligibilityEvalResult> IssueEligibilityEvalCheckinAndRunPlan(
849*14675a02SAndroid Build Coastguard Worker     std::vector<engine::ExampleIteratorFactory*> example_iterator_factories,
850*14675a02SAndroid Build Coastguard Worker     std::function<bool()> should_abort, PhaseLogger& phase_logger, Files* files,
851*14675a02SAndroid Build Coastguard Worker     LogManager* log_manager, OpStatsLogger* opstats_logger, const Flags* flags,
852*14675a02SAndroid Build Coastguard Worker     FederatedProtocol* federated_protocol,
853*14675a02SAndroid Build Coastguard Worker     const fcp::client::InterruptibleRunner::TimingConfig& timing_config,
854*14675a02SAndroid Build Coastguard Worker     const absl::Time reference_time, FLRunnerResult& fl_runner_result) {
855*14675a02SAndroid Build Coastguard Worker   const absl::Time time_before_checkin = absl::Now();
856*14675a02SAndroid Build Coastguard Worker   const NetworkStats network_stats_before_checkin =
857*14675a02SAndroid Build Coastguard Worker       GetCumulativeNetworkStats(federated_protocol,
858*14675a02SAndroid Build Coastguard Worker                                 /*fedselect_manager=*/nullptr);
859*14675a02SAndroid Build Coastguard Worker   // These fields will, after a successful checkin that resulted in an EET
860*14675a02SAndroid Build Coastguard Worker   // being received, contain the time at which the EET plan/checkpoint URIs
861*14675a02SAndroid Build Coastguard Worker   // were received (but not yet downloaded), as well as the cumulative network
862*14675a02SAndroid Build Coastguard Worker   // stats at that point, allowing us to separately calculate how long it took
863*14675a02SAndroid Build Coastguard Worker   // to then download the actual payloads.
864*14675a02SAndroid Build Coastguard Worker   absl::Time time_before_plan_download = time_before_checkin;
865*14675a02SAndroid Build Coastguard Worker   NetworkStats network_stats_before_plan_download =
866*14675a02SAndroid Build Coastguard Worker       network_stats_before_checkin;
867*14675a02SAndroid Build Coastguard Worker   // Log that we are about to check in with the server.
868*14675a02SAndroid Build Coastguard Worker   phase_logger.LogEligibilityEvalCheckinStarted();
869*14675a02SAndroid Build Coastguard Worker   // Issue the eligibility eval checkin request (providing a callback that
870*14675a02SAndroid Build Coastguard Worker   // will be called when an EET is assigned to the task but before its
871*14675a02SAndroid Build Coastguard Worker   // plan/checkpoint URIs have actually been downloaded).
872*14675a02SAndroid Build Coastguard Worker   bool plan_uris_received_callback_called = false;
873*14675a02SAndroid Build Coastguard Worker   std::function<void(const FederatedProtocol::EligibilityEvalTask&)>
874*14675a02SAndroid Build Coastguard Worker       plan_uris_received_callback =
875*14675a02SAndroid Build Coastguard Worker           [&time_before_plan_download, &network_stats_before_plan_download,
876*14675a02SAndroid Build Coastguard Worker            &time_before_checkin, &network_stats_before_checkin,
877*14675a02SAndroid Build Coastguard Worker            &federated_protocol, &phase_logger,
878*14675a02SAndroid Build Coastguard Worker            &plan_uris_received_callback_called](
879*14675a02SAndroid Build Coastguard Worker               const FederatedProtocol::EligibilityEvalTask& task) {
880*14675a02SAndroid Build Coastguard Worker             // When the plan URIs have been received, we already know the name
881*14675a02SAndroid Build Coastguard Worker             // of the task we have been assigned, so let's tell the
882*14675a02SAndroid Build Coastguard Worker             // PhaseLogger.
883*14675a02SAndroid Build Coastguard Worker             phase_logger.SetModelIdentifier(task.execution_id);
884*14675a02SAndroid Build Coastguard Worker             // We also should log a corresponding log event.
885*14675a02SAndroid Build Coastguard Worker             phase_logger.LogEligibilityEvalCheckinPlanUriReceived(
886*14675a02SAndroid Build Coastguard Worker                 GetNetworkStatsSince(federated_protocol,
887*14675a02SAndroid Build Coastguard Worker                                      /*fedselect_manager=*/nullptr,
888*14675a02SAndroid Build Coastguard Worker                                      network_stats_before_checkin),
889*14675a02SAndroid Build Coastguard Worker                 time_before_checkin);
890*14675a02SAndroid Build Coastguard Worker             // And we must take a snapshot of the current time & network
891*14675a02SAndroid Build Coastguard Worker             // stats, so we can distinguish between the duration/network stats
892*14675a02SAndroid Build Coastguard Worker             // incurred for the checkin request vs. the actual downloading of
893*14675a02SAndroid Build Coastguard Worker             // the plan/checkpoint resources.
894*14675a02SAndroid Build Coastguard Worker             time_before_plan_download = absl::Now();
895*14675a02SAndroid Build Coastguard Worker             network_stats_before_plan_download =
896*14675a02SAndroid Build Coastguard Worker                 GetCumulativeNetworkStats(federated_protocol,
897*14675a02SAndroid Build Coastguard Worker                                           /*fedselect_manager=*/nullptr);
898*14675a02SAndroid Build Coastguard Worker             plan_uris_received_callback_called = true;
899*14675a02SAndroid Build Coastguard Worker           };
900*14675a02SAndroid Build Coastguard Worker   absl::StatusOr<FederatedProtocol::EligibilityEvalCheckinResult>
901*14675a02SAndroid Build Coastguard Worker       eligibility_checkin_result = federated_protocol->EligibilityEvalCheckin(
902*14675a02SAndroid Build Coastguard Worker           plan_uris_received_callback);
903*14675a02SAndroid Build Coastguard Worker   UpdateRetryWindowAndNetworkStats(*federated_protocol,
904*14675a02SAndroid Build Coastguard Worker                                    /*fedselect_manager=*/nullptr, phase_logger,
905*14675a02SAndroid Build Coastguard Worker                                    fl_runner_result);
906*14675a02SAndroid Build Coastguard Worker   // It's a bit unfortunate that we have to inspect the checkin_result and
907*14675a02SAndroid Build Coastguard Worker   // extract the model identifier here rather than further down the function,
908*14675a02SAndroid Build Coastguard Worker   // but this ensures that the histograms below will have the right model
909*14675a02SAndroid Build Coastguard Worker   // identifier attached (and we want to also emit the histograms even if we
910*14675a02SAndroid Build Coastguard Worker   // have failed/rejected checkin outcomes).
911*14675a02SAndroid Build Coastguard Worker   if (eligibility_checkin_result.ok() &&
912*14675a02SAndroid Build Coastguard Worker       std::holds_alternative<FederatedProtocol::EligibilityEvalTask>(
913*14675a02SAndroid Build Coastguard Worker           *eligibility_checkin_result)) {
914*14675a02SAndroid Build Coastguard Worker     // Make sure that if we received an EligibilityEvalTask, then the callback
915*14675a02SAndroid Build Coastguard Worker     // should have already been called by this point by the protocol (ensuring
916*14675a02SAndroid Build Coastguard Worker     // that SetModelIdentifier has been called etc.).
917*14675a02SAndroid Build Coastguard Worker     FCP_CHECK(plan_uris_received_callback_called);
918*14675a02SAndroid Build Coastguard Worker   }
919*14675a02SAndroid Build Coastguard Worker   if (!eligibility_checkin_result.ok()) {
920*14675a02SAndroid Build Coastguard Worker     auto status = eligibility_checkin_result.status();
921*14675a02SAndroid Build Coastguard Worker     auto message = absl::StrCat("Error during eligibility eval checkin: code: ",
922*14675a02SAndroid Build Coastguard Worker                                 status.code(), ", message: ", status.message());
923*14675a02SAndroid Build Coastguard Worker     if (status.code() == absl::StatusCode::kAborted) {
924*14675a02SAndroid Build Coastguard Worker       phase_logger.LogEligibilityEvalCheckinServerAborted(
925*14675a02SAndroid Build Coastguard Worker           status,
926*14675a02SAndroid Build Coastguard Worker           GetNetworkStatsSince(federated_protocol,
927*14675a02SAndroid Build Coastguard Worker                                /*fedselect_manager=*/nullptr,
928*14675a02SAndroid Build Coastguard Worker                                network_stats_before_plan_download),
929*14675a02SAndroid Build Coastguard Worker           time_before_plan_download);
930*14675a02SAndroid Build Coastguard Worker     } else if (status.code() == absl::StatusCode::kCancelled) {
931*14675a02SAndroid Build Coastguard Worker       phase_logger.LogEligibilityEvalCheckinClientInterrupted(
932*14675a02SAndroid Build Coastguard Worker           status,
933*14675a02SAndroid Build Coastguard Worker           GetNetworkStatsSince(federated_protocol,
934*14675a02SAndroid Build Coastguard Worker                                /*fedselect_manager=*/nullptr,
935*14675a02SAndroid Build Coastguard Worker                                network_stats_before_plan_download),
936*14675a02SAndroid Build Coastguard Worker           time_before_plan_download);
937*14675a02SAndroid Build Coastguard Worker     } else if (!status.ok()) {
938*14675a02SAndroid Build Coastguard Worker       phase_logger.LogEligibilityEvalCheckinIOError(
939*14675a02SAndroid Build Coastguard Worker           status,
940*14675a02SAndroid Build Coastguard Worker           GetNetworkStatsSince(federated_protocol,
941*14675a02SAndroid Build Coastguard Worker                                /*fedselect_manager=*/nullptr,
942*14675a02SAndroid Build Coastguard Worker                                network_stats_before_plan_download),
943*14675a02SAndroid Build Coastguard Worker           time_before_plan_download);
944*14675a02SAndroid Build Coastguard Worker     }
945*14675a02SAndroid Build Coastguard Worker     FCP_LOG(INFO) << message;
946*14675a02SAndroid Build Coastguard Worker     return absl::InternalError("");
947*14675a02SAndroid Build Coastguard Worker   }
948*14675a02SAndroid Build Coastguard Worker   EligibilityEvalResult result;
949*14675a02SAndroid Build Coastguard Worker   if (std::holds_alternative<FederatedProtocol::Rejection>(
950*14675a02SAndroid Build Coastguard Worker           *eligibility_checkin_result)) {
951*14675a02SAndroid Build Coastguard Worker     phase_logger.LogEligibilityEvalCheckinTurnedAway(
952*14675a02SAndroid Build Coastguard Worker         GetNetworkStatsSince(federated_protocol,
953*14675a02SAndroid Build Coastguard Worker                              /*fedselect_manager=*/nullptr,
954*14675a02SAndroid Build Coastguard Worker                              network_stats_before_checkin),
955*14675a02SAndroid Build Coastguard Worker         time_before_checkin);
956*14675a02SAndroid Build Coastguard Worker     // If the server explicitly rejected our request, then we must abort and
957*14675a02SAndroid Build Coastguard Worker     // we must not proceed to the "checkin" phase below.
958*14675a02SAndroid Build Coastguard Worker     FCP_LOG(INFO) << "Device rejected by server during eligibility eval "
959*14675a02SAndroid Build Coastguard Worker                      "checkin; aborting";
960*14675a02SAndroid Build Coastguard Worker     return absl::InternalError("");
961*14675a02SAndroid Build Coastguard Worker   } else if (std::holds_alternative<FederatedProtocol::EligibilityEvalDisabled>(
962*14675a02SAndroid Build Coastguard Worker                  *eligibility_checkin_result)) {
963*14675a02SAndroid Build Coastguard Worker     phase_logger.LogEligibilityEvalNotConfigured(
964*14675a02SAndroid Build Coastguard Worker         GetNetworkStatsSince(federated_protocol,
965*14675a02SAndroid Build Coastguard Worker                              /*fedselect_manager=*/nullptr,
966*14675a02SAndroid Build Coastguard Worker                              network_stats_before_checkin),
967*14675a02SAndroid Build Coastguard Worker         time_before_checkin);
968*14675a02SAndroid Build Coastguard Worker     // If the server indicates that no eligibility eval task is configured for
969*14675a02SAndroid Build Coastguard Worker     // the population then there is nothing more to do. We simply proceed to
970*14675a02SAndroid Build Coastguard Worker     // the "checkin" phase below without providing it a TaskEligibilityInfo
971*14675a02SAndroid Build Coastguard Worker     // proto.
972*14675a02SAndroid Build Coastguard Worker     result.task_eligibility_info = std::nullopt;
973*14675a02SAndroid Build Coastguard Worker     return result;
974*14675a02SAndroid Build Coastguard Worker   }
975*14675a02SAndroid Build Coastguard Worker   auto eligibility_eval_task =
976*14675a02SAndroid Build Coastguard Worker       absl::get<FederatedProtocol::EligibilityEvalTask>(
977*14675a02SAndroid Build Coastguard Worker           *eligibility_checkin_result);
978*14675a02SAndroid Build Coastguard Worker   // Parse and run the eligibility eval task if the server returned one.
979*14675a02SAndroid Build Coastguard Worker   // Now we have a EligibilityEvalTask, if an error happens, we will report to
980*14675a02SAndroid Build Coastguard Worker   // the server via the ReportEligibilityEvalError.
981*14675a02SAndroid Build Coastguard Worker   absl::StatusOr<std::optional<TaskEligibilityInfo>> task_eligibility_info =
982*14675a02SAndroid Build Coastguard Worker       RunEligibilityEvalPlan(
983*14675a02SAndroid Build Coastguard Worker           eligibility_eval_task, example_iterator_factories, should_abort,
984*14675a02SAndroid Build Coastguard Worker           phase_logger, files, log_manager, opstats_logger, flags,
985*14675a02SAndroid Build Coastguard Worker           federated_protocol, timing_config, reference_time,
986*14675a02SAndroid Build Coastguard Worker           /*time_before_checkin=*/time_before_checkin,
987*14675a02SAndroid Build Coastguard Worker           /*time_before_plan_download=*/time_before_plan_download,
988*14675a02SAndroid Build Coastguard Worker           GetNetworkStatsSince(federated_protocol,
989*14675a02SAndroid Build Coastguard Worker                                /*fedselect_manager=*/nullptr,
990*14675a02SAndroid Build Coastguard Worker                                network_stats_before_plan_download));
991*14675a02SAndroid Build Coastguard Worker   if (!task_eligibility_info.ok()) {
992*14675a02SAndroid Build Coastguard Worker     // Note that none of the PhaseLogger methods will reflect the very little
993*14675a02SAndroid Build Coastguard Worker     // amount of network usage the will be incurred by this protocol request.
994*14675a02SAndroid Build Coastguard Worker     // We consider this to be OK to keep things simple, and because this
995*14675a02SAndroid Build Coastguard Worker     // should use such a limited amount of network bandwidth. Do note that the
996*14675a02SAndroid Build Coastguard Worker     // network usage *will* be correctly reported in the OpStats database.
997*14675a02SAndroid Build Coastguard Worker     federated_protocol->ReportEligibilityEvalError(
998*14675a02SAndroid Build Coastguard Worker         absl::Status(task_eligibility_info.status().code(),
999*14675a02SAndroid Build Coastguard Worker                      "Failed to compute eligibility info"));
1000*14675a02SAndroid Build Coastguard Worker     UpdateRetryWindowAndNetworkStats(*federated_protocol,
1001*14675a02SAndroid Build Coastguard Worker                                      /*fedselect_manager=*/nullptr,
1002*14675a02SAndroid Build Coastguard Worker                                      phase_logger, fl_runner_result);
1003*14675a02SAndroid Build Coastguard Worker     return task_eligibility_info.status();
1004*14675a02SAndroid Build Coastguard Worker   }
1005*14675a02SAndroid Build Coastguard Worker   return CreateEligibilityEvalResult(
1006*14675a02SAndroid Build Coastguard Worker       *task_eligibility_info,
1007*14675a02SAndroid Build Coastguard Worker       eligibility_eval_task.population_eligibility_spec);
1008*14675a02SAndroid Build Coastguard Worker }
1009*14675a02SAndroid Build Coastguard Worker struct CheckinResult {
1010*14675a02SAndroid Build Coastguard Worker   std::string task_name;
1011*14675a02SAndroid Build Coastguard Worker   ClientOnlyPlan plan;
1012*14675a02SAndroid Build Coastguard Worker   int32_t minimum_clients_in_server_visible_aggregate;
1013*14675a02SAndroid Build Coastguard Worker   std::string checkpoint_input_filename;
1014*14675a02SAndroid Build Coastguard Worker   std::string computation_id;
1015*14675a02SAndroid Build Coastguard Worker   std::string federated_select_uri_template;
1016*14675a02SAndroid Build Coastguard Worker };
IssueCheckin(PhaseLogger & phase_logger,LogManager * log_manager,Files * files,FederatedProtocol * federated_protocol,std::optional<TaskEligibilityInfo> task_eligibility_info,absl::Time reference_time,const std::string & population_name,FLRunnerResult & fl_runner_result,const Flags * flags)1017*14675a02SAndroid Build Coastguard Worker absl::StatusOr<CheckinResult> IssueCheckin(
1018*14675a02SAndroid Build Coastguard Worker     PhaseLogger& phase_logger, LogManager* log_manager, Files* files,
1019*14675a02SAndroid Build Coastguard Worker     FederatedProtocol* federated_protocol,
1020*14675a02SAndroid Build Coastguard Worker     std::optional<TaskEligibilityInfo> task_eligibility_info,
1021*14675a02SAndroid Build Coastguard Worker     absl::Time reference_time, const std::string& population_name,
1022*14675a02SAndroid Build Coastguard Worker     FLRunnerResult& fl_runner_result, const Flags* flags) {
1023*14675a02SAndroid Build Coastguard Worker   absl::Time time_before_checkin = absl::Now();
1024*14675a02SAndroid Build Coastguard Worker   // We must return only stats that cover the check in phase for the log
1025*14675a02SAndroid Build Coastguard Worker   // events below.
1026*14675a02SAndroid Build Coastguard Worker   const NetworkStats network_stats_before_checkin =
1027*14675a02SAndroid Build Coastguard Worker       GetCumulativeNetworkStats(federated_protocol,
1028*14675a02SAndroid Build Coastguard Worker                                 /*fedselect_manager=*/nullptr);
1029*14675a02SAndroid Build Coastguard Worker   // These fields will, after a successful checkin that resulted in a task
1030*14675a02SAndroid Build Coastguard Worker   // being assigned, contain the time at which the task plan/checkpoint URIs
1031*14675a02SAndroid Build Coastguard Worker   // were received (but not yet downloaded), as well as the cumulative network
1032*14675a02SAndroid Build Coastguard Worker   // stats at that point, allowing us to separately calculate how long it took
1033*14675a02SAndroid Build Coastguard Worker   // to then download the actual payloads.
1034*14675a02SAndroid Build Coastguard Worker   absl::Time time_before_plan_download = time_before_checkin;
1035*14675a02SAndroid Build Coastguard Worker   NetworkStats network_stats_before_plan_download =
1036*14675a02SAndroid Build Coastguard Worker       network_stats_before_checkin;
1037*14675a02SAndroid Build Coastguard Worker   // Clear the model identifier before check-in, to ensure that the any prior
1038*14675a02SAndroid Build Coastguard Worker   // eligibility eval task name isn't used any longer.
1039*14675a02SAndroid Build Coastguard Worker   phase_logger.SetModelIdentifier("");
1040*14675a02SAndroid Build Coastguard Worker   phase_logger.LogCheckinStarted();
1041*14675a02SAndroid Build Coastguard Worker   std::string task_name;
1042*14675a02SAndroid Build Coastguard Worker   // Issue the checkin request (providing a callback that will be called when
1043*14675a02SAndroid Build Coastguard Worker   // an EET is assigned to the task but before its plan/checkpoint URIs have
1044*14675a02SAndroid Build Coastguard Worker   // actually been downloaded).
1045*14675a02SAndroid Build Coastguard Worker   bool plan_uris_received_callback_called = false;
1046*14675a02SAndroid Build Coastguard Worker   std::function<void(const FederatedProtocol::TaskAssignment&)>
1047*14675a02SAndroid Build Coastguard Worker       plan_uris_received_callback =
1048*14675a02SAndroid Build Coastguard Worker           [&time_before_plan_download, &network_stats_before_plan_download,
1049*14675a02SAndroid Build Coastguard Worker            &time_before_checkin, &network_stats_before_checkin, &task_name,
1050*14675a02SAndroid Build Coastguard Worker            &federated_protocol, &population_name, &log_manager, &phase_logger,
1051*14675a02SAndroid Build Coastguard Worker            &plan_uris_received_callback_called](
1052*14675a02SAndroid Build Coastguard Worker               const FederatedProtocol::TaskAssignment& task_assignment) {
1053*14675a02SAndroid Build Coastguard Worker             // When the plan URIs have been received, we already know the name
1054*14675a02SAndroid Build Coastguard Worker             // of the task we have been assigned, so let's tell the
1055*14675a02SAndroid Build Coastguard Worker             // PhaseLogger.
1056*14675a02SAndroid Build Coastguard Worker             auto model_identifier = task_assignment.aggregation_session_id;
1057*14675a02SAndroid Build Coastguard Worker             phase_logger.SetModelIdentifier(model_identifier);
1058*14675a02SAndroid Build Coastguard Worker             // We also should log a corresponding log event.
1059*14675a02SAndroid Build Coastguard Worker             task_name = ExtractTaskNameFromAggregationSessionId(
1060*14675a02SAndroid Build Coastguard Worker                 model_identifier, population_name, *log_manager);
1061*14675a02SAndroid Build Coastguard Worker             phase_logger.LogCheckinPlanUriReceived(
1062*14675a02SAndroid Build Coastguard Worker                 task_name,
1063*14675a02SAndroid Build Coastguard Worker                 GetNetworkStatsSince(federated_protocol,
1064*14675a02SAndroid Build Coastguard Worker                                      /*fedselect_manager=*/nullptr,
1065*14675a02SAndroid Build Coastguard Worker                                      network_stats_before_checkin),
1066*14675a02SAndroid Build Coastguard Worker                 time_before_checkin);
1067*14675a02SAndroid Build Coastguard Worker             // And we must take a snapshot of the current time & network
1068*14675a02SAndroid Build Coastguard Worker             // stats, so we can distinguish between the duration/network stats
1069*14675a02SAndroid Build Coastguard Worker             // incurred for the checkin request vs. the actual downloading of
1070*14675a02SAndroid Build Coastguard Worker             // the plan/checkpoint resources.
1071*14675a02SAndroid Build Coastguard Worker             time_before_plan_download = absl::Now();
1072*14675a02SAndroid Build Coastguard Worker             network_stats_before_plan_download = GetCumulativeNetworkStats(
1073*14675a02SAndroid Build Coastguard Worker                 federated_protocol, /*fedselect_manager=*/nullptr);
1074*14675a02SAndroid Build Coastguard Worker             plan_uris_received_callback_called = true;
1075*14675a02SAndroid Build Coastguard Worker           };
1076*14675a02SAndroid Build Coastguard Worker   absl::StatusOr<FederatedProtocol::CheckinResult> checkin_result =
1077*14675a02SAndroid Build Coastguard Worker       federated_protocol->Checkin(task_eligibility_info,
1078*14675a02SAndroid Build Coastguard Worker                                   plan_uris_received_callback);
1079*14675a02SAndroid Build Coastguard Worker   UpdateRetryWindowAndNetworkStats(*federated_protocol,
1080*14675a02SAndroid Build Coastguard Worker                                    /*fedselect_manager=*/nullptr, phase_logger,
1081*14675a02SAndroid Build Coastguard Worker                                    fl_runner_result);
1082*14675a02SAndroid Build Coastguard Worker   // It's a bit unfortunate that we have to inspect the checkin_result and
1083*14675a02SAndroid Build Coastguard Worker   // extract the model identifier here rather than further down the function,
1084*14675a02SAndroid Build Coastguard Worker   // but this ensures that the histograms below will have the right model
1085*14675a02SAndroid Build Coastguard Worker   // identifier attached (and we want to also emit the histograms even if we
1086*14675a02SAndroid Build Coastguard Worker   // have failed/rejected checkin outcomes).
1087*14675a02SAndroid Build Coastguard Worker   if (checkin_result.ok() &&
1088*14675a02SAndroid Build Coastguard Worker       std::holds_alternative<FederatedProtocol::TaskAssignment>(
1089*14675a02SAndroid Build Coastguard Worker           *checkin_result)) {
1090*14675a02SAndroid Build Coastguard Worker     // Make sure that if we received a TaskAssignment, then the callback
1091*14675a02SAndroid Build Coastguard Worker     // should have already been called by this point by the protocol (ensuring
1092*14675a02SAndroid Build Coastguard Worker     // that SetModelIdentifier has been called etc.).
1093*14675a02SAndroid Build Coastguard Worker     FCP_CHECK(plan_uris_received_callback_called);
1094*14675a02SAndroid Build Coastguard Worker   }
1095*14675a02SAndroid Build Coastguard Worker   if (!checkin_result.ok()) {
1096*14675a02SAndroid Build Coastguard Worker     auto status = checkin_result.status();
1097*14675a02SAndroid Build Coastguard Worker     auto message = absl::StrCat("Error during checkin: code: ", status.code(),
1098*14675a02SAndroid Build Coastguard Worker                                 ", message: ", status.message());
1099*14675a02SAndroid Build Coastguard Worker     if (status.code() == absl::StatusCode::kAborted) {
1100*14675a02SAndroid Build Coastguard Worker       phase_logger.LogCheckinServerAborted(
1101*14675a02SAndroid Build Coastguard Worker           status,
1102*14675a02SAndroid Build Coastguard Worker           GetNetworkStatsSince(federated_protocol,
1103*14675a02SAndroid Build Coastguard Worker                                /*fedselect_manager=*/nullptr,
1104*14675a02SAndroid Build Coastguard Worker                                network_stats_before_plan_download),
1105*14675a02SAndroid Build Coastguard Worker           time_before_plan_download, reference_time);
1106*14675a02SAndroid Build Coastguard Worker     } else if (status.code() == absl::StatusCode::kCancelled) {
1107*14675a02SAndroid Build Coastguard Worker       phase_logger.LogCheckinClientInterrupted(
1108*14675a02SAndroid Build Coastguard Worker           status,
1109*14675a02SAndroid Build Coastguard Worker           GetNetworkStatsSince(federated_protocol,
1110*14675a02SAndroid Build Coastguard Worker                                /*fedselect_manager=*/nullptr,
1111*14675a02SAndroid Build Coastguard Worker                                network_stats_before_plan_download),
1112*14675a02SAndroid Build Coastguard Worker           time_before_plan_download, reference_time);
1113*14675a02SAndroid Build Coastguard Worker     } else if (!status.ok()) {
1114*14675a02SAndroid Build Coastguard Worker       phase_logger.LogCheckinIOError(
1115*14675a02SAndroid Build Coastguard Worker           status,
1116*14675a02SAndroid Build Coastguard Worker           GetNetworkStatsSince(federated_protocol,
1117*14675a02SAndroid Build Coastguard Worker                                /*fedselect_manager=*/nullptr,
1118*14675a02SAndroid Build Coastguard Worker                                network_stats_before_plan_download),
1119*14675a02SAndroid Build Coastguard Worker           time_before_plan_download, reference_time);
1120*14675a02SAndroid Build Coastguard Worker     }
1121*14675a02SAndroid Build Coastguard Worker     FCP_LOG(INFO) << message;
1122*14675a02SAndroid Build Coastguard Worker     return status;
1123*14675a02SAndroid Build Coastguard Worker   }
1124*14675a02SAndroid Build Coastguard Worker   // Server rejected us? Return the fl_runner_results as-is.
1125*14675a02SAndroid Build Coastguard Worker   if (std::holds_alternative<FederatedProtocol::Rejection>(*checkin_result)) {
1126*14675a02SAndroid Build Coastguard Worker     phase_logger.LogCheckinTurnedAway(
1127*14675a02SAndroid Build Coastguard Worker         GetNetworkStatsSince(federated_protocol,
1128*14675a02SAndroid Build Coastguard Worker                              /*fedselect_manager=*/nullptr,
1129*14675a02SAndroid Build Coastguard Worker                              network_stats_before_checkin),
1130*14675a02SAndroid Build Coastguard Worker         time_before_checkin, reference_time);
1131*14675a02SAndroid Build Coastguard Worker     FCP_LOG(INFO) << "Device rejected by server during checkin; aborting";
1132*14675a02SAndroid Build Coastguard Worker     return absl::InternalError("Device rejected by server.");
1133*14675a02SAndroid Build Coastguard Worker   }
1134*14675a02SAndroid Build Coastguard Worker   auto task_assignment =
1135*14675a02SAndroid Build Coastguard Worker       absl::get<FederatedProtocol::TaskAssignment>(*checkin_result);
1136*14675a02SAndroid Build Coastguard Worker   ClientOnlyPlan plan;
1137*14675a02SAndroid Build Coastguard Worker   auto plan_bytes = task_assignment.payloads.plan;
1138*14675a02SAndroid Build Coastguard Worker   if (!ParseFromStringOrCord(plan, plan_bytes)) {
1139*14675a02SAndroid Build Coastguard Worker     auto message = "Failed to parse received plan";
1140*14675a02SAndroid Build Coastguard Worker     phase_logger.LogCheckinInvalidPayload(
1141*14675a02SAndroid Build Coastguard Worker         message,
1142*14675a02SAndroid Build Coastguard Worker         GetNetworkStatsSince(federated_protocol,
1143*14675a02SAndroid Build Coastguard Worker                              /*fedselect_manager=*/nullptr,
1144*14675a02SAndroid Build Coastguard Worker                              network_stats_before_plan_download),
1145*14675a02SAndroid Build Coastguard Worker         time_before_plan_download, reference_time);
1146*14675a02SAndroid Build Coastguard Worker     FCP_LOG(ERROR) << message;
1147*14675a02SAndroid Build Coastguard Worker     return absl::InternalError("");
1148*14675a02SAndroid Build Coastguard Worker   }
1149*14675a02SAndroid Build Coastguard Worker   std::string computation_id;
1150*14675a02SAndroid Build Coastguard Worker   if (flags->enable_computation_id()) {
1151*14675a02SAndroid Build Coastguard Worker     computation_id = ComputeSHA256FromStringOrCord(plan_bytes);
1152*14675a02SAndroid Build Coastguard Worker   }
1153*14675a02SAndroid Build Coastguard Worker   int32_t minimum_clients_in_server_visible_aggregate = 0;
1154*14675a02SAndroid Build Coastguard Worker   if (task_assignment.sec_agg_info.has_value()) {
1155*14675a02SAndroid Build Coastguard Worker     auto minimum_number_of_participants =
1156*14675a02SAndroid Build Coastguard Worker         plan.phase().minimum_number_of_participants();
1157*14675a02SAndroid Build Coastguard Worker     if (task_assignment.sec_agg_info->expected_number_of_clients <
1158*14675a02SAndroid Build Coastguard Worker         minimum_number_of_participants) {
1159*14675a02SAndroid Build Coastguard Worker       return absl::InternalError(
1160*14675a02SAndroid Build Coastguard Worker           "expectedNumberOfClients was less than Plan's "
1161*14675a02SAndroid Build Coastguard Worker           "minimumNumberOfParticipants.");
1162*14675a02SAndroid Build Coastguard Worker     }
1163*14675a02SAndroid Build Coastguard Worker     minimum_clients_in_server_visible_aggregate =
1164*14675a02SAndroid Build Coastguard Worker         task_assignment.sec_agg_info
1165*14675a02SAndroid Build Coastguard Worker             ->minimum_clients_in_server_visible_aggregate;
1166*14675a02SAndroid Build Coastguard Worker   }
1167*14675a02SAndroid Build Coastguard Worker   absl::StatusOr<std::string> checkpoint_input_filename = "";
1168*14675a02SAndroid Build Coastguard Worker   // Example query plan does not have an input checkpoint.
1169*14675a02SAndroid Build Coastguard Worker   if (!plan.phase().has_example_query_spec()) {
1170*14675a02SAndroid Build Coastguard Worker     checkpoint_input_filename =
1171*14675a02SAndroid Build Coastguard Worker         CreateInputCheckpointFile(files, task_assignment.payloads.checkpoint);
1172*14675a02SAndroid Build Coastguard Worker     if (!checkpoint_input_filename.ok()) {
1173*14675a02SAndroid Build Coastguard Worker       auto status = checkpoint_input_filename.status();
1174*14675a02SAndroid Build Coastguard Worker       auto message = absl::StrCat(
1175*14675a02SAndroid Build Coastguard Worker           "Failed to create checkpoint input file: code: ", status.code(),
1176*14675a02SAndroid Build Coastguard Worker           ", message: ", status.message());
1177*14675a02SAndroid Build Coastguard Worker       phase_logger.LogCheckinIOError(
1178*14675a02SAndroid Build Coastguard Worker           status,
1179*14675a02SAndroid Build Coastguard Worker           GetNetworkStatsSince(federated_protocol,
1180*14675a02SAndroid Build Coastguard Worker                                /*fedselect_manager=*/nullptr,
1181*14675a02SAndroid Build Coastguard Worker                                network_stats_before_plan_download),
1182*14675a02SAndroid Build Coastguard Worker           time_before_plan_download, reference_time);
1183*14675a02SAndroid Build Coastguard Worker       FCP_LOG(ERROR) << message;
1184*14675a02SAndroid Build Coastguard Worker       return status;
1185*14675a02SAndroid Build Coastguard Worker     }
1186*14675a02SAndroid Build Coastguard Worker   }
1187*14675a02SAndroid Build Coastguard Worker   phase_logger.LogCheckinCompleted(
1188*14675a02SAndroid Build Coastguard Worker       task_name,
1189*14675a02SAndroid Build Coastguard Worker       GetNetworkStatsSince(federated_protocol, /*fedselect_manager=*/nullptr,
1190*14675a02SAndroid Build Coastguard Worker                            network_stats_before_plan_download),
1191*14675a02SAndroid Build Coastguard Worker       /*time_before_checkin=*/time_before_checkin,
1192*14675a02SAndroid Build Coastguard Worker       /*time_before_plan_download=*/time_before_plan_download, reference_time);
1193*14675a02SAndroid Build Coastguard Worker   return CheckinResult{
1194*14675a02SAndroid Build Coastguard Worker       .task_name = std::move(task_name),
1195*14675a02SAndroid Build Coastguard Worker       .plan = std::move(plan),
1196*14675a02SAndroid Build Coastguard Worker       .minimum_clients_in_server_visible_aggregate =
1197*14675a02SAndroid Build Coastguard Worker           minimum_clients_in_server_visible_aggregate,
1198*14675a02SAndroid Build Coastguard Worker       .checkpoint_input_filename = std::move(*checkpoint_input_filename),
1199*14675a02SAndroid Build Coastguard Worker       .computation_id = std::move(computation_id),
1200*14675a02SAndroid Build Coastguard Worker       .federated_select_uri_template =
1201*14675a02SAndroid Build Coastguard Worker           task_assignment.federated_select_uri_template};
1202*14675a02SAndroid Build Coastguard Worker }
1203*14675a02SAndroid Build Coastguard Worker }  // namespace
RunFederatedComputation(SimpleTaskEnvironment * env_deps,EventPublisher * event_publisher,Files * files,LogManager * log_manager,const Flags * flags,const std::string & federated_service_uri,const std::string & api_key,const std::string & test_cert_path,const std::string & session_name,const std::string & population_name,const std::string & retry_token,const std::string & client_version,const std::string & attestation_measurement)1204*14675a02SAndroid Build Coastguard Worker absl::StatusOr<FLRunnerResult> RunFederatedComputation(
1205*14675a02SAndroid Build Coastguard Worker     SimpleTaskEnvironment* env_deps, EventPublisher* event_publisher,
1206*14675a02SAndroid Build Coastguard Worker     Files* files, LogManager* log_manager, const Flags* flags,
1207*14675a02SAndroid Build Coastguard Worker     const std::string& federated_service_uri, const std::string& api_key,
1208*14675a02SAndroid Build Coastguard Worker     const std::string& test_cert_path, const std::string& session_name,
1209*14675a02SAndroid Build Coastguard Worker     const std::string& population_name, const std::string& retry_token,
1210*14675a02SAndroid Build Coastguard Worker     const std::string& client_version,
1211*14675a02SAndroid Build Coastguard Worker     const std::string& attestation_measurement) {
1212*14675a02SAndroid Build Coastguard Worker   auto opstats_logger =
1213*14675a02SAndroid Build Coastguard Worker       engine::CreateOpStatsLogger(env_deps->GetBaseDir(), flags, log_manager,
1214*14675a02SAndroid Build Coastguard Worker                                   session_name, population_name);
1215*14675a02SAndroid Build Coastguard Worker   absl::Time reference_time = absl::Now();
1216*14675a02SAndroid Build Coastguard Worker   FLRunnerResult fl_runner_result;
1217*14675a02SAndroid Build Coastguard Worker   fcp::client::InterruptibleRunner::TimingConfig timing_config = {
1218*14675a02SAndroid Build Coastguard Worker       .polling_period =
1219*14675a02SAndroid Build Coastguard Worker           absl::Milliseconds(flags->condition_polling_period_millis()),
1220*14675a02SAndroid Build Coastguard Worker       .graceful_shutdown_period = absl::Milliseconds(
1221*14675a02SAndroid Build Coastguard Worker           flags->tf_execution_teardown_grace_period_millis()),
1222*14675a02SAndroid Build Coastguard Worker       .extended_shutdown_period = absl::Milliseconds(
1223*14675a02SAndroid Build Coastguard Worker           flags->tf_execution_teardown_extended_period_millis()),
1224*14675a02SAndroid Build Coastguard Worker   };
1225*14675a02SAndroid Build Coastguard Worker   auto should_abort_protocol_callback = [&env_deps, &timing_config]() -> bool {
1226*14675a02SAndroid Build Coastguard Worker     // Return the Status if failed, or the negated value if successful.
1227*14675a02SAndroid Build Coastguard Worker     return env_deps->ShouldAbort(absl::Now(), timing_config.polling_period);
1228*14675a02SAndroid Build Coastguard Worker   };
1229*14675a02SAndroid Build Coastguard Worker   PhaseLoggerImpl phase_logger(event_publisher, opstats_logger.get(),
1230*14675a02SAndroid Build Coastguard Worker                                log_manager, flags);
1231*14675a02SAndroid Build Coastguard Worker   // If there was an error initializing OpStats, opstats_logger will be a
1232*14675a02SAndroid Build Coastguard Worker   // no-op implementation and execution will be allowed to continue.
1233*14675a02SAndroid Build Coastguard Worker   if (!opstats_logger->GetInitStatus().ok()) {
1234*14675a02SAndroid Build Coastguard Worker     // This will only happen if OpStats is enabled and there was an error in
1235*14675a02SAndroid Build Coastguard Worker     // initialization.
1236*14675a02SAndroid Build Coastguard Worker     phase_logger.LogNonfatalInitializationError(
1237*14675a02SAndroid Build Coastguard Worker         opstats_logger->GetInitStatus());
1238*14675a02SAndroid Build Coastguard Worker   }
1239*14675a02SAndroid Build Coastguard Worker   Clock* clock = Clock::RealClock();
1240*14675a02SAndroid Build Coastguard Worker   std::unique_ptr<cache::ResourceCache> resource_cache;
1241*14675a02SAndroid Build Coastguard Worker   // if (flags->max_resource_cache_size_bytes() > 0) {
1242*14675a02SAndroid Build Coastguard Worker   //   // Anything that goes wrong in FileBackedResourceCache::Create is a
1243*14675a02SAndroid Build Coastguard Worker   //   // programmer error.
1244*14675a02SAndroid Build Coastguard Worker   //   absl::StatusOr<std::unique_ptr<cache::ResourceCache>>
1245*14675a02SAndroid Build Coastguard Worker   //       resource_cache_internal = cache::FileBackedResourceCache::Create(
1246*14675a02SAndroid Build Coastguard Worker   //           env_deps->GetBaseDir(), env_deps->GetCacheDir(), log_manager,
1247*14675a02SAndroid Build Coastguard Worker   //           clock, flags->max_resource_cache_size_bytes());
1248*14675a02SAndroid Build Coastguard Worker   //   if (!resource_cache_internal.ok()) {
1249*14675a02SAndroid Build Coastguard Worker   //     auto resource_init_failed_status = absl::Status(
1250*14675a02SAndroid Build Coastguard Worker   //         resource_cache_internal.status().code(),
1251*14675a02SAndroid Build Coastguard Worker   //         absl::StrCat("Failed to initialize FileBackedResourceCache: ",
1252*14675a02SAndroid Build Coastguard Worker   //                      resource_cache_internal.status().ToString()));
1253*14675a02SAndroid Build Coastguard Worker   //     if (flags->resource_cache_initialization_error_is_fatal()) {
1254*14675a02SAndroid Build Coastguard Worker   //       phase_logger.LogFatalInitializationError(resource_init_failed_status);
1255*14675a02SAndroid Build Coastguard Worker   //       return resource_init_failed_status;
1256*14675a02SAndroid Build Coastguard Worker   //     }
1257*14675a02SAndroid Build Coastguard Worker   //     // We log an error but otherwise proceed as if the cache was disabled.
1258*14675a02SAndroid Build Coastguard Worker   //     phase_logger.LogNonfatalInitializationError(resource_init_failed_status);
1259*14675a02SAndroid Build Coastguard Worker   //   } else {
1260*14675a02SAndroid Build Coastguard Worker   //     resource_cache = std::move(*resource_cache_internal);
1261*14675a02SAndroid Build Coastguard Worker   //   }
1262*14675a02SAndroid Build Coastguard Worker   // }
1263*14675a02SAndroid Build Coastguard Worker   std::unique_ptr<::fcp::client::http::HttpClient> http_client =
1264*14675a02SAndroid Build Coastguard Worker       flags->enable_grpc_with_http_resource_support() ||
1265*14675a02SAndroid Build Coastguard Worker               flags->use_http_federated_compute_protocol()
1266*14675a02SAndroid Build Coastguard Worker           ? env_deps->CreateHttpClient()
1267*14675a02SAndroid Build Coastguard Worker           : nullptr;
1268*14675a02SAndroid Build Coastguard Worker   std::unique_ptr<FederatedProtocol> federated_protocol;
1269*14675a02SAndroid Build Coastguard Worker   if (flags->use_http_federated_compute_protocol()) {
1270*14675a02SAndroid Build Coastguard Worker     log_manager->LogDiag(ProdDiagCode::HTTP_FEDERATED_PROTOCOL_USED);
1271*14675a02SAndroid Build Coastguard Worker     // Verify the entry point uri starts with "https://" or
1272*14675a02SAndroid Build Coastguard Worker     // "http://localhost". Note "http://localhost" is allowed for testing
1273*14675a02SAndroid Build Coastguard Worker     // purpose.
1274*14675a02SAndroid Build Coastguard Worker     if (!(absl::StartsWith(federated_service_uri, "https://") ||
1275*14675a02SAndroid Build Coastguard Worker           absl::StartsWith(federated_service_uri, "http://localhost"))) {
1276*14675a02SAndroid Build Coastguard Worker       return absl::InvalidArgumentError("The entry point uri is invalid.");
1277*14675a02SAndroid Build Coastguard Worker     }
1278*14675a02SAndroid Build Coastguard Worker     federated_protocol = std::make_unique<http::HttpFederatedProtocol>(
1279*14675a02SAndroid Build Coastguard Worker         clock, log_manager, flags, http_client.get(),
1280*14675a02SAndroid Build Coastguard Worker         std::make_unique<SecAggRunnerFactoryImpl>(),
1281*14675a02SAndroid Build Coastguard Worker         event_publisher->secagg_event_publisher(), federated_service_uri,
1282*14675a02SAndroid Build Coastguard Worker         api_key, population_name, retry_token, client_version,
1283*14675a02SAndroid Build Coastguard Worker         attestation_measurement, should_abort_protocol_callback, absl::BitGen(),
1284*14675a02SAndroid Build Coastguard Worker         timing_config, resource_cache.get());
1285*14675a02SAndroid Build Coastguard Worker   } else {
1286*14675a02SAndroid Build Coastguard Worker #ifdef FCP_CLIENT_SUPPORT_GRPC
1287*14675a02SAndroid Build Coastguard Worker     // Check in with the server to either retrieve a plan + initial
1288*14675a02SAndroid Build Coastguard Worker     // checkpoint, or get rejected with a RetryWindow.
1289*14675a02SAndroid Build Coastguard Worker     auto grpc_channel_deadline = flags->grpc_channel_deadline_seconds();
1290*14675a02SAndroid Build Coastguard Worker     if (grpc_channel_deadline <= 0) {
1291*14675a02SAndroid Build Coastguard Worker       grpc_channel_deadline = 600;
1292*14675a02SAndroid Build Coastguard Worker       FCP_LOG(INFO) << "Using default channel deadline of "
1293*14675a02SAndroid Build Coastguard Worker                     << grpc_channel_deadline << " seconds.";
1294*14675a02SAndroid Build Coastguard Worker     }
1295*14675a02SAndroid Build Coastguard Worker     federated_protocol = std::make_unique<GrpcFederatedProtocol>(
1296*14675a02SAndroid Build Coastguard Worker         event_publisher, log_manager,
1297*14675a02SAndroid Build Coastguard Worker         std::make_unique<SecAggRunnerFactoryImpl>(), flags, http_client.get(),
1298*14675a02SAndroid Build Coastguard Worker         federated_service_uri, api_key, test_cert_path, population_name,
1299*14675a02SAndroid Build Coastguard Worker         retry_token, client_version, attestation_measurement,
1300*14675a02SAndroid Build Coastguard Worker         should_abort_protocol_callback, timing_config, grpc_channel_deadline,
1301*14675a02SAndroid Build Coastguard Worker         resource_cache.get());
1302*14675a02SAndroid Build Coastguard Worker #else
1303*14675a02SAndroid Build Coastguard Worker     return absl::InternalError("No FederatedProtocol enabled.");
1304*14675a02SAndroid Build Coastguard Worker #endif
1305*14675a02SAndroid Build Coastguard Worker   }
1306*14675a02SAndroid Build Coastguard Worker   std::unique_ptr<FederatedSelectManager> federated_select_manager;
1307*14675a02SAndroid Build Coastguard Worker   if (http_client != nullptr && flags->enable_federated_select()) {
1308*14675a02SAndroid Build Coastguard Worker     federated_select_manager = std::make_unique<HttpFederatedSelectManager>(
1309*14675a02SAndroid Build Coastguard Worker         log_manager, files, http_client.get(), should_abort_protocol_callback,
1310*14675a02SAndroid Build Coastguard Worker         timing_config);
1311*14675a02SAndroid Build Coastguard Worker   } else {
1312*14675a02SAndroid Build Coastguard Worker     federated_select_manager =
1313*14675a02SAndroid Build Coastguard Worker         std::make_unique<DisabledFederatedSelectManager>(log_manager);
1314*14675a02SAndroid Build Coastguard Worker   }
1315*14675a02SAndroid Build Coastguard Worker   return RunFederatedComputation(env_deps, phase_logger, event_publisher, files,
1316*14675a02SAndroid Build Coastguard Worker                                  log_manager, opstats_logger.get(), flags,
1317*14675a02SAndroid Build Coastguard Worker                                  federated_protocol.get(),
1318*14675a02SAndroid Build Coastguard Worker                                  federated_select_manager.get(), timing_config,
1319*14675a02SAndroid Build Coastguard Worker                                  reference_time, session_name, population_name);
1320*14675a02SAndroid Build Coastguard Worker }
RunFederatedComputation(SimpleTaskEnvironment * env_deps,PhaseLogger & phase_logger,EventPublisher * event_publisher,Files * files,LogManager * log_manager,OpStatsLogger * opstats_logger,const Flags * flags,FederatedProtocol * federated_protocol,FederatedSelectManager * fedselect_manager,const fcp::client::InterruptibleRunner::TimingConfig & timing_config,const absl::Time reference_time,const std::string & session_name,const std::string & population_name)1321*14675a02SAndroid Build Coastguard Worker absl::StatusOr<FLRunnerResult> RunFederatedComputation(
1322*14675a02SAndroid Build Coastguard Worker     SimpleTaskEnvironment* env_deps, PhaseLogger& phase_logger,
1323*14675a02SAndroid Build Coastguard Worker     EventPublisher* event_publisher, Files* files, LogManager* log_manager,
1324*14675a02SAndroid Build Coastguard Worker     OpStatsLogger* opstats_logger, const Flags* flags,
1325*14675a02SAndroid Build Coastguard Worker     FederatedProtocol* federated_protocol,
1326*14675a02SAndroid Build Coastguard Worker     FederatedSelectManager* fedselect_manager,
1327*14675a02SAndroid Build Coastguard Worker     const fcp::client::InterruptibleRunner::TimingConfig& timing_config,
1328*14675a02SAndroid Build Coastguard Worker     const absl::Time reference_time, const std::string& session_name,
1329*14675a02SAndroid Build Coastguard Worker     const std::string& population_name) {
1330*14675a02SAndroid Build Coastguard Worker   SelectorContext federated_selector_context;
1331*14675a02SAndroid Build Coastguard Worker   federated_selector_context.mutable_computation_properties()->set_session_name(
1332*14675a02SAndroid Build Coastguard Worker       session_name);
1333*14675a02SAndroid Build Coastguard Worker   FederatedComputation federated_computation;
1334*14675a02SAndroid Build Coastguard Worker   federated_computation.set_population_name(population_name);
1335*14675a02SAndroid Build Coastguard Worker   *federated_selector_context.mutable_computation_properties()
1336*14675a02SAndroid Build Coastguard Worker        ->mutable_federated() = federated_computation;
1337*14675a02SAndroid Build Coastguard Worker   SelectorContext eligibility_selector_context;
1338*14675a02SAndroid Build Coastguard Worker   eligibility_selector_context.mutable_computation_properties()
1339*14675a02SAndroid Build Coastguard Worker       ->set_session_name(session_name);
1340*14675a02SAndroid Build Coastguard Worker   EligibilityEvalComputation eligibility_eval_computation;
1341*14675a02SAndroid Build Coastguard Worker   eligibility_eval_computation.set_population_name(population_name);
1342*14675a02SAndroid Build Coastguard Worker   *eligibility_selector_context.mutable_computation_properties()
1343*14675a02SAndroid Build Coastguard Worker        ->mutable_eligibility_eval() = eligibility_eval_computation;
1344*14675a02SAndroid Build Coastguard Worker   // Construct a default FLRunnerResult that reflects an unsuccessful training
1345*14675a02SAndroid Build Coastguard Worker   // attempt and which uses RetryWindow corresponding to transient errors (if
1346*14675a02SAndroid Build Coastguard Worker   // the flag is on).
1347*14675a02SAndroid Build Coastguard Worker   // This is what will be returned if we have to bail early, before we've
1348*14675a02SAndroid Build Coastguard Worker   // received a RetryWindow from the server.
1349*14675a02SAndroid Build Coastguard Worker   FLRunnerResult fl_runner_result;
1350*14675a02SAndroid Build Coastguard Worker   fl_runner_result.set_contribution_result(FLRunnerResult::FAIL);
1351*14675a02SAndroid Build Coastguard Worker   // Before we even check whether we should abort right away, update the retry
1352*14675a02SAndroid Build Coastguard Worker   // window. That way we will use the most appropriate retry window we have
1353*14675a02SAndroid Build Coastguard Worker   // available (an implementation detail of FederatedProtocol, but generally a
1354*14675a02SAndroid Build Coastguard Worker   // 'transient error' retry window based on the provided flag values) in case
1355*14675a02SAndroid Build Coastguard Worker   // we do need to abort.
1356*14675a02SAndroid Build Coastguard Worker   UpdateRetryWindowAndNetworkStats(*federated_protocol, fedselect_manager,
1357*14675a02SAndroid Build Coastguard Worker                                    phase_logger, fl_runner_result);
1358*14675a02SAndroid Build Coastguard Worker   // Check if the device conditions allow for checking in with the server
1359*14675a02SAndroid Build Coastguard Worker   // and running a federated computation. If not, bail early with the
1360*14675a02SAndroid Build Coastguard Worker   // transient error retry window.
1361*14675a02SAndroid Build Coastguard Worker   std::function<bool()> should_abort = [env_deps, &timing_config]() {
1362*14675a02SAndroid Build Coastguard Worker     return env_deps->ShouldAbort(absl::Now(), timing_config.polling_period);
1363*14675a02SAndroid Build Coastguard Worker   };
1364*14675a02SAndroid Build Coastguard Worker   if (should_abort()) {
1365*14675a02SAndroid Build Coastguard Worker     std::string message =
1366*14675a02SAndroid Build Coastguard Worker         "Device conditions not satisfied, aborting federated computation";
1367*14675a02SAndroid Build Coastguard Worker     FCP_LOG(INFO) << message;
1368*14675a02SAndroid Build Coastguard Worker     phase_logger.LogTaskNotStarted(message);
1369*14675a02SAndroid Build Coastguard Worker     return fl_runner_result;
1370*14675a02SAndroid Build Coastguard Worker   }
1371*14675a02SAndroid Build Coastguard Worker   // Eligibility eval plans can use example iterators from the
1372*14675a02SAndroid Build Coastguard Worker   // SimpleTaskEnvironment and those reading the OpStats DB.
1373*14675a02SAndroid Build Coastguard Worker   opstats::OpStatsExampleIteratorFactory opstats_example_iterator_factory(
1374*14675a02SAndroid Build Coastguard Worker       opstats_logger, log_manager,
1375*14675a02SAndroid Build Coastguard Worker       flags->opstats_last_successful_contribution_criteria());
1376*14675a02SAndroid Build Coastguard Worker   std::unique_ptr<engine::ExampleIteratorFactory>
1377*14675a02SAndroid Build Coastguard Worker       env_eligibility_example_iterator_factory =
1378*14675a02SAndroid Build Coastguard Worker           CreateSimpleTaskEnvironmentIteratorFactory(
1379*14675a02SAndroid Build Coastguard Worker               env_deps, eligibility_selector_context);
1380*14675a02SAndroid Build Coastguard Worker   std::vector<engine::ExampleIteratorFactory*>
1381*14675a02SAndroid Build Coastguard Worker       eligibility_example_iterator_factories{
1382*14675a02SAndroid Build Coastguard Worker           &opstats_example_iterator_factory,
1383*14675a02SAndroid Build Coastguard Worker           env_eligibility_example_iterator_factory.get()};
1384*14675a02SAndroid Build Coastguard Worker   // Note that this method will update fl_runner_result's fields with values
1385*14675a02SAndroid Build Coastguard Worker   // received over the course of the eligibility eval protocol interaction.
1386*14675a02SAndroid Build Coastguard Worker   absl::StatusOr<EligibilityEvalResult> eligibility_eval_result =
1387*14675a02SAndroid Build Coastguard Worker       IssueEligibilityEvalCheckinAndRunPlan(
1388*14675a02SAndroid Build Coastguard Worker           eligibility_example_iterator_factories, should_abort, phase_logger,
1389*14675a02SAndroid Build Coastguard Worker           files, log_manager, opstats_logger, flags, federated_protocol,
1390*14675a02SAndroid Build Coastguard Worker           timing_config, reference_time, fl_runner_result);
1391*14675a02SAndroid Build Coastguard Worker   if (!eligibility_eval_result.ok()) {
1392*14675a02SAndroid Build Coastguard Worker     return fl_runner_result;
1393*14675a02SAndroid Build Coastguard Worker   }
1394*14675a02SAndroid Build Coastguard Worker   auto checkin_result =
1395*14675a02SAndroid Build Coastguard Worker       IssueCheckin(phase_logger, log_manager, files, federated_protocol,
1396*14675a02SAndroid Build Coastguard Worker                    std::move(eligibility_eval_result->task_eligibility_info),
1397*14675a02SAndroid Build Coastguard Worker                    reference_time, population_name, fl_runner_result, flags);
1398*14675a02SAndroid Build Coastguard Worker   if (!checkin_result.ok()) {
1399*14675a02SAndroid Build Coastguard Worker     return fl_runner_result;
1400*14675a02SAndroid Build Coastguard Worker   }
1401*14675a02SAndroid Build Coastguard Worker   SelectorContext federated_selector_context_with_task_name =
1402*14675a02SAndroid Build Coastguard Worker       federated_selector_context;
1403*14675a02SAndroid Build Coastguard Worker   federated_selector_context_with_task_name.mutable_computation_properties()
1404*14675a02SAndroid Build Coastguard Worker       ->mutable_federated()
1405*14675a02SAndroid Build Coastguard Worker       ->set_task_name(checkin_result->task_name);
1406*14675a02SAndroid Build Coastguard Worker   if (flags->enable_computation_id()) {
1407*14675a02SAndroid Build Coastguard Worker     federated_selector_context_with_task_name.mutable_computation_properties()
1408*14675a02SAndroid Build Coastguard Worker         ->mutable_federated()
1409*14675a02SAndroid Build Coastguard Worker         ->set_computation_id(checkin_result->computation_id);
1410*14675a02SAndroid Build Coastguard Worker   }
1411*14675a02SAndroid Build Coastguard Worker   if (checkin_result->plan.phase().has_example_query_spec()) {
1412*14675a02SAndroid Build Coastguard Worker     federated_selector_context_with_task_name.mutable_computation_properties()
1413*14675a02SAndroid Build Coastguard Worker         ->set_example_iterator_output_format(
1414*14675a02SAndroid Build Coastguard Worker             ::fcp::client::QueryTimeComputationProperties::
1415*14675a02SAndroid Build Coastguard Worker                 EXAMPLE_QUERY_RESULT);
1416*14675a02SAndroid Build Coastguard Worker   }
1417*14675a02SAndroid Build Coastguard Worker   // Include the last successful contribution timestamp in the
1418*14675a02SAndroid Build Coastguard Worker   // SelectorContext.
1419*14675a02SAndroid Build Coastguard Worker   const auto& opstats_db = opstats_logger->GetOpStatsDb();
1420*14675a02SAndroid Build Coastguard Worker   if (opstats_db != nullptr) {
1421*14675a02SAndroid Build Coastguard Worker     absl::StatusOr<opstats::OpStatsSequence> data = opstats_db->Read();
1422*14675a02SAndroid Build Coastguard Worker     if (data.ok()) {
1423*14675a02SAndroid Build Coastguard Worker       std::optional<google::protobuf::Timestamp>
1424*14675a02SAndroid Build Coastguard Worker           last_successful_contribution_time =
1425*14675a02SAndroid Build Coastguard Worker               opstats::GetLastSuccessfulContributionTime(
1426*14675a02SAndroid Build Coastguard Worker                   *data, checkin_result->task_name);
1427*14675a02SAndroid Build Coastguard Worker       if (last_successful_contribution_time.has_value()) {
1428*14675a02SAndroid Build Coastguard Worker         *(federated_selector_context_with_task_name
1429*14675a02SAndroid Build Coastguard Worker               .mutable_computation_properties()
1430*14675a02SAndroid Build Coastguard Worker               ->mutable_federated()
1431*14675a02SAndroid Build Coastguard Worker               ->mutable_historical_context()
1432*14675a02SAndroid Build Coastguard Worker               ->mutable_last_successful_contribution_time()) =
1433*14675a02SAndroid Build Coastguard Worker             *last_successful_contribution_time;
1434*14675a02SAndroid Build Coastguard Worker       }
1435*14675a02SAndroid Build Coastguard Worker     }
1436*14675a02SAndroid Build Coastguard Worker   }
1437*14675a02SAndroid Build Coastguard Worker   if (checkin_result->plan.phase().has_example_query_spec()) {
1438*14675a02SAndroid Build Coastguard Worker     // Example query plan only supports simple agg for now.
1439*14675a02SAndroid Build Coastguard Worker     *(federated_selector_context_with_task_name
1440*14675a02SAndroid Build Coastguard Worker           .mutable_computation_properties()
1441*14675a02SAndroid Build Coastguard Worker           ->mutable_federated()
1442*14675a02SAndroid Build Coastguard Worker           ->mutable_simple_aggregation()) = SimpleAggregation();
1443*14675a02SAndroid Build Coastguard Worker   } else {
1444*14675a02SAndroid Build Coastguard Worker     const auto& federated_compute_io_router =
1445*14675a02SAndroid Build Coastguard Worker         checkin_result->plan.phase().federated_compute();
1446*14675a02SAndroid Build Coastguard Worker     const bool has_simpleagg_tensors =
1447*14675a02SAndroid Build Coastguard Worker         !federated_compute_io_router.output_filepath_tensor_name().empty();
1448*14675a02SAndroid Build Coastguard Worker     bool all_aggregations_are_secagg = true;
1449*14675a02SAndroid Build Coastguard Worker     for (const auto& aggregation : federated_compute_io_router.aggregations()) {
1450*14675a02SAndroid Build Coastguard Worker       all_aggregations_are_secagg &=
1451*14675a02SAndroid Build Coastguard Worker           aggregation.second.protocol_config_case() ==
1452*14675a02SAndroid Build Coastguard Worker           AggregationConfig::kSecureAggregation;
1453*14675a02SAndroid Build Coastguard Worker     }
1454*14675a02SAndroid Build Coastguard Worker     if (!has_simpleagg_tensors && all_aggregations_are_secagg) {
1455*14675a02SAndroid Build Coastguard Worker       federated_selector_context_with_task_name
1456*14675a02SAndroid Build Coastguard Worker           .mutable_computation_properties()
1457*14675a02SAndroid Build Coastguard Worker           ->mutable_federated()
1458*14675a02SAndroid Build Coastguard Worker           ->mutable_secure_aggregation()
1459*14675a02SAndroid Build Coastguard Worker           ->set_minimum_clients_in_server_visible_aggregate(
1460*14675a02SAndroid Build Coastguard Worker               checkin_result->minimum_clients_in_server_visible_aggregate);
1461*14675a02SAndroid Build Coastguard Worker     } else {
1462*14675a02SAndroid Build Coastguard Worker       // Has an output checkpoint, so some tensors must be simply aggregated.
1463*14675a02SAndroid Build Coastguard Worker       *(federated_selector_context_with_task_name
1464*14675a02SAndroid Build Coastguard Worker             .mutable_computation_properties()
1465*14675a02SAndroid Build Coastguard Worker             ->mutable_federated()
1466*14675a02SAndroid Build Coastguard Worker             ->mutable_simple_aggregation()) = SimpleAggregation();
1467*14675a02SAndroid Build Coastguard Worker     }
1468*14675a02SAndroid Build Coastguard Worker   }
1469*14675a02SAndroid Build Coastguard Worker   RetryWindow report_retry_window;
1470*14675a02SAndroid Build Coastguard Worker   phase_logger.LogComputationStarted();
1471*14675a02SAndroid Build Coastguard Worker   absl::Time run_plan_start_time = absl::Now();
1472*14675a02SAndroid Build Coastguard Worker   NetworkStats run_plan_start_network_stats =
1473*14675a02SAndroid Build Coastguard Worker       GetCumulativeNetworkStats(federated_protocol, fedselect_manager);
1474*14675a02SAndroid Build Coastguard Worker   absl::StatusOr<std::string> checkpoint_output_filename =
1475*14675a02SAndroid Build Coastguard Worker       files->CreateTempFile("output", ".ckp");
1476*14675a02SAndroid Build Coastguard Worker   if (!checkpoint_output_filename.ok()) {
1477*14675a02SAndroid Build Coastguard Worker     auto status = checkpoint_output_filename.status();
1478*14675a02SAndroid Build Coastguard Worker     auto message = absl::StrCat(
1479*14675a02SAndroid Build Coastguard Worker         "Could not create temporary output checkpoint file: code: ",
1480*14675a02SAndroid Build Coastguard Worker         status.code(), ", message: ", status.message());
1481*14675a02SAndroid Build Coastguard Worker     phase_logger.LogComputationIOError(
1482*14675a02SAndroid Build Coastguard Worker         status, ExampleStats(),
1483*14675a02SAndroid Build Coastguard Worker         GetNetworkStatsSince(federated_protocol, fedselect_manager,
1484*14675a02SAndroid Build Coastguard Worker                              run_plan_start_network_stats),
1485*14675a02SAndroid Build Coastguard Worker         run_plan_start_time);
1486*14675a02SAndroid Build Coastguard Worker     return fl_runner_result;
1487*14675a02SAndroid Build Coastguard Worker   }
1488*14675a02SAndroid Build Coastguard Worker   // Regular plans can use example iterators from the SimpleTaskEnvironment,
1489*14675a02SAndroid Build Coastguard Worker   // those reading the OpStats DB, or those serving Federated Select slices.
1490*14675a02SAndroid Build Coastguard Worker   std::unique_ptr<engine::ExampleIteratorFactory> env_example_iterator_factory =
1491*14675a02SAndroid Build Coastguard Worker       CreateSimpleTaskEnvironmentIteratorFactory(
1492*14675a02SAndroid Build Coastguard Worker           env_deps, federated_selector_context_with_task_name);
1493*14675a02SAndroid Build Coastguard Worker   std::unique_ptr<::fcp::client::engine::ExampleIteratorFactory>
1494*14675a02SAndroid Build Coastguard Worker       fedselect_example_iterator_factory =
1495*14675a02SAndroid Build Coastguard Worker           fedselect_manager->CreateExampleIteratorFactoryForUriTemplate(
1496*14675a02SAndroid Build Coastguard Worker               checkin_result->federated_select_uri_template);
1497*14675a02SAndroid Build Coastguard Worker   std::vector<engine::ExampleIteratorFactory*> example_iterator_factories{
1498*14675a02SAndroid Build Coastguard Worker       fedselect_example_iterator_factory.get(),
1499*14675a02SAndroid Build Coastguard Worker       &opstats_example_iterator_factory, env_example_iterator_factory.get()};
1500*14675a02SAndroid Build Coastguard Worker   PlanResultAndCheckpointFile plan_result_and_checkpoint_file =
1501*14675a02SAndroid Build Coastguard Worker       checkin_result->plan.phase().has_example_query_spec()
1502*14675a02SAndroid Build Coastguard Worker           ? RunPlanWithExampleQuerySpec(
1503*14675a02SAndroid Build Coastguard Worker                 example_iterator_factories, opstats_logger, flags,
1504*14675a02SAndroid Build Coastguard Worker                 checkin_result->plan, *checkpoint_output_filename)
1505*14675a02SAndroid Build Coastguard Worker           : RunPlanWithTensorflowSpec(
1506*14675a02SAndroid Build Coastguard Worker                 example_iterator_factories, should_abort, log_manager,
1507*14675a02SAndroid Build Coastguard Worker                 opstats_logger, flags, checkin_result->plan,
1508*14675a02SAndroid Build Coastguard Worker                 checkin_result->checkpoint_input_filename,
1509*14675a02SAndroid Build Coastguard Worker                 *checkpoint_output_filename, timing_config);
1510*14675a02SAndroid Build Coastguard Worker   // Update the FLRunnerResult fields to account for any network usage during
1511*14675a02SAndroid Build Coastguard Worker   // the execution of the plan (e.g. due to Federated Select slices having
1512*14675a02SAndroid Build Coastguard Worker   // been fetched).
1513*14675a02SAndroid Build Coastguard Worker   UpdateRetryWindowAndNetworkStats(*federated_protocol, fedselect_manager,
1514*14675a02SAndroid Build Coastguard Worker                                    phase_logger, fl_runner_result);
1515*14675a02SAndroid Build Coastguard Worker   auto outcome = plan_result_and_checkpoint_file.plan_result.outcome;
1516*14675a02SAndroid Build Coastguard Worker   absl::StatusOr<ComputationResults> computation_results;
1517*14675a02SAndroid Build Coastguard Worker   if (outcome == engine::PlanOutcome::kSuccess) {
1518*14675a02SAndroid Build Coastguard Worker     computation_results = CreateComputationResults(
1519*14675a02SAndroid Build Coastguard Worker         checkin_result->plan.phase().has_example_query_spec()
1520*14675a02SAndroid Build Coastguard Worker             ? nullptr
1521*14675a02SAndroid Build Coastguard Worker             : &checkin_result->plan.phase().tensorflow_spec(),
1522*14675a02SAndroid Build Coastguard Worker         plan_result_and_checkpoint_file);
1523*14675a02SAndroid Build Coastguard Worker   }
1524*14675a02SAndroid Build Coastguard Worker   LogComputationOutcome(
1525*14675a02SAndroid Build Coastguard Worker       plan_result_and_checkpoint_file.plan_result, computation_results.status(),
1526*14675a02SAndroid Build Coastguard Worker       phase_logger,
1527*14675a02SAndroid Build Coastguard Worker       GetNetworkStatsSince(federated_protocol, fedselect_manager,
1528*14675a02SAndroid Build Coastguard Worker                            run_plan_start_network_stats),
1529*14675a02SAndroid Build Coastguard Worker       run_plan_start_time, reference_time);
1530*14675a02SAndroid Build Coastguard Worker   absl::Status report_result = ReportPlanResult(
1531*14675a02SAndroid Build Coastguard Worker       federated_protocol, phase_logger, std::move(computation_results),
1532*14675a02SAndroid Build Coastguard Worker       run_plan_start_time, reference_time);
1533*14675a02SAndroid Build Coastguard Worker   if (outcome == engine::PlanOutcome::kSuccess && report_result.ok()) {
1534*14675a02SAndroid Build Coastguard Worker     // Only if training succeeded *and* reporting succeeded do we consider
1535*14675a02SAndroid Build Coastguard Worker     // the device to have contributed successfully.
1536*14675a02SAndroid Build Coastguard Worker     fl_runner_result.set_contribution_result(FLRunnerResult::SUCCESS);
1537*14675a02SAndroid Build Coastguard Worker   }
1538*14675a02SAndroid Build Coastguard Worker   // Update the FLRunnerResult fields one more time to account for the
1539*14675a02SAndroid Build Coastguard Worker   // "Report" protocol interaction.
1540*14675a02SAndroid Build Coastguard Worker   UpdateRetryWindowAndNetworkStats(*federated_protocol, fedselect_manager,
1541*14675a02SAndroid Build Coastguard Worker                                    phase_logger, fl_runner_result);
1542*14675a02SAndroid Build Coastguard Worker   return fl_runner_result;
1543*14675a02SAndroid Build Coastguard Worker }
RunPlanWithTensorflowSpecForTesting(SimpleTaskEnvironment * env_deps,EventPublisher * event_publisher,Files * files,LogManager * log_manager,const Flags * flags,const ClientOnlyPlan & client_plan,const std::string & checkpoint_input_filename,const fcp::client::InterruptibleRunner::TimingConfig & timing_config,const absl::Time run_plan_start_time,const absl::Time reference_time)1544*14675a02SAndroid Build Coastguard Worker FLRunnerTensorflowSpecResult RunPlanWithTensorflowSpecForTesting(
1545*14675a02SAndroid Build Coastguard Worker     SimpleTaskEnvironment* env_deps, EventPublisher* event_publisher,
1546*14675a02SAndroid Build Coastguard Worker     Files* files, LogManager* log_manager, const Flags* flags,
1547*14675a02SAndroid Build Coastguard Worker     const ClientOnlyPlan& client_plan,
1548*14675a02SAndroid Build Coastguard Worker     const std::string& checkpoint_input_filename,
1549*14675a02SAndroid Build Coastguard Worker     const fcp::client::InterruptibleRunner::TimingConfig& timing_config,
1550*14675a02SAndroid Build Coastguard Worker     const absl::Time run_plan_start_time, const absl::Time reference_time) {
1551*14675a02SAndroid Build Coastguard Worker   FLRunnerTensorflowSpecResult result;
1552*14675a02SAndroid Build Coastguard Worker   result.set_outcome(engine::PhaseOutcome::ERROR);
1553*14675a02SAndroid Build Coastguard Worker   engine::PlanResult plan_result(engine::PlanOutcome::kTensorflowError,
1554*14675a02SAndroid Build Coastguard Worker                                  absl::UnknownError(""));
1555*14675a02SAndroid Build Coastguard Worker   std::function<bool()> should_abort = [env_deps, &timing_config]() {
1556*14675a02SAndroid Build Coastguard Worker     return env_deps->ShouldAbort(absl::Now(), timing_config.polling_period);
1557*14675a02SAndroid Build Coastguard Worker   };
1558*14675a02SAndroid Build Coastguard Worker   auto opstats_logger =
1559*14675a02SAndroid Build Coastguard Worker       engine::CreateOpStatsLogger(env_deps->GetBaseDir(), flags, log_manager,
1560*14675a02SAndroid Build Coastguard Worker                                   /*session_name=*/"", /*population_name=*/"");
1561*14675a02SAndroid Build Coastguard Worker   PhaseLoggerImpl phase_logger(event_publisher, opstats_logger.get(),
1562*14675a02SAndroid Build Coastguard Worker                                log_manager, flags);
1563*14675a02SAndroid Build Coastguard Worker   // Regular plans can use example iterators from the SimpleTaskEnvironment,
1564*14675a02SAndroid Build Coastguard Worker   // those reading the OpStats DB, or those serving Federated Select slices.
1565*14675a02SAndroid Build Coastguard Worker   // However, we don't provide a Federated Select-specific example iterator
1566*14675a02SAndroid Build Coastguard Worker   // factory. That way, the Federated Select slice queries will be forwarded
1567*14675a02SAndroid Build Coastguard Worker   // to SimpleTaskEnvironment, which can handle them by providing
1568*14675a02SAndroid Build Coastguard Worker   // test-specific slices if they want to.
1569*14675a02SAndroid Build Coastguard Worker   //
1570*14675a02SAndroid Build Coastguard Worker   // Eligibility eval plans can only use iterators from the
1571*14675a02SAndroid Build Coastguard Worker   // SimpleTaskEnvironment and those reading the OpStats DB.
1572*14675a02SAndroid Build Coastguard Worker   opstats::OpStatsExampleIteratorFactory opstats_example_iterator_factory(
1573*14675a02SAndroid Build Coastguard Worker       opstats_logger.get(), log_manager,
1574*14675a02SAndroid Build Coastguard Worker       flags->opstats_last_successful_contribution_criteria());
1575*14675a02SAndroid Build Coastguard Worker   std::unique_ptr<engine::ExampleIteratorFactory> env_example_iterator_factory =
1576*14675a02SAndroid Build Coastguard Worker       CreateSimpleTaskEnvironmentIteratorFactory(env_deps, SelectorContext());
1577*14675a02SAndroid Build Coastguard Worker   std::vector<engine::ExampleIteratorFactory*> example_iterator_factories{
1578*14675a02SAndroid Build Coastguard Worker       &opstats_example_iterator_factory, env_example_iterator_factory.get()};
1579*14675a02SAndroid Build Coastguard Worker   phase_logger.LogComputationStarted();
1580*14675a02SAndroid Build Coastguard Worker   if (client_plan.phase().has_federated_compute()) {
1581*14675a02SAndroid Build Coastguard Worker     absl::StatusOr<std::string> checkpoint_output_filename =
1582*14675a02SAndroid Build Coastguard Worker         files->CreateTempFile("output", ".ckp");
1583*14675a02SAndroid Build Coastguard Worker     if (!checkpoint_output_filename.ok()) {
1584*14675a02SAndroid Build Coastguard Worker       phase_logger.LogComputationIOError(
1585*14675a02SAndroid Build Coastguard Worker           checkpoint_output_filename.status(), ExampleStats(),
1586*14675a02SAndroid Build Coastguard Worker           // Empty network stats, since no network protocol is actually used
1587*14675a02SAndroid Build Coastguard Worker           // in this method.
1588*14675a02SAndroid Build Coastguard Worker           NetworkStats(), run_plan_start_time);
1589*14675a02SAndroid Build Coastguard Worker       return result;
1590*14675a02SAndroid Build Coastguard Worker     }
1591*14675a02SAndroid Build Coastguard Worker     // Regular TensorflowSpec-based plans.
1592*14675a02SAndroid Build Coastguard Worker     PlanResultAndCheckpointFile plan_result_and_checkpoint_file =
1593*14675a02SAndroid Build Coastguard Worker         RunPlanWithTensorflowSpec(example_iterator_factories, should_abort,
1594*14675a02SAndroid Build Coastguard Worker                                   log_manager, opstats_logger.get(), flags,
1595*14675a02SAndroid Build Coastguard Worker                                   client_plan, checkpoint_input_filename,
1596*14675a02SAndroid Build Coastguard Worker                                   *checkpoint_output_filename, timing_config);
1597*14675a02SAndroid Build Coastguard Worker     result.set_checkpoint_output_filename(
1598*14675a02SAndroid Build Coastguard Worker         plan_result_and_checkpoint_file.checkpoint_file);
1599*14675a02SAndroid Build Coastguard Worker     plan_result = std::move(plan_result_and_checkpoint_file.plan_result);
1600*14675a02SAndroid Build Coastguard Worker   } else if (client_plan.phase().has_federated_compute_eligibility()) {
1601*14675a02SAndroid Build Coastguard Worker     // Eligibility eval plans.
1602*14675a02SAndroid Build Coastguard Worker     plan_result = RunEligibilityEvalPlanWithTensorflowSpec(
1603*14675a02SAndroid Build Coastguard Worker         example_iterator_factories, should_abort, log_manager,
1604*14675a02SAndroid Build Coastguard Worker         opstats_logger.get(), flags, client_plan, checkpoint_input_filename,
1605*14675a02SAndroid Build Coastguard Worker         timing_config, run_plan_start_time, reference_time);
1606*14675a02SAndroid Build Coastguard Worker   } else {
1607*14675a02SAndroid Build Coastguard Worker     // This branch shouldn't be taken, unless we add an additional type of
1608*14675a02SAndroid Build Coastguard Worker     // TensorflowSpec-based plan in the future. We return a readable error so
1609*14675a02SAndroid Build Coastguard Worker     // that when such new plan types *are* added, they result in clear
1610*14675a02SAndroid Build Coastguard Worker     // compatibility test failures when such plans are erroneously targeted at
1611*14675a02SAndroid Build Coastguard Worker     // old releases that don't support them yet.
1612*14675a02SAndroid Build Coastguard Worker     event_publisher->PublishIoError("Unsupported TensorflowSpec-based plan");
1613*14675a02SAndroid Build Coastguard Worker     return result;
1614*14675a02SAndroid Build Coastguard Worker   }
1615*14675a02SAndroid Build Coastguard Worker   // Copy output tensors into the result proto.
1616*14675a02SAndroid Build Coastguard Worker   result.set_outcome(
1617*14675a02SAndroid Build Coastguard Worker       engine::ConvertPlanOutcomeToPhaseOutcome(plan_result.outcome));
1618*14675a02SAndroid Build Coastguard Worker   if (plan_result.outcome == engine::PlanOutcome::kSuccess) {
1619*14675a02SAndroid Build Coastguard Worker     for (int i = 0; i < plan_result.output_names.size(); i++) {
1620*14675a02SAndroid Build Coastguard Worker       tensorflow::TensorProto output_tensor_proto;
1621*14675a02SAndroid Build Coastguard Worker       plan_result.output_tensors[i].AsProtoField(&output_tensor_proto);
1622*14675a02SAndroid Build Coastguard Worker       (*result.mutable_output_tensors())[plan_result.output_names[i]] =
1623*14675a02SAndroid Build Coastguard Worker           std::move(output_tensor_proto);
1624*14675a02SAndroid Build Coastguard Worker     }
1625*14675a02SAndroid Build Coastguard Worker     phase_logger.LogComputationCompleted(
1626*14675a02SAndroid Build Coastguard Worker         plan_result.example_stats,
1627*14675a02SAndroid Build Coastguard Worker         // Empty network stats, since no network protocol is actually used in
1628*14675a02SAndroid Build Coastguard Worker         // this method.
1629*14675a02SAndroid Build Coastguard Worker         NetworkStats(), run_plan_start_time, reference_time);
1630*14675a02SAndroid Build Coastguard Worker   } else {
1631*14675a02SAndroid Build Coastguard Worker     phase_logger.LogComputationTensorflowError(
1632*14675a02SAndroid Build Coastguard Worker         plan_result.original_status, plan_result.example_stats, NetworkStats(),
1633*14675a02SAndroid Build Coastguard Worker         run_plan_start_time, reference_time);
1634*14675a02SAndroid Build Coastguard Worker   }
1635*14675a02SAndroid Build Coastguard Worker   return result;
1636*14675a02SAndroid Build Coastguard Worker }
1637*14675a02SAndroid Build Coastguard Worker }  // namespace client
1638*14675a02SAndroid Build Coastguard Worker }  // namespace fcp
1639