1*14675a02SAndroid Build Coastguard Worker /*
2*14675a02SAndroid Build Coastguard Worker * Copyright 2023 Google LLC
3*14675a02SAndroid Build Coastguard Worker *
4*14675a02SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License");
5*14675a02SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License.
6*14675a02SAndroid Build Coastguard Worker * You may obtain a copy of the License at
7*14675a02SAndroid Build Coastguard Worker *
8*14675a02SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0
9*14675a02SAndroid Build Coastguard Worker *
10*14675a02SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software
11*14675a02SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS,
12*14675a02SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*14675a02SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and
14*14675a02SAndroid Build Coastguard Worker * limitations under the License.
15*14675a02SAndroid Build Coastguard Worker */
16*14675a02SAndroid Build Coastguard Worker
17*14675a02SAndroid Build Coastguard Worker #include "fcp/client/fcp_runner.h"
18*14675a02SAndroid Build Coastguard Worker
19*14675a02SAndroid Build Coastguard Worker #include "fcp/client/engine/example_iterator_factory.h"
20*14675a02SAndroid Build Coastguard Worker #include "fcp/client/engine/example_query_plan_engine.h"
21*14675a02SAndroid Build Coastguard Worker #include "fcp/client/engine/plan_engine_helpers.h"
22*14675a02SAndroid Build Coastguard Worker #include "fcp/client/engine/tflite_plan_engine.h"
23*14675a02SAndroid Build Coastguard Worker #include "fcp/client/fl_runner.pb.h"
24*14675a02SAndroid Build Coastguard Worker #include "fcp/client/opstats/opstats_logger.h"
25*14675a02SAndroid Build Coastguard Worker #include "fcp/protos/plan.pb.h"
26*14675a02SAndroid Build Coastguard Worker
27*14675a02SAndroid Build Coastguard Worker namespace fcp {
28*14675a02SAndroid Build Coastguard Worker namespace client {
29*14675a02SAndroid Build Coastguard Worker
30*14675a02SAndroid Build Coastguard Worker using ::fcp::client::opstats::OpStatsLogger;
31*14675a02SAndroid Build Coastguard Worker using ::google::internal::federated::plan::AggregationConfig;
32*14675a02SAndroid Build Coastguard Worker using ::google::internal::federated::plan::ClientOnlyPlan;
33*14675a02SAndroid Build Coastguard Worker using ::google::internal::federated::plan::FederatedComputeIORouter;
34*14675a02SAndroid Build Coastguard Worker using ::google::internal::federated::plan::TensorflowSpec;
35*14675a02SAndroid Build Coastguard Worker
36*14675a02SAndroid Build Coastguard Worker using TfLiteInputs = absl::flat_hash_map<std::string, std::string>;
37*14675a02SAndroid Build Coastguard Worker namespace {
38*14675a02SAndroid Build Coastguard Worker
39*14675a02SAndroid Build Coastguard Worker // Creates an ExampleIteratorFactory that routes queries to the
40*14675a02SAndroid Build Coastguard Worker // SimpleTaskEnvironment::CreateExampleIterator() method.
41*14675a02SAndroid Build Coastguard Worker std::unique_ptr<engine::ExampleIteratorFactory>
CreateSimpleTaskEnvironmentIteratorFactory(SimpleTaskEnvironment * task_env,const SelectorContext & selector_context)42*14675a02SAndroid Build Coastguard Worker CreateSimpleTaskEnvironmentIteratorFactory(
43*14675a02SAndroid Build Coastguard Worker SimpleTaskEnvironment* task_env, const SelectorContext& selector_context) {
44*14675a02SAndroid Build Coastguard Worker return std::make_unique<engine::FunctionalExampleIteratorFactory>(
45*14675a02SAndroid Build Coastguard Worker /*can_handle_func=*/
46*14675a02SAndroid Build Coastguard Worker [](const google::internal::federated::plan::ExampleSelector&) {
47*14675a02SAndroid Build Coastguard Worker // The SimpleTaskEnvironment-based ExampleIteratorFactory should
48*14675a02SAndroid Build Coastguard Worker // be the catch-all factory that is able to handle all queries
49*14675a02SAndroid Build Coastguard Worker // that no other ExampleIteratorFactory is able to handle.
50*14675a02SAndroid Build Coastguard Worker return true;
51*14675a02SAndroid Build Coastguard Worker },
52*14675a02SAndroid Build Coastguard Worker /*create_iterator_func=*/
53*14675a02SAndroid Build Coastguard Worker [task_env, selector_context](
54*14675a02SAndroid Build Coastguard Worker const google::internal::federated::plan::ExampleSelector&
55*14675a02SAndroid Build Coastguard Worker example_selector) {
56*14675a02SAndroid Build Coastguard Worker return task_env->CreateExampleIterator(example_selector,
57*14675a02SAndroid Build Coastguard Worker selector_context);
58*14675a02SAndroid Build Coastguard Worker },
59*14675a02SAndroid Build Coastguard Worker /*should_collect_stats=*/true);
60*14675a02SAndroid Build Coastguard Worker }
61*14675a02SAndroid Build Coastguard Worker
ConstructTFLiteInputsForTensorflowSpecPlan(const FederatedComputeIORouter & io_router,const std::string & checkpoint_input_filename,const std::string & checkpoint_output_filename)62*14675a02SAndroid Build Coastguard Worker std::unique_ptr<TfLiteInputs> ConstructTFLiteInputsForTensorflowSpecPlan(
63*14675a02SAndroid Build Coastguard Worker const FederatedComputeIORouter& io_router,
64*14675a02SAndroid Build Coastguard Worker const std::string& checkpoint_input_filename,
65*14675a02SAndroid Build Coastguard Worker const std::string& checkpoint_output_filename) {
66*14675a02SAndroid Build Coastguard Worker auto inputs = std::make_unique<TfLiteInputs>();
67*14675a02SAndroid Build Coastguard Worker if (!io_router.input_filepath_tensor_name().empty()) {
68*14675a02SAndroid Build Coastguard Worker (*inputs)[io_router.input_filepath_tensor_name()] =
69*14675a02SAndroid Build Coastguard Worker checkpoint_input_filename;
70*14675a02SAndroid Build Coastguard Worker }
71*14675a02SAndroid Build Coastguard Worker
72*14675a02SAndroid Build Coastguard Worker if (!io_router.output_filepath_tensor_name().empty()) {
73*14675a02SAndroid Build Coastguard Worker (*inputs)[io_router.output_filepath_tensor_name()] =
74*14675a02SAndroid Build Coastguard Worker checkpoint_output_filename;
75*14675a02SAndroid Build Coastguard Worker }
76*14675a02SAndroid Build Coastguard Worker
77*14675a02SAndroid Build Coastguard Worker return inputs;
78*14675a02SAndroid Build Coastguard Worker }
79*14675a02SAndroid Build Coastguard Worker
ConstructOutputsWithDeterministicOrder(const TensorflowSpec & tensorflow_spec,const FederatedComputeIORouter & io_router)80*14675a02SAndroid Build Coastguard Worker absl::StatusOr<std::vector<std::string>> ConstructOutputsWithDeterministicOrder(
81*14675a02SAndroid Build Coastguard Worker const TensorflowSpec& tensorflow_spec,
82*14675a02SAndroid Build Coastguard Worker const FederatedComputeIORouter& io_router) {
83*14675a02SAndroid Build Coastguard Worker std::vector<std::string> output_names;
84*14675a02SAndroid Build Coastguard Worker // The order of output tensor names should match the order in TensorflowSpec.
85*14675a02SAndroid Build Coastguard Worker for (const auto& output_tensor_spec : tensorflow_spec.output_tensor_specs()) {
86*14675a02SAndroid Build Coastguard Worker std::string tensor_name = output_tensor_spec.name();
87*14675a02SAndroid Build Coastguard Worker if (!io_router.aggregations().contains(tensor_name) ||
88*14675a02SAndroid Build Coastguard Worker !io_router.aggregations().at(tensor_name).has_secure_aggregation()) {
89*14675a02SAndroid Build Coastguard Worker return absl::InvalidArgumentError(
90*14675a02SAndroid Build Coastguard Worker "Output tensor is missing in AggregationConfig, or has unsupported "
91*14675a02SAndroid Build Coastguard Worker "aggregation type.");
92*14675a02SAndroid Build Coastguard Worker }
93*14675a02SAndroid Build Coastguard Worker output_names.push_back(tensor_name);
94*14675a02SAndroid Build Coastguard Worker }
95*14675a02SAndroid Build Coastguard Worker
96*14675a02SAndroid Build Coastguard Worker return output_names;
97*14675a02SAndroid Build Coastguard Worker }
98*14675a02SAndroid Build Coastguard Worker
99*14675a02SAndroid Build Coastguard Worker struct PlanResultAndCheckpointFile {
PlanResultAndCheckpointFilefcp::client::__anonb8072fe00111::PlanResultAndCheckpointFile100*14675a02SAndroid Build Coastguard Worker explicit PlanResultAndCheckpointFile(engine::PlanResult plan_result)
101*14675a02SAndroid Build Coastguard Worker : plan_result(std::move(plan_result)) {}
102*14675a02SAndroid Build Coastguard Worker engine::PlanResult plan_result;
103*14675a02SAndroid Build Coastguard Worker std::string checkpoint_file;
104*14675a02SAndroid Build Coastguard Worker
105*14675a02SAndroid Build Coastguard Worker PlanResultAndCheckpointFile(PlanResultAndCheckpointFile&&) = default;
106*14675a02SAndroid Build Coastguard Worker PlanResultAndCheckpointFile& operator=(PlanResultAndCheckpointFile&&) =
107*14675a02SAndroid Build Coastguard Worker default;
108*14675a02SAndroid Build Coastguard Worker
109*14675a02SAndroid Build Coastguard Worker // Disallow copy and assign.
110*14675a02SAndroid Build Coastguard Worker PlanResultAndCheckpointFile(const PlanResultAndCheckpointFile&) = delete;
111*14675a02SAndroid Build Coastguard Worker PlanResultAndCheckpointFile& operator=(const PlanResultAndCheckpointFile&) =
112*14675a02SAndroid Build Coastguard Worker delete;
113*14675a02SAndroid Build Coastguard Worker };
114*14675a02SAndroid Build Coastguard Worker
RunPlanWithExampleQuerySpec(std::vector<engine::ExampleIteratorFactory * > example_iterator_factories,OpStatsLogger * opstats_logger,const Flags * flags,const ClientOnlyPlan & client_plan,const std::string & checkpoint_output_filename)115*14675a02SAndroid Build Coastguard Worker PlanResultAndCheckpointFile RunPlanWithExampleQuerySpec(
116*14675a02SAndroid Build Coastguard Worker std::vector<engine::ExampleIteratorFactory*> example_iterator_factories,
117*14675a02SAndroid Build Coastguard Worker OpStatsLogger* opstats_logger, const Flags* flags,
118*14675a02SAndroid Build Coastguard Worker const ClientOnlyPlan& client_plan,
119*14675a02SAndroid Build Coastguard Worker const std::string& checkpoint_output_filename) {
120*14675a02SAndroid Build Coastguard Worker if (!client_plan.phase().has_example_query_spec()) {
121*14675a02SAndroid Build Coastguard Worker return PlanResultAndCheckpointFile(engine::PlanResult(
122*14675a02SAndroid Build Coastguard Worker engine::PlanOutcome::kInvalidArgument,
123*14675a02SAndroid Build Coastguard Worker absl::InvalidArgumentError("Plan must include ExampleQuerySpec")));
124*14675a02SAndroid Build Coastguard Worker }
125*14675a02SAndroid Build Coastguard Worker if (!flags->enable_example_query_plan_engine()) {
126*14675a02SAndroid Build Coastguard Worker // Example query plan received while the flag is off.
127*14675a02SAndroid Build Coastguard Worker return PlanResultAndCheckpointFile(engine::PlanResult(
128*14675a02SAndroid Build Coastguard Worker engine::PlanOutcome::kInvalidArgument,
129*14675a02SAndroid Build Coastguard Worker absl::InvalidArgumentError(
130*14675a02SAndroid Build Coastguard Worker "Example query plan received while the flag is off")));
131*14675a02SAndroid Build Coastguard Worker }
132*14675a02SAndroid Build Coastguard Worker if (!client_plan.phase().has_federated_example_query()) {
133*14675a02SAndroid Build Coastguard Worker return PlanResultAndCheckpointFile(engine::PlanResult(
134*14675a02SAndroid Build Coastguard Worker engine::PlanOutcome::kInvalidArgument,
135*14675a02SAndroid Build Coastguard Worker absl::InvalidArgumentError("Invalid ExampleQuerySpec-based plan")));
136*14675a02SAndroid Build Coastguard Worker }
137*14675a02SAndroid Build Coastguard Worker for (const auto& example_query :
138*14675a02SAndroid Build Coastguard Worker client_plan.phase().example_query_spec().example_queries()) {
139*14675a02SAndroid Build Coastguard Worker for (auto const& [vector_name, spec] :
140*14675a02SAndroid Build Coastguard Worker example_query.output_vector_specs()) {
141*14675a02SAndroid Build Coastguard Worker const auto& aggregations =
142*14675a02SAndroid Build Coastguard Worker client_plan.phase().federated_example_query().aggregations();
143*14675a02SAndroid Build Coastguard Worker if ((aggregations.find(vector_name) == aggregations.end()) ||
144*14675a02SAndroid Build Coastguard Worker !aggregations.at(vector_name).has_tf_v1_checkpoint_aggregation()) {
145*14675a02SAndroid Build Coastguard Worker return PlanResultAndCheckpointFile(engine::PlanResult(
146*14675a02SAndroid Build Coastguard Worker engine::PlanOutcome::kInvalidArgument,
147*14675a02SAndroid Build Coastguard Worker absl::InvalidArgumentError("Output vector is missing in "
148*14675a02SAndroid Build Coastguard Worker "AggregationConfig, or has unsupported "
149*14675a02SAndroid Build Coastguard Worker "aggregation type.")));
150*14675a02SAndroid Build Coastguard Worker }
151*14675a02SAndroid Build Coastguard Worker }
152*14675a02SAndroid Build Coastguard Worker }
153*14675a02SAndroid Build Coastguard Worker
154*14675a02SAndroid Build Coastguard Worker engine::ExampleQueryPlanEngine plan_engine(example_iterator_factories,
155*14675a02SAndroid Build Coastguard Worker opstats_logger);
156*14675a02SAndroid Build Coastguard Worker engine::PlanResult plan_result = plan_engine.RunPlan(
157*14675a02SAndroid Build Coastguard Worker client_plan.phase().example_query_spec(), checkpoint_output_filename);
158*14675a02SAndroid Build Coastguard Worker PlanResultAndCheckpointFile result(std::move(plan_result));
159*14675a02SAndroid Build Coastguard Worker result.checkpoint_file = checkpoint_output_filename;
160*14675a02SAndroid Build Coastguard Worker return result;
161*14675a02SAndroid Build Coastguard Worker }
162*14675a02SAndroid Build Coastguard Worker
RunPlanWithTensorflowSpec(std::vector<engine::ExampleIteratorFactory * > example_iterator_factories,std::function<bool ()> should_abort,LogManager * log_manager,OpStatsLogger * opstats_logger,const Flags * flags,const ClientOnlyPlan & client_plan,const std::string & checkpoint_input_filename,const std::string & checkpoint_output_filename,const fcp::client::InterruptibleRunner::TimingConfig & timing_config)163*14675a02SAndroid Build Coastguard Worker PlanResultAndCheckpointFile RunPlanWithTensorflowSpec(
164*14675a02SAndroid Build Coastguard Worker std::vector<engine::ExampleIteratorFactory*> example_iterator_factories,
165*14675a02SAndroid Build Coastguard Worker std::function<bool()> should_abort, LogManager* log_manager,
166*14675a02SAndroid Build Coastguard Worker OpStatsLogger* opstats_logger, const Flags* flags,
167*14675a02SAndroid Build Coastguard Worker const ClientOnlyPlan& client_plan,
168*14675a02SAndroid Build Coastguard Worker const std::string& checkpoint_input_filename,
169*14675a02SAndroid Build Coastguard Worker const std::string& checkpoint_output_filename,
170*14675a02SAndroid Build Coastguard Worker const fcp::client::InterruptibleRunner::TimingConfig& timing_config) {
171*14675a02SAndroid Build Coastguard Worker if (!client_plan.phase().has_tensorflow_spec()) {
172*14675a02SAndroid Build Coastguard Worker return PlanResultAndCheckpointFile(engine::PlanResult(
173*14675a02SAndroid Build Coastguard Worker engine::PlanOutcome::kInvalidArgument,
174*14675a02SAndroid Build Coastguard Worker absl::InvalidArgumentError("Plan must include TensorflowSpec.")));
175*14675a02SAndroid Build Coastguard Worker }
176*14675a02SAndroid Build Coastguard Worker if (!client_plan.phase().has_federated_compute()) {
177*14675a02SAndroid Build Coastguard Worker return PlanResultAndCheckpointFile(engine::PlanResult(
178*14675a02SAndroid Build Coastguard Worker engine::PlanOutcome::kInvalidArgument,
179*14675a02SAndroid Build Coastguard Worker absl::InvalidArgumentError("Invalid TensorflowSpec-based plan")));
180*14675a02SAndroid Build Coastguard Worker }
181*14675a02SAndroid Build Coastguard Worker
182*14675a02SAndroid Build Coastguard Worker // Get the output tensor names.
183*14675a02SAndroid Build Coastguard Worker absl::StatusOr<std::vector<std::string>> output_names;
184*14675a02SAndroid Build Coastguard Worker output_names = ConstructOutputsWithDeterministicOrder(
185*14675a02SAndroid Build Coastguard Worker client_plan.phase().tensorflow_spec(),
186*14675a02SAndroid Build Coastguard Worker client_plan.phase().federated_compute());
187*14675a02SAndroid Build Coastguard Worker if (!output_names.ok()) {
188*14675a02SAndroid Build Coastguard Worker return PlanResultAndCheckpointFile(engine::PlanResult(
189*14675a02SAndroid Build Coastguard Worker engine::PlanOutcome::kInvalidArgument, output_names.status()));
190*14675a02SAndroid Build Coastguard Worker }
191*14675a02SAndroid Build Coastguard Worker
192*14675a02SAndroid Build Coastguard Worker // Run plan and get a set of output tensors back.
193*14675a02SAndroid Build Coastguard Worker if (flags->use_tflite_training() && !client_plan.tflite_graph().empty()) {
194*14675a02SAndroid Build Coastguard Worker std::unique_ptr<TfLiteInputs> tflite_inputs =
195*14675a02SAndroid Build Coastguard Worker ConstructTFLiteInputsForTensorflowSpecPlan(
196*14675a02SAndroid Build Coastguard Worker client_plan.phase().federated_compute(), checkpoint_input_filename,
197*14675a02SAndroid Build Coastguard Worker checkpoint_output_filename);
198*14675a02SAndroid Build Coastguard Worker engine::TfLitePlanEngine plan_engine(example_iterator_factories,
199*14675a02SAndroid Build Coastguard Worker should_abort, log_manager,
200*14675a02SAndroid Build Coastguard Worker opstats_logger, flags, &timing_config);
201*14675a02SAndroid Build Coastguard Worker engine::PlanResult plan_result = plan_engine.RunPlan(
202*14675a02SAndroid Build Coastguard Worker client_plan.phase().tensorflow_spec(), client_plan.tflite_graph(),
203*14675a02SAndroid Build Coastguard Worker std::move(tflite_inputs), *output_names);
204*14675a02SAndroid Build Coastguard Worker PlanResultAndCheckpointFile result(std::move(plan_result));
205*14675a02SAndroid Build Coastguard Worker result.checkpoint_file = checkpoint_output_filename;
206*14675a02SAndroid Build Coastguard Worker
207*14675a02SAndroid Build Coastguard Worker return result;
208*14675a02SAndroid Build Coastguard Worker }
209*14675a02SAndroid Build Coastguard Worker
210*14675a02SAndroid Build Coastguard Worker return PlanResultAndCheckpointFile(
211*14675a02SAndroid Build Coastguard Worker engine::PlanResult(engine::PlanOutcome::kTensorflowError,
212*14675a02SAndroid Build Coastguard Worker absl::InternalError("No plan engine enabled")));
213*14675a02SAndroid Build Coastguard Worker }
214*14675a02SAndroid Build Coastguard Worker } // namespace
215*14675a02SAndroid Build Coastguard Worker
RunFederatedComputation(SimpleTaskEnvironment * env_deps,LogManager * log_manager,const Flags * flags,const google::internal::federated::plan::ClientOnlyPlan & client_plan,const std::string & checkpoint_input_filename,const std::string & checkpoint_output_filename,const std::string & session_name,const std::string & population_name,const std::string & task_name,const fcp::client::InterruptibleRunner::TimingConfig & timing_config)216*14675a02SAndroid Build Coastguard Worker absl::StatusOr<FLRunnerResult> RunFederatedComputation(
217*14675a02SAndroid Build Coastguard Worker SimpleTaskEnvironment* env_deps, LogManager* log_manager,
218*14675a02SAndroid Build Coastguard Worker const Flags* flags,
219*14675a02SAndroid Build Coastguard Worker const google::internal::federated::plan::ClientOnlyPlan& client_plan,
220*14675a02SAndroid Build Coastguard Worker const std::string& checkpoint_input_filename,
221*14675a02SAndroid Build Coastguard Worker const std::string& checkpoint_output_filename,
222*14675a02SAndroid Build Coastguard Worker const std::string& session_name, const std::string& population_name,
223*14675a02SAndroid Build Coastguard Worker const std::string& task_name,
224*14675a02SAndroid Build Coastguard Worker const fcp::client::InterruptibleRunner::TimingConfig& timing_config) {
225*14675a02SAndroid Build Coastguard Worker SelectorContext federated_selector_context;
226*14675a02SAndroid Build Coastguard Worker federated_selector_context.mutable_computation_properties()->set_session_name(
227*14675a02SAndroid Build Coastguard Worker session_name);
228*14675a02SAndroid Build Coastguard Worker FederatedComputation federated_computation;
229*14675a02SAndroid Build Coastguard Worker federated_computation.set_population_name(population_name);
230*14675a02SAndroid Build Coastguard Worker *federated_selector_context.mutable_computation_properties()
231*14675a02SAndroid Build Coastguard Worker ->mutable_federated() = federated_computation;
232*14675a02SAndroid Build Coastguard Worker federated_selector_context.mutable_computation_properties()
233*14675a02SAndroid Build Coastguard Worker ->mutable_federated()
234*14675a02SAndroid Build Coastguard Worker ->set_task_name(task_name);
235*14675a02SAndroid Build Coastguard Worker if (client_plan.phase().has_example_query_spec()) {
236*14675a02SAndroid Build Coastguard Worker federated_selector_context.mutable_computation_properties()
237*14675a02SAndroid Build Coastguard Worker ->set_example_iterator_output_format(
238*14675a02SAndroid Build Coastguard Worker ::fcp::client::QueryTimeComputationProperties::
239*14675a02SAndroid Build Coastguard Worker EXAMPLE_QUERY_RESULT);
240*14675a02SAndroid Build Coastguard Worker } else {
241*14675a02SAndroid Build Coastguard Worker const auto& federated_compute_io_router =
242*14675a02SAndroid Build Coastguard Worker client_plan.phase().federated_compute();
243*14675a02SAndroid Build Coastguard Worker const bool has_simpleagg_tensors =
244*14675a02SAndroid Build Coastguard Worker !federated_compute_io_router.output_filepath_tensor_name().empty();
245*14675a02SAndroid Build Coastguard Worker bool all_aggregations_are_secagg = true;
246*14675a02SAndroid Build Coastguard Worker for (const auto& aggregation : federated_compute_io_router.aggregations()) {
247*14675a02SAndroid Build Coastguard Worker all_aggregations_are_secagg &=
248*14675a02SAndroid Build Coastguard Worker aggregation.second.protocol_config_case() ==
249*14675a02SAndroid Build Coastguard Worker AggregationConfig::kSecureAggregation;
250*14675a02SAndroid Build Coastguard Worker }
251*14675a02SAndroid Build Coastguard Worker if (!has_simpleagg_tensors && all_aggregations_are_secagg) {
252*14675a02SAndroid Build Coastguard Worker federated_selector_context.mutable_computation_properties()
253*14675a02SAndroid Build Coastguard Worker ->mutable_federated()
254*14675a02SAndroid Build Coastguard Worker ->mutable_secure_aggregation()
255*14675a02SAndroid Build Coastguard Worker ->set_minimum_clients_in_server_visible_aggregate(100);
256*14675a02SAndroid Build Coastguard Worker } else {
257*14675a02SAndroid Build Coastguard Worker // Has an output checkpoint, so some tensors must be simply aggregated.
258*14675a02SAndroid Build Coastguard Worker *(federated_selector_context.mutable_computation_properties()
259*14675a02SAndroid Build Coastguard Worker ->mutable_federated()
260*14675a02SAndroid Build Coastguard Worker ->mutable_simple_aggregation()) = SimpleAggregation();
261*14675a02SAndroid Build Coastguard Worker }
262*14675a02SAndroid Build Coastguard Worker }
263*14675a02SAndroid Build Coastguard Worker
264*14675a02SAndroid Build Coastguard Worker auto opstats_logger =
265*14675a02SAndroid Build Coastguard Worker engine::CreateOpStatsLogger(env_deps->GetBaseDir(), flags, log_manager,
266*14675a02SAndroid Build Coastguard Worker session_name, population_name);
267*14675a02SAndroid Build Coastguard Worker
268*14675a02SAndroid Build Coastguard Worker // Check if the device conditions allow for checking in with the server
269*14675a02SAndroid Build Coastguard Worker // and running a federated computation. If not, bail early with the
270*14675a02SAndroid Build Coastguard Worker // transient error retry window.
271*14675a02SAndroid Build Coastguard Worker std::function<bool()> should_abort = [env_deps, &timing_config]() {
272*14675a02SAndroid Build Coastguard Worker return env_deps->ShouldAbort(absl::Now(), timing_config.polling_period);
273*14675a02SAndroid Build Coastguard Worker };
274*14675a02SAndroid Build Coastguard Worker
275*14675a02SAndroid Build Coastguard Worker // Regular plans can use example iterators from the SimpleTaskEnvironment,
276*14675a02SAndroid Build Coastguard Worker // those reading the OpStats DB, or those serving Federated Select slices.
277*14675a02SAndroid Build Coastguard Worker std::unique_ptr<engine::ExampleIteratorFactory> env_example_iterator_factory =
278*14675a02SAndroid Build Coastguard Worker CreateSimpleTaskEnvironmentIteratorFactory(env_deps,
279*14675a02SAndroid Build Coastguard Worker federated_selector_context);
280*14675a02SAndroid Build Coastguard Worker std::vector<engine::ExampleIteratorFactory*> example_iterator_factories{
281*14675a02SAndroid Build Coastguard Worker env_example_iterator_factory.get()};
282*14675a02SAndroid Build Coastguard Worker PlanResultAndCheckpointFile plan_result_and_checkpoint_file =
283*14675a02SAndroid Build Coastguard Worker client_plan.phase().has_example_query_spec()
284*14675a02SAndroid Build Coastguard Worker ? RunPlanWithExampleQuerySpec(example_iterator_factories,
285*14675a02SAndroid Build Coastguard Worker opstats_logger.get(), flags,
286*14675a02SAndroid Build Coastguard Worker client_plan, checkpoint_output_filename)
287*14675a02SAndroid Build Coastguard Worker : RunPlanWithTensorflowSpec(example_iterator_factories, should_abort,
288*14675a02SAndroid Build Coastguard Worker log_manager, opstats_logger.get(), flags,
289*14675a02SAndroid Build Coastguard Worker client_plan, checkpoint_input_filename,
290*14675a02SAndroid Build Coastguard Worker checkpoint_output_filename,
291*14675a02SAndroid Build Coastguard Worker timing_config);
292*14675a02SAndroid Build Coastguard Worker auto outcome = plan_result_and_checkpoint_file.plan_result.outcome;
293*14675a02SAndroid Build Coastguard Worker FLRunnerResult fl_runner_result;
294*14675a02SAndroid Build Coastguard Worker
295*14675a02SAndroid Build Coastguard Worker if (outcome == engine::PlanOutcome::kSuccess) {
296*14675a02SAndroid Build Coastguard Worker fl_runner_result.set_contribution_result(FLRunnerResult::SUCCESS);
297*14675a02SAndroid Build Coastguard Worker } else {
298*14675a02SAndroid Build Coastguard Worker switch (outcome) {
299*14675a02SAndroid Build Coastguard Worker case engine::PlanOutcome::kInvalidArgument:
300*14675a02SAndroid Build Coastguard Worker fl_runner_result.set_error_status(FLRunnerResult::INVALID_ARGUMENT);
301*14675a02SAndroid Build Coastguard Worker break;
302*14675a02SAndroid Build Coastguard Worker case engine::PlanOutcome::kTensorflowError:
303*14675a02SAndroid Build Coastguard Worker fl_runner_result.set_error_status(FLRunnerResult::TENSORFLOW_ERROR);
304*14675a02SAndroid Build Coastguard Worker break;
305*14675a02SAndroid Build Coastguard Worker case engine::PlanOutcome::kExampleIteratorError:
306*14675a02SAndroid Build Coastguard Worker fl_runner_result.set_error_status(
307*14675a02SAndroid Build Coastguard Worker FLRunnerResult::EXAMPLE_ITERATOR_ERROR);
308*14675a02SAndroid Build Coastguard Worker break;
309*14675a02SAndroid Build Coastguard Worker default:
310*14675a02SAndroid Build Coastguard Worker break;
311*14675a02SAndroid Build Coastguard Worker }
312*14675a02SAndroid Build Coastguard Worker fl_runner_result.set_contribution_result(FLRunnerResult::FAIL);
313*14675a02SAndroid Build Coastguard Worker std::string error_message = std::string{
314*14675a02SAndroid Build Coastguard Worker plan_result_and_checkpoint_file.plan_result.original_status.message()};
315*14675a02SAndroid Build Coastguard Worker fl_runner_result.set_error_message(error_message);
316*14675a02SAndroid Build Coastguard Worker }
317*14675a02SAndroid Build Coastguard Worker
318*14675a02SAndroid Build Coastguard Worker FLRunnerResult::ExampleStats example_stats;
319*14675a02SAndroid Build Coastguard Worker example_stats.set_example_count(
320*14675a02SAndroid Build Coastguard Worker plan_result_and_checkpoint_file.plan_result.example_stats.example_count);
321*14675a02SAndroid Build Coastguard Worker example_stats.set_example_size_bytes(
322*14675a02SAndroid Build Coastguard Worker plan_result_and_checkpoint_file.plan_result.example_stats
323*14675a02SAndroid Build Coastguard Worker .example_size_bytes);
324*14675a02SAndroid Build Coastguard Worker
325*14675a02SAndroid Build Coastguard Worker *fl_runner_result.mutable_example_stats() = example_stats;
326*14675a02SAndroid Build Coastguard Worker
327*14675a02SAndroid Build Coastguard Worker return fl_runner_result;
328*14675a02SAndroid Build Coastguard Worker }
329*14675a02SAndroid Build Coastguard Worker
330*14675a02SAndroid Build Coastguard Worker } // namespace client
331*14675a02SAndroid Build Coastguard Worker } // namespace fcp