xref: /aosp_15_r20/external/federated-compute/fcp/client/fcp_runner.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker /*
2*14675a02SAndroid Build Coastguard Worker  * Copyright 2023 Google LLC
3*14675a02SAndroid Build Coastguard Worker  *
4*14675a02SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*14675a02SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*14675a02SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*14675a02SAndroid Build Coastguard Worker  *
8*14675a02SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*14675a02SAndroid Build Coastguard Worker  *
10*14675a02SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*14675a02SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*14675a02SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*14675a02SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*14675a02SAndroid Build Coastguard Worker  * limitations under the License.
15*14675a02SAndroid Build Coastguard Worker  */
16*14675a02SAndroid Build Coastguard Worker 
17*14675a02SAndroid Build Coastguard Worker #include "fcp/client/fcp_runner.h"
18*14675a02SAndroid Build Coastguard Worker 
19*14675a02SAndroid Build Coastguard Worker #include "fcp/client/engine/example_iterator_factory.h"
20*14675a02SAndroid Build Coastguard Worker #include "fcp/client/engine/example_query_plan_engine.h"
21*14675a02SAndroid Build Coastguard Worker #include "fcp/client/engine/plan_engine_helpers.h"
22*14675a02SAndroid Build Coastguard Worker #include "fcp/client/engine/tflite_plan_engine.h"
23*14675a02SAndroid Build Coastguard Worker #include "fcp/client/fl_runner.pb.h"
24*14675a02SAndroid Build Coastguard Worker #include "fcp/client/opstats/opstats_logger.h"
25*14675a02SAndroid Build Coastguard Worker #include "fcp/protos/plan.pb.h"
26*14675a02SAndroid Build Coastguard Worker 
27*14675a02SAndroid Build Coastguard Worker namespace fcp {
28*14675a02SAndroid Build Coastguard Worker namespace client {
29*14675a02SAndroid Build Coastguard Worker 
30*14675a02SAndroid Build Coastguard Worker using ::fcp::client::opstats::OpStatsLogger;
31*14675a02SAndroid Build Coastguard Worker using ::google::internal::federated::plan::AggregationConfig;
32*14675a02SAndroid Build Coastguard Worker using ::google::internal::federated::plan::ClientOnlyPlan;
33*14675a02SAndroid Build Coastguard Worker using ::google::internal::federated::plan::FederatedComputeIORouter;
34*14675a02SAndroid Build Coastguard Worker using ::google::internal::federated::plan::TensorflowSpec;
35*14675a02SAndroid Build Coastguard Worker 
36*14675a02SAndroid Build Coastguard Worker using TfLiteInputs = absl::flat_hash_map<std::string, std::string>;
37*14675a02SAndroid Build Coastguard Worker namespace {
38*14675a02SAndroid Build Coastguard Worker 
39*14675a02SAndroid Build Coastguard Worker // Creates an ExampleIteratorFactory that routes queries to the
40*14675a02SAndroid Build Coastguard Worker // SimpleTaskEnvironment::CreateExampleIterator() method.
41*14675a02SAndroid Build Coastguard Worker std::unique_ptr<engine::ExampleIteratorFactory>
CreateSimpleTaskEnvironmentIteratorFactory(SimpleTaskEnvironment * task_env,const SelectorContext & selector_context)42*14675a02SAndroid Build Coastguard Worker CreateSimpleTaskEnvironmentIteratorFactory(
43*14675a02SAndroid Build Coastguard Worker     SimpleTaskEnvironment* task_env, const SelectorContext& selector_context) {
44*14675a02SAndroid Build Coastguard Worker   return std::make_unique<engine::FunctionalExampleIteratorFactory>(
45*14675a02SAndroid Build Coastguard Worker       /*can_handle_func=*/
46*14675a02SAndroid Build Coastguard Worker       [](const google::internal::federated::plan::ExampleSelector&) {
47*14675a02SAndroid Build Coastguard Worker         // The SimpleTaskEnvironment-based ExampleIteratorFactory should
48*14675a02SAndroid Build Coastguard Worker         // be the catch-all factory that is able to handle all queries
49*14675a02SAndroid Build Coastguard Worker         // that no other ExampleIteratorFactory is able to handle.
50*14675a02SAndroid Build Coastguard Worker         return true;
51*14675a02SAndroid Build Coastguard Worker       },
52*14675a02SAndroid Build Coastguard Worker       /*create_iterator_func=*/
53*14675a02SAndroid Build Coastguard Worker       [task_env, selector_context](
54*14675a02SAndroid Build Coastguard Worker           const google::internal::federated::plan::ExampleSelector&
55*14675a02SAndroid Build Coastguard Worker               example_selector) {
56*14675a02SAndroid Build Coastguard Worker         return task_env->CreateExampleIterator(example_selector,
57*14675a02SAndroid Build Coastguard Worker                                                selector_context);
58*14675a02SAndroid Build Coastguard Worker       },
59*14675a02SAndroid Build Coastguard Worker       /*should_collect_stats=*/true);
60*14675a02SAndroid Build Coastguard Worker }
61*14675a02SAndroid Build Coastguard Worker 
ConstructTFLiteInputsForTensorflowSpecPlan(const FederatedComputeIORouter & io_router,const std::string & checkpoint_input_filename,const std::string & checkpoint_output_filename)62*14675a02SAndroid Build Coastguard Worker std::unique_ptr<TfLiteInputs> ConstructTFLiteInputsForTensorflowSpecPlan(
63*14675a02SAndroid Build Coastguard Worker     const FederatedComputeIORouter& io_router,
64*14675a02SAndroid Build Coastguard Worker     const std::string& checkpoint_input_filename,
65*14675a02SAndroid Build Coastguard Worker     const std::string& checkpoint_output_filename) {
66*14675a02SAndroid Build Coastguard Worker   auto inputs = std::make_unique<TfLiteInputs>();
67*14675a02SAndroid Build Coastguard Worker   if (!io_router.input_filepath_tensor_name().empty()) {
68*14675a02SAndroid Build Coastguard Worker     (*inputs)[io_router.input_filepath_tensor_name()] =
69*14675a02SAndroid Build Coastguard Worker         checkpoint_input_filename;
70*14675a02SAndroid Build Coastguard Worker   }
71*14675a02SAndroid Build Coastguard Worker 
72*14675a02SAndroid Build Coastguard Worker   if (!io_router.output_filepath_tensor_name().empty()) {
73*14675a02SAndroid Build Coastguard Worker     (*inputs)[io_router.output_filepath_tensor_name()] =
74*14675a02SAndroid Build Coastguard Worker         checkpoint_output_filename;
75*14675a02SAndroid Build Coastguard Worker   }
76*14675a02SAndroid Build Coastguard Worker 
77*14675a02SAndroid Build Coastguard Worker   return inputs;
78*14675a02SAndroid Build Coastguard Worker }
79*14675a02SAndroid Build Coastguard Worker 
ConstructOutputsWithDeterministicOrder(const TensorflowSpec & tensorflow_spec,const FederatedComputeIORouter & io_router)80*14675a02SAndroid Build Coastguard Worker absl::StatusOr<std::vector<std::string>> ConstructOutputsWithDeterministicOrder(
81*14675a02SAndroid Build Coastguard Worker     const TensorflowSpec& tensorflow_spec,
82*14675a02SAndroid Build Coastguard Worker     const FederatedComputeIORouter& io_router) {
83*14675a02SAndroid Build Coastguard Worker   std::vector<std::string> output_names;
84*14675a02SAndroid Build Coastguard Worker   // The order of output tensor names should match the order in TensorflowSpec.
85*14675a02SAndroid Build Coastguard Worker   for (const auto& output_tensor_spec : tensorflow_spec.output_tensor_specs()) {
86*14675a02SAndroid Build Coastguard Worker     std::string tensor_name = output_tensor_spec.name();
87*14675a02SAndroid Build Coastguard Worker     if (!io_router.aggregations().contains(tensor_name) ||
88*14675a02SAndroid Build Coastguard Worker         !io_router.aggregations().at(tensor_name).has_secure_aggregation()) {
89*14675a02SAndroid Build Coastguard Worker       return absl::InvalidArgumentError(
90*14675a02SAndroid Build Coastguard Worker           "Output tensor is missing in AggregationConfig, or has unsupported "
91*14675a02SAndroid Build Coastguard Worker           "aggregation type.");
92*14675a02SAndroid Build Coastguard Worker     }
93*14675a02SAndroid Build Coastguard Worker     output_names.push_back(tensor_name);
94*14675a02SAndroid Build Coastguard Worker   }
95*14675a02SAndroid Build Coastguard Worker 
96*14675a02SAndroid Build Coastguard Worker   return output_names;
97*14675a02SAndroid Build Coastguard Worker }
98*14675a02SAndroid Build Coastguard Worker 
99*14675a02SAndroid Build Coastguard Worker struct PlanResultAndCheckpointFile {
PlanResultAndCheckpointFilefcp::client::__anonb8072fe00111::PlanResultAndCheckpointFile100*14675a02SAndroid Build Coastguard Worker   explicit PlanResultAndCheckpointFile(engine::PlanResult plan_result)
101*14675a02SAndroid Build Coastguard Worker       : plan_result(std::move(plan_result)) {}
102*14675a02SAndroid Build Coastguard Worker   engine::PlanResult plan_result;
103*14675a02SAndroid Build Coastguard Worker   std::string checkpoint_file;
104*14675a02SAndroid Build Coastguard Worker 
105*14675a02SAndroid Build Coastguard Worker   PlanResultAndCheckpointFile(PlanResultAndCheckpointFile&&) = default;
106*14675a02SAndroid Build Coastguard Worker   PlanResultAndCheckpointFile& operator=(PlanResultAndCheckpointFile&&) =
107*14675a02SAndroid Build Coastguard Worker       default;
108*14675a02SAndroid Build Coastguard Worker 
109*14675a02SAndroid Build Coastguard Worker   // Disallow copy and assign.
110*14675a02SAndroid Build Coastguard Worker   PlanResultAndCheckpointFile(const PlanResultAndCheckpointFile&) = delete;
111*14675a02SAndroid Build Coastguard Worker   PlanResultAndCheckpointFile& operator=(const PlanResultAndCheckpointFile&) =
112*14675a02SAndroid Build Coastguard Worker       delete;
113*14675a02SAndroid Build Coastguard Worker };
114*14675a02SAndroid Build Coastguard Worker 
RunPlanWithExampleQuerySpec(std::vector<engine::ExampleIteratorFactory * > example_iterator_factories,OpStatsLogger * opstats_logger,const Flags * flags,const ClientOnlyPlan & client_plan,const std::string & checkpoint_output_filename)115*14675a02SAndroid Build Coastguard Worker PlanResultAndCheckpointFile RunPlanWithExampleQuerySpec(
116*14675a02SAndroid Build Coastguard Worker     std::vector<engine::ExampleIteratorFactory*> example_iterator_factories,
117*14675a02SAndroid Build Coastguard Worker     OpStatsLogger* opstats_logger, const Flags* flags,
118*14675a02SAndroid Build Coastguard Worker     const ClientOnlyPlan& client_plan,
119*14675a02SAndroid Build Coastguard Worker     const std::string& checkpoint_output_filename) {
120*14675a02SAndroid Build Coastguard Worker   if (!client_plan.phase().has_example_query_spec()) {
121*14675a02SAndroid Build Coastguard Worker     return PlanResultAndCheckpointFile(engine::PlanResult(
122*14675a02SAndroid Build Coastguard Worker         engine::PlanOutcome::kInvalidArgument,
123*14675a02SAndroid Build Coastguard Worker         absl::InvalidArgumentError("Plan must include ExampleQuerySpec")));
124*14675a02SAndroid Build Coastguard Worker   }
125*14675a02SAndroid Build Coastguard Worker   if (!flags->enable_example_query_plan_engine()) {
126*14675a02SAndroid Build Coastguard Worker     // Example query plan received while the flag is off.
127*14675a02SAndroid Build Coastguard Worker     return PlanResultAndCheckpointFile(engine::PlanResult(
128*14675a02SAndroid Build Coastguard Worker         engine::PlanOutcome::kInvalidArgument,
129*14675a02SAndroid Build Coastguard Worker         absl::InvalidArgumentError(
130*14675a02SAndroid Build Coastguard Worker             "Example query plan received while the flag is off")));
131*14675a02SAndroid Build Coastguard Worker   }
132*14675a02SAndroid Build Coastguard Worker   if (!client_plan.phase().has_federated_example_query()) {
133*14675a02SAndroid Build Coastguard Worker     return PlanResultAndCheckpointFile(engine::PlanResult(
134*14675a02SAndroid Build Coastguard Worker         engine::PlanOutcome::kInvalidArgument,
135*14675a02SAndroid Build Coastguard Worker         absl::InvalidArgumentError("Invalid ExampleQuerySpec-based plan")));
136*14675a02SAndroid Build Coastguard Worker   }
137*14675a02SAndroid Build Coastguard Worker   for (const auto& example_query :
138*14675a02SAndroid Build Coastguard Worker        client_plan.phase().example_query_spec().example_queries()) {
139*14675a02SAndroid Build Coastguard Worker     for (auto const& [vector_name, spec] :
140*14675a02SAndroid Build Coastguard Worker          example_query.output_vector_specs()) {
141*14675a02SAndroid Build Coastguard Worker       const auto& aggregations =
142*14675a02SAndroid Build Coastguard Worker           client_plan.phase().federated_example_query().aggregations();
143*14675a02SAndroid Build Coastguard Worker       if ((aggregations.find(vector_name) == aggregations.end()) ||
144*14675a02SAndroid Build Coastguard Worker           !aggregations.at(vector_name).has_tf_v1_checkpoint_aggregation()) {
145*14675a02SAndroid Build Coastguard Worker         return PlanResultAndCheckpointFile(engine::PlanResult(
146*14675a02SAndroid Build Coastguard Worker             engine::PlanOutcome::kInvalidArgument,
147*14675a02SAndroid Build Coastguard Worker             absl::InvalidArgumentError("Output vector is missing in "
148*14675a02SAndroid Build Coastguard Worker                                        "AggregationConfig, or has unsupported "
149*14675a02SAndroid Build Coastguard Worker                                        "aggregation type.")));
150*14675a02SAndroid Build Coastguard Worker       }
151*14675a02SAndroid Build Coastguard Worker     }
152*14675a02SAndroid Build Coastguard Worker   }
153*14675a02SAndroid Build Coastguard Worker 
154*14675a02SAndroid Build Coastguard Worker   engine::ExampleQueryPlanEngine plan_engine(example_iterator_factories,
155*14675a02SAndroid Build Coastguard Worker                                              opstats_logger);
156*14675a02SAndroid Build Coastguard Worker   engine::PlanResult plan_result = plan_engine.RunPlan(
157*14675a02SAndroid Build Coastguard Worker       client_plan.phase().example_query_spec(), checkpoint_output_filename);
158*14675a02SAndroid Build Coastguard Worker   PlanResultAndCheckpointFile result(std::move(plan_result));
159*14675a02SAndroid Build Coastguard Worker   result.checkpoint_file = checkpoint_output_filename;
160*14675a02SAndroid Build Coastguard Worker   return result;
161*14675a02SAndroid Build Coastguard Worker }
162*14675a02SAndroid Build Coastguard Worker 
RunPlanWithTensorflowSpec(std::vector<engine::ExampleIteratorFactory * > example_iterator_factories,std::function<bool ()> should_abort,LogManager * log_manager,OpStatsLogger * opstats_logger,const Flags * flags,const ClientOnlyPlan & client_plan,const std::string & checkpoint_input_filename,const std::string & checkpoint_output_filename,const fcp::client::InterruptibleRunner::TimingConfig & timing_config)163*14675a02SAndroid Build Coastguard Worker PlanResultAndCheckpointFile RunPlanWithTensorflowSpec(
164*14675a02SAndroid Build Coastguard Worker     std::vector<engine::ExampleIteratorFactory*> example_iterator_factories,
165*14675a02SAndroid Build Coastguard Worker     std::function<bool()> should_abort, LogManager* log_manager,
166*14675a02SAndroid Build Coastguard Worker     OpStatsLogger* opstats_logger, const Flags* flags,
167*14675a02SAndroid Build Coastguard Worker     const ClientOnlyPlan& client_plan,
168*14675a02SAndroid Build Coastguard Worker     const std::string& checkpoint_input_filename,
169*14675a02SAndroid Build Coastguard Worker     const std::string& checkpoint_output_filename,
170*14675a02SAndroid Build Coastguard Worker     const fcp::client::InterruptibleRunner::TimingConfig& timing_config) {
171*14675a02SAndroid Build Coastguard Worker   if (!client_plan.phase().has_tensorflow_spec()) {
172*14675a02SAndroid Build Coastguard Worker     return PlanResultAndCheckpointFile(engine::PlanResult(
173*14675a02SAndroid Build Coastguard Worker         engine::PlanOutcome::kInvalidArgument,
174*14675a02SAndroid Build Coastguard Worker         absl::InvalidArgumentError("Plan must include TensorflowSpec.")));
175*14675a02SAndroid Build Coastguard Worker   }
176*14675a02SAndroid Build Coastguard Worker   if (!client_plan.phase().has_federated_compute()) {
177*14675a02SAndroid Build Coastguard Worker     return PlanResultAndCheckpointFile(engine::PlanResult(
178*14675a02SAndroid Build Coastguard Worker         engine::PlanOutcome::kInvalidArgument,
179*14675a02SAndroid Build Coastguard Worker         absl::InvalidArgumentError("Invalid TensorflowSpec-based plan")));
180*14675a02SAndroid Build Coastguard Worker   }
181*14675a02SAndroid Build Coastguard Worker 
182*14675a02SAndroid Build Coastguard Worker   // Get the output tensor names.
183*14675a02SAndroid Build Coastguard Worker   absl::StatusOr<std::vector<std::string>> output_names;
184*14675a02SAndroid Build Coastguard Worker   output_names = ConstructOutputsWithDeterministicOrder(
185*14675a02SAndroid Build Coastguard Worker       client_plan.phase().tensorflow_spec(),
186*14675a02SAndroid Build Coastguard Worker       client_plan.phase().federated_compute());
187*14675a02SAndroid Build Coastguard Worker   if (!output_names.ok()) {
188*14675a02SAndroid Build Coastguard Worker     return PlanResultAndCheckpointFile(engine::PlanResult(
189*14675a02SAndroid Build Coastguard Worker         engine::PlanOutcome::kInvalidArgument, output_names.status()));
190*14675a02SAndroid Build Coastguard Worker   }
191*14675a02SAndroid Build Coastguard Worker 
192*14675a02SAndroid Build Coastguard Worker   // Run plan and get a set of output tensors back.
193*14675a02SAndroid Build Coastguard Worker   if (flags->use_tflite_training() && !client_plan.tflite_graph().empty()) {
194*14675a02SAndroid Build Coastguard Worker     std::unique_ptr<TfLiteInputs> tflite_inputs =
195*14675a02SAndroid Build Coastguard Worker         ConstructTFLiteInputsForTensorflowSpecPlan(
196*14675a02SAndroid Build Coastguard Worker             client_plan.phase().federated_compute(), checkpoint_input_filename,
197*14675a02SAndroid Build Coastguard Worker             checkpoint_output_filename);
198*14675a02SAndroid Build Coastguard Worker     engine::TfLitePlanEngine plan_engine(example_iterator_factories,
199*14675a02SAndroid Build Coastguard Worker                                          should_abort, log_manager,
200*14675a02SAndroid Build Coastguard Worker                                          opstats_logger, flags, &timing_config);
201*14675a02SAndroid Build Coastguard Worker     engine::PlanResult plan_result = plan_engine.RunPlan(
202*14675a02SAndroid Build Coastguard Worker         client_plan.phase().tensorflow_spec(), client_plan.tflite_graph(),
203*14675a02SAndroid Build Coastguard Worker         std::move(tflite_inputs), *output_names);
204*14675a02SAndroid Build Coastguard Worker     PlanResultAndCheckpointFile result(std::move(plan_result));
205*14675a02SAndroid Build Coastguard Worker     result.checkpoint_file = checkpoint_output_filename;
206*14675a02SAndroid Build Coastguard Worker 
207*14675a02SAndroid Build Coastguard Worker     return result;
208*14675a02SAndroid Build Coastguard Worker   }
209*14675a02SAndroid Build Coastguard Worker 
210*14675a02SAndroid Build Coastguard Worker   return PlanResultAndCheckpointFile(
211*14675a02SAndroid Build Coastguard Worker       engine::PlanResult(engine::PlanOutcome::kTensorflowError,
212*14675a02SAndroid Build Coastguard Worker                          absl::InternalError("No plan engine enabled")));
213*14675a02SAndroid Build Coastguard Worker }
214*14675a02SAndroid Build Coastguard Worker }  // namespace
215*14675a02SAndroid Build Coastguard Worker 
RunFederatedComputation(SimpleTaskEnvironment * env_deps,LogManager * log_manager,const Flags * flags,const google::internal::federated::plan::ClientOnlyPlan & client_plan,const std::string & checkpoint_input_filename,const std::string & checkpoint_output_filename,const std::string & session_name,const std::string & population_name,const std::string & task_name,const fcp::client::InterruptibleRunner::TimingConfig & timing_config)216*14675a02SAndroid Build Coastguard Worker absl::StatusOr<FLRunnerResult> RunFederatedComputation(
217*14675a02SAndroid Build Coastguard Worker     SimpleTaskEnvironment* env_deps, LogManager* log_manager,
218*14675a02SAndroid Build Coastguard Worker     const Flags* flags,
219*14675a02SAndroid Build Coastguard Worker     const google::internal::federated::plan::ClientOnlyPlan& client_plan,
220*14675a02SAndroid Build Coastguard Worker     const std::string& checkpoint_input_filename,
221*14675a02SAndroid Build Coastguard Worker     const std::string& checkpoint_output_filename,
222*14675a02SAndroid Build Coastguard Worker     const std::string& session_name, const std::string& population_name,
223*14675a02SAndroid Build Coastguard Worker     const std::string& task_name,
224*14675a02SAndroid Build Coastguard Worker     const fcp::client::InterruptibleRunner::TimingConfig& timing_config) {
225*14675a02SAndroid Build Coastguard Worker   SelectorContext federated_selector_context;
226*14675a02SAndroid Build Coastguard Worker   federated_selector_context.mutable_computation_properties()->set_session_name(
227*14675a02SAndroid Build Coastguard Worker       session_name);
228*14675a02SAndroid Build Coastguard Worker   FederatedComputation federated_computation;
229*14675a02SAndroid Build Coastguard Worker   federated_computation.set_population_name(population_name);
230*14675a02SAndroid Build Coastguard Worker   *federated_selector_context.mutable_computation_properties()
231*14675a02SAndroid Build Coastguard Worker        ->mutable_federated() = federated_computation;
232*14675a02SAndroid Build Coastguard Worker   federated_selector_context.mutable_computation_properties()
233*14675a02SAndroid Build Coastguard Worker       ->mutable_federated()
234*14675a02SAndroid Build Coastguard Worker       ->set_task_name(task_name);
235*14675a02SAndroid Build Coastguard Worker   if (client_plan.phase().has_example_query_spec()) {
236*14675a02SAndroid Build Coastguard Worker     federated_selector_context.mutable_computation_properties()
237*14675a02SAndroid Build Coastguard Worker         ->set_example_iterator_output_format(
238*14675a02SAndroid Build Coastguard Worker             ::fcp::client::QueryTimeComputationProperties::
239*14675a02SAndroid Build Coastguard Worker                 EXAMPLE_QUERY_RESULT);
240*14675a02SAndroid Build Coastguard Worker   } else {
241*14675a02SAndroid Build Coastguard Worker     const auto& federated_compute_io_router =
242*14675a02SAndroid Build Coastguard Worker         client_plan.phase().federated_compute();
243*14675a02SAndroid Build Coastguard Worker     const bool has_simpleagg_tensors =
244*14675a02SAndroid Build Coastguard Worker         !federated_compute_io_router.output_filepath_tensor_name().empty();
245*14675a02SAndroid Build Coastguard Worker     bool all_aggregations_are_secagg = true;
246*14675a02SAndroid Build Coastguard Worker     for (const auto& aggregation : federated_compute_io_router.aggregations()) {
247*14675a02SAndroid Build Coastguard Worker       all_aggregations_are_secagg &=
248*14675a02SAndroid Build Coastguard Worker           aggregation.second.protocol_config_case() ==
249*14675a02SAndroid Build Coastguard Worker           AggregationConfig::kSecureAggregation;
250*14675a02SAndroid Build Coastguard Worker     }
251*14675a02SAndroid Build Coastguard Worker     if (!has_simpleagg_tensors && all_aggregations_are_secagg) {
252*14675a02SAndroid Build Coastguard Worker       federated_selector_context.mutable_computation_properties()
253*14675a02SAndroid Build Coastguard Worker           ->mutable_federated()
254*14675a02SAndroid Build Coastguard Worker           ->mutable_secure_aggregation()
255*14675a02SAndroid Build Coastguard Worker           ->set_minimum_clients_in_server_visible_aggregate(100);
256*14675a02SAndroid Build Coastguard Worker     } else {
257*14675a02SAndroid Build Coastguard Worker       // Has an output checkpoint, so some tensors must be simply aggregated.
258*14675a02SAndroid Build Coastguard Worker       *(federated_selector_context.mutable_computation_properties()
259*14675a02SAndroid Build Coastguard Worker             ->mutable_federated()
260*14675a02SAndroid Build Coastguard Worker             ->mutable_simple_aggregation()) = SimpleAggregation();
261*14675a02SAndroid Build Coastguard Worker     }
262*14675a02SAndroid Build Coastguard Worker   }
263*14675a02SAndroid Build Coastguard Worker 
264*14675a02SAndroid Build Coastguard Worker   auto opstats_logger =
265*14675a02SAndroid Build Coastguard Worker       engine::CreateOpStatsLogger(env_deps->GetBaseDir(), flags, log_manager,
266*14675a02SAndroid Build Coastguard Worker                                   session_name, population_name);
267*14675a02SAndroid Build Coastguard Worker 
268*14675a02SAndroid Build Coastguard Worker   // Check if the device conditions allow for checking in with the server
269*14675a02SAndroid Build Coastguard Worker   // and running a federated computation. If not, bail early with the
270*14675a02SAndroid Build Coastguard Worker   // transient error retry window.
271*14675a02SAndroid Build Coastguard Worker   std::function<bool()> should_abort = [env_deps, &timing_config]() {
272*14675a02SAndroid Build Coastguard Worker     return env_deps->ShouldAbort(absl::Now(), timing_config.polling_period);
273*14675a02SAndroid Build Coastguard Worker   };
274*14675a02SAndroid Build Coastguard Worker 
275*14675a02SAndroid Build Coastguard Worker   // Regular plans can use example iterators from the SimpleTaskEnvironment,
276*14675a02SAndroid Build Coastguard Worker   // those reading the OpStats DB, or those serving Federated Select slices.
277*14675a02SAndroid Build Coastguard Worker   std::unique_ptr<engine::ExampleIteratorFactory> env_example_iterator_factory =
278*14675a02SAndroid Build Coastguard Worker       CreateSimpleTaskEnvironmentIteratorFactory(env_deps,
279*14675a02SAndroid Build Coastguard Worker                                                  federated_selector_context);
280*14675a02SAndroid Build Coastguard Worker   std::vector<engine::ExampleIteratorFactory*> example_iterator_factories{
281*14675a02SAndroid Build Coastguard Worker       env_example_iterator_factory.get()};
282*14675a02SAndroid Build Coastguard Worker   PlanResultAndCheckpointFile plan_result_and_checkpoint_file =
283*14675a02SAndroid Build Coastguard Worker       client_plan.phase().has_example_query_spec()
284*14675a02SAndroid Build Coastguard Worker           ? RunPlanWithExampleQuerySpec(example_iterator_factories,
285*14675a02SAndroid Build Coastguard Worker                                         opstats_logger.get(), flags,
286*14675a02SAndroid Build Coastguard Worker                                         client_plan, checkpoint_output_filename)
287*14675a02SAndroid Build Coastguard Worker           : RunPlanWithTensorflowSpec(example_iterator_factories, should_abort,
288*14675a02SAndroid Build Coastguard Worker                                       log_manager, opstats_logger.get(), flags,
289*14675a02SAndroid Build Coastguard Worker                                       client_plan, checkpoint_input_filename,
290*14675a02SAndroid Build Coastguard Worker                                       checkpoint_output_filename,
291*14675a02SAndroid Build Coastguard Worker                                       timing_config);
292*14675a02SAndroid Build Coastguard Worker   auto outcome = plan_result_and_checkpoint_file.plan_result.outcome;
293*14675a02SAndroid Build Coastguard Worker   FLRunnerResult fl_runner_result;
294*14675a02SAndroid Build Coastguard Worker 
295*14675a02SAndroid Build Coastguard Worker   if (outcome == engine::PlanOutcome::kSuccess) {
296*14675a02SAndroid Build Coastguard Worker     fl_runner_result.set_contribution_result(FLRunnerResult::SUCCESS);
297*14675a02SAndroid Build Coastguard Worker   } else {
298*14675a02SAndroid Build Coastguard Worker     switch (outcome) {
299*14675a02SAndroid Build Coastguard Worker       case engine::PlanOutcome::kInvalidArgument:
300*14675a02SAndroid Build Coastguard Worker         fl_runner_result.set_error_status(FLRunnerResult::INVALID_ARGUMENT);
301*14675a02SAndroid Build Coastguard Worker         break;
302*14675a02SAndroid Build Coastguard Worker       case engine::PlanOutcome::kTensorflowError:
303*14675a02SAndroid Build Coastguard Worker         fl_runner_result.set_error_status(FLRunnerResult::TENSORFLOW_ERROR);
304*14675a02SAndroid Build Coastguard Worker         break;
305*14675a02SAndroid Build Coastguard Worker       case engine::PlanOutcome::kExampleIteratorError:
306*14675a02SAndroid Build Coastguard Worker         fl_runner_result.set_error_status(
307*14675a02SAndroid Build Coastguard Worker             FLRunnerResult::EXAMPLE_ITERATOR_ERROR);
308*14675a02SAndroid Build Coastguard Worker         break;
309*14675a02SAndroid Build Coastguard Worker       default:
310*14675a02SAndroid Build Coastguard Worker         break;
311*14675a02SAndroid Build Coastguard Worker     }
312*14675a02SAndroid Build Coastguard Worker     fl_runner_result.set_contribution_result(FLRunnerResult::FAIL);
313*14675a02SAndroid Build Coastguard Worker     std::string error_message = std::string{
314*14675a02SAndroid Build Coastguard Worker         plan_result_and_checkpoint_file.plan_result.original_status.message()};
315*14675a02SAndroid Build Coastguard Worker     fl_runner_result.set_error_message(error_message);
316*14675a02SAndroid Build Coastguard Worker   }
317*14675a02SAndroid Build Coastguard Worker 
318*14675a02SAndroid Build Coastguard Worker   FLRunnerResult::ExampleStats example_stats;
319*14675a02SAndroid Build Coastguard Worker   example_stats.set_example_count(
320*14675a02SAndroid Build Coastguard Worker       plan_result_and_checkpoint_file.plan_result.example_stats.example_count);
321*14675a02SAndroid Build Coastguard Worker   example_stats.set_example_size_bytes(
322*14675a02SAndroid Build Coastguard Worker       plan_result_and_checkpoint_file.plan_result.example_stats
323*14675a02SAndroid Build Coastguard Worker           .example_size_bytes);
324*14675a02SAndroid Build Coastguard Worker 
325*14675a02SAndroid Build Coastguard Worker   *fl_runner_result.mutable_example_stats() = example_stats;
326*14675a02SAndroid Build Coastguard Worker 
327*14675a02SAndroid Build Coastguard Worker   return fl_runner_result;
328*14675a02SAndroid Build Coastguard Worker }
329*14675a02SAndroid Build Coastguard Worker 
330*14675a02SAndroid Build Coastguard Worker }  // namespace client
331*14675a02SAndroid Build Coastguard Worker }  // namespace fcp