xref: /aosp_15_r20/external/federated-compute/fcp/client/fcp_runner.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2023 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/client/fcp_runner.h"
18 
19 #include "fcp/client/engine/example_iterator_factory.h"
20 #include "fcp/client/engine/example_query_plan_engine.h"
21 #include "fcp/client/engine/plan_engine_helpers.h"
22 #include "fcp/client/engine/tflite_plan_engine.h"
23 #include "fcp/client/fl_runner.pb.h"
24 #include "fcp/client/opstats/opstats_logger.h"
25 #include "fcp/protos/plan.pb.h"
26 
27 namespace fcp {
28 namespace client {
29 
30 using ::fcp::client::opstats::OpStatsLogger;
31 using ::google::internal::federated::plan::AggregationConfig;
32 using ::google::internal::federated::plan::ClientOnlyPlan;
33 using ::google::internal::federated::plan::FederatedComputeIORouter;
34 using ::google::internal::federated::plan::TensorflowSpec;
35 
36 using TfLiteInputs = absl::flat_hash_map<std::string, std::string>;
37 namespace {
38 
39 // Creates an ExampleIteratorFactory that routes queries to the
40 // SimpleTaskEnvironment::CreateExampleIterator() method.
41 std::unique_ptr<engine::ExampleIteratorFactory>
CreateSimpleTaskEnvironmentIteratorFactory(SimpleTaskEnvironment * task_env,const SelectorContext & selector_context)42 CreateSimpleTaskEnvironmentIteratorFactory(
43     SimpleTaskEnvironment* task_env, const SelectorContext& selector_context) {
44   return std::make_unique<engine::FunctionalExampleIteratorFactory>(
45       /*can_handle_func=*/
46       [](const google::internal::federated::plan::ExampleSelector&) {
47         // The SimpleTaskEnvironment-based ExampleIteratorFactory should
48         // be the catch-all factory that is able to handle all queries
49         // that no other ExampleIteratorFactory is able to handle.
50         return true;
51       },
52       /*create_iterator_func=*/
53       [task_env, selector_context](
54           const google::internal::federated::plan::ExampleSelector&
55               example_selector) {
56         return task_env->CreateExampleIterator(example_selector,
57                                                selector_context);
58       },
59       /*should_collect_stats=*/true);
60 }
61 
ConstructTFLiteInputsForTensorflowSpecPlan(const FederatedComputeIORouter & io_router,const std::string & checkpoint_input_filename,const std::string & checkpoint_output_filename)62 std::unique_ptr<TfLiteInputs> ConstructTFLiteInputsForTensorflowSpecPlan(
63     const FederatedComputeIORouter& io_router,
64     const std::string& checkpoint_input_filename,
65     const std::string& checkpoint_output_filename) {
66   auto inputs = std::make_unique<TfLiteInputs>();
67   if (!io_router.input_filepath_tensor_name().empty()) {
68     (*inputs)[io_router.input_filepath_tensor_name()] =
69         checkpoint_input_filename;
70   }
71 
72   if (!io_router.output_filepath_tensor_name().empty()) {
73     (*inputs)[io_router.output_filepath_tensor_name()] =
74         checkpoint_output_filename;
75   }
76 
77   return inputs;
78 }
79 
ConstructOutputsWithDeterministicOrder(const TensorflowSpec & tensorflow_spec,const FederatedComputeIORouter & io_router)80 absl::StatusOr<std::vector<std::string>> ConstructOutputsWithDeterministicOrder(
81     const TensorflowSpec& tensorflow_spec,
82     const FederatedComputeIORouter& io_router) {
83   std::vector<std::string> output_names;
84   // The order of output tensor names should match the order in TensorflowSpec.
85   for (const auto& output_tensor_spec : tensorflow_spec.output_tensor_specs()) {
86     std::string tensor_name = output_tensor_spec.name();
87     if (!io_router.aggregations().contains(tensor_name) ||
88         !io_router.aggregations().at(tensor_name).has_secure_aggregation()) {
89       return absl::InvalidArgumentError(
90           "Output tensor is missing in AggregationConfig, or has unsupported "
91           "aggregation type.");
92     }
93     output_names.push_back(tensor_name);
94   }
95 
96   return output_names;
97 }
98 
99 struct PlanResultAndCheckpointFile {
PlanResultAndCheckpointFilefcp::client::__anonb8072fe00111::PlanResultAndCheckpointFile100   explicit PlanResultAndCheckpointFile(engine::PlanResult plan_result)
101       : plan_result(std::move(plan_result)) {}
102   engine::PlanResult plan_result;
103   std::string checkpoint_file;
104 
105   PlanResultAndCheckpointFile(PlanResultAndCheckpointFile&&) = default;
106   PlanResultAndCheckpointFile& operator=(PlanResultAndCheckpointFile&&) =
107       default;
108 
109   // Disallow copy and assign.
110   PlanResultAndCheckpointFile(const PlanResultAndCheckpointFile&) = delete;
111   PlanResultAndCheckpointFile& operator=(const PlanResultAndCheckpointFile&) =
112       delete;
113 };
114 
RunPlanWithExampleQuerySpec(std::vector<engine::ExampleIteratorFactory * > example_iterator_factories,OpStatsLogger * opstats_logger,const Flags * flags,const ClientOnlyPlan & client_plan,const std::string & checkpoint_output_filename)115 PlanResultAndCheckpointFile RunPlanWithExampleQuerySpec(
116     std::vector<engine::ExampleIteratorFactory*> example_iterator_factories,
117     OpStatsLogger* opstats_logger, const Flags* flags,
118     const ClientOnlyPlan& client_plan,
119     const std::string& checkpoint_output_filename) {
120   if (!client_plan.phase().has_example_query_spec()) {
121     return PlanResultAndCheckpointFile(engine::PlanResult(
122         engine::PlanOutcome::kInvalidArgument,
123         absl::InvalidArgumentError("Plan must include ExampleQuerySpec")));
124   }
125   if (!flags->enable_example_query_plan_engine()) {
126     // Example query plan received while the flag is off.
127     return PlanResultAndCheckpointFile(engine::PlanResult(
128         engine::PlanOutcome::kInvalidArgument,
129         absl::InvalidArgumentError(
130             "Example query plan received while the flag is off")));
131   }
132   if (!client_plan.phase().has_federated_example_query()) {
133     return PlanResultAndCheckpointFile(engine::PlanResult(
134         engine::PlanOutcome::kInvalidArgument,
135         absl::InvalidArgumentError("Invalid ExampleQuerySpec-based plan")));
136   }
137   for (const auto& example_query :
138        client_plan.phase().example_query_spec().example_queries()) {
139     for (auto const& [vector_name, spec] :
140          example_query.output_vector_specs()) {
141       const auto& aggregations =
142           client_plan.phase().federated_example_query().aggregations();
143       if ((aggregations.find(vector_name) == aggregations.end()) ||
144           !aggregations.at(vector_name).has_tf_v1_checkpoint_aggregation()) {
145         return PlanResultAndCheckpointFile(engine::PlanResult(
146             engine::PlanOutcome::kInvalidArgument,
147             absl::InvalidArgumentError("Output vector is missing in "
148                                        "AggregationConfig, or has unsupported "
149                                        "aggregation type.")));
150       }
151     }
152   }
153 
154   engine::ExampleQueryPlanEngine plan_engine(example_iterator_factories,
155                                              opstats_logger);
156   engine::PlanResult plan_result = plan_engine.RunPlan(
157       client_plan.phase().example_query_spec(), checkpoint_output_filename);
158   PlanResultAndCheckpointFile result(std::move(plan_result));
159   result.checkpoint_file = checkpoint_output_filename;
160   return result;
161 }
162 
RunPlanWithTensorflowSpec(std::vector<engine::ExampleIteratorFactory * > example_iterator_factories,std::function<bool ()> should_abort,LogManager * log_manager,OpStatsLogger * opstats_logger,const Flags * flags,const ClientOnlyPlan & client_plan,const std::string & checkpoint_input_filename,const std::string & checkpoint_output_filename,const fcp::client::InterruptibleRunner::TimingConfig & timing_config)163 PlanResultAndCheckpointFile RunPlanWithTensorflowSpec(
164     std::vector<engine::ExampleIteratorFactory*> example_iterator_factories,
165     std::function<bool()> should_abort, LogManager* log_manager,
166     OpStatsLogger* opstats_logger, const Flags* flags,
167     const ClientOnlyPlan& client_plan,
168     const std::string& checkpoint_input_filename,
169     const std::string& checkpoint_output_filename,
170     const fcp::client::InterruptibleRunner::TimingConfig& timing_config) {
171   if (!client_plan.phase().has_tensorflow_spec()) {
172     return PlanResultAndCheckpointFile(engine::PlanResult(
173         engine::PlanOutcome::kInvalidArgument,
174         absl::InvalidArgumentError("Plan must include TensorflowSpec.")));
175   }
176   if (!client_plan.phase().has_federated_compute()) {
177     return PlanResultAndCheckpointFile(engine::PlanResult(
178         engine::PlanOutcome::kInvalidArgument,
179         absl::InvalidArgumentError("Invalid TensorflowSpec-based plan")));
180   }
181 
182   // Get the output tensor names.
183   absl::StatusOr<std::vector<std::string>> output_names;
184   output_names = ConstructOutputsWithDeterministicOrder(
185       client_plan.phase().tensorflow_spec(),
186       client_plan.phase().federated_compute());
187   if (!output_names.ok()) {
188     return PlanResultAndCheckpointFile(engine::PlanResult(
189         engine::PlanOutcome::kInvalidArgument, output_names.status()));
190   }
191 
192   // Run plan and get a set of output tensors back.
193   if (flags->use_tflite_training() && !client_plan.tflite_graph().empty()) {
194     std::unique_ptr<TfLiteInputs> tflite_inputs =
195         ConstructTFLiteInputsForTensorflowSpecPlan(
196             client_plan.phase().federated_compute(), checkpoint_input_filename,
197             checkpoint_output_filename);
198     engine::TfLitePlanEngine plan_engine(example_iterator_factories,
199                                          should_abort, log_manager,
200                                          opstats_logger, flags, &timing_config);
201     engine::PlanResult plan_result = plan_engine.RunPlan(
202         client_plan.phase().tensorflow_spec(), client_plan.tflite_graph(),
203         std::move(tflite_inputs), *output_names);
204     PlanResultAndCheckpointFile result(std::move(plan_result));
205     result.checkpoint_file = checkpoint_output_filename;
206 
207     return result;
208   }
209 
210   return PlanResultAndCheckpointFile(
211       engine::PlanResult(engine::PlanOutcome::kTensorflowError,
212                          absl::InternalError("No plan engine enabled")));
213 }
214 }  // namespace
215 
RunFederatedComputation(SimpleTaskEnvironment * env_deps,LogManager * log_manager,const Flags * flags,const google::internal::federated::plan::ClientOnlyPlan & client_plan,const std::string & checkpoint_input_filename,const std::string & checkpoint_output_filename,const std::string & session_name,const std::string & population_name,const std::string & task_name,const fcp::client::InterruptibleRunner::TimingConfig & timing_config)216 absl::StatusOr<FLRunnerResult> RunFederatedComputation(
217     SimpleTaskEnvironment* env_deps, LogManager* log_manager,
218     const Flags* flags,
219     const google::internal::federated::plan::ClientOnlyPlan& client_plan,
220     const std::string& checkpoint_input_filename,
221     const std::string& checkpoint_output_filename,
222     const std::string& session_name, const std::string& population_name,
223     const std::string& task_name,
224     const fcp::client::InterruptibleRunner::TimingConfig& timing_config) {
225   SelectorContext federated_selector_context;
226   federated_selector_context.mutable_computation_properties()->set_session_name(
227       session_name);
228   FederatedComputation federated_computation;
229   federated_computation.set_population_name(population_name);
230   *federated_selector_context.mutable_computation_properties()
231        ->mutable_federated() = federated_computation;
232   federated_selector_context.mutable_computation_properties()
233       ->mutable_federated()
234       ->set_task_name(task_name);
235   if (client_plan.phase().has_example_query_spec()) {
236     federated_selector_context.mutable_computation_properties()
237         ->set_example_iterator_output_format(
238             ::fcp::client::QueryTimeComputationProperties::
239                 EXAMPLE_QUERY_RESULT);
240   } else {
241     const auto& federated_compute_io_router =
242         client_plan.phase().federated_compute();
243     const bool has_simpleagg_tensors =
244         !federated_compute_io_router.output_filepath_tensor_name().empty();
245     bool all_aggregations_are_secagg = true;
246     for (const auto& aggregation : federated_compute_io_router.aggregations()) {
247       all_aggregations_are_secagg &=
248           aggregation.second.protocol_config_case() ==
249           AggregationConfig::kSecureAggregation;
250     }
251     if (!has_simpleagg_tensors && all_aggregations_are_secagg) {
252       federated_selector_context.mutable_computation_properties()
253           ->mutable_federated()
254           ->mutable_secure_aggregation()
255           ->set_minimum_clients_in_server_visible_aggregate(100);
256     } else {
257       // Has an output checkpoint, so some tensors must be simply aggregated.
258       *(federated_selector_context.mutable_computation_properties()
259             ->mutable_federated()
260             ->mutable_simple_aggregation()) = SimpleAggregation();
261     }
262   }
263 
264   auto opstats_logger =
265       engine::CreateOpStatsLogger(env_deps->GetBaseDir(), flags, log_manager,
266                                   session_name, population_name);
267 
268   // Check if the device conditions allow for checking in with the server
269   // and running a federated computation. If not, bail early with the
270   // transient error retry window.
271   std::function<bool()> should_abort = [env_deps, &timing_config]() {
272     return env_deps->ShouldAbort(absl::Now(), timing_config.polling_period);
273   };
274 
275   // Regular plans can use example iterators from the SimpleTaskEnvironment,
276   // those reading the OpStats DB, or those serving Federated Select slices.
277   std::unique_ptr<engine::ExampleIteratorFactory> env_example_iterator_factory =
278       CreateSimpleTaskEnvironmentIteratorFactory(env_deps,
279                                                  federated_selector_context);
280   std::vector<engine::ExampleIteratorFactory*> example_iterator_factories{
281       env_example_iterator_factory.get()};
282   PlanResultAndCheckpointFile plan_result_and_checkpoint_file =
283       client_plan.phase().has_example_query_spec()
284           ? RunPlanWithExampleQuerySpec(example_iterator_factories,
285                                         opstats_logger.get(), flags,
286                                         client_plan, checkpoint_output_filename)
287           : RunPlanWithTensorflowSpec(example_iterator_factories, should_abort,
288                                       log_manager, opstats_logger.get(), flags,
289                                       client_plan, checkpoint_input_filename,
290                                       checkpoint_output_filename,
291                                       timing_config);
292   auto outcome = plan_result_and_checkpoint_file.plan_result.outcome;
293   FLRunnerResult fl_runner_result;
294 
295   if (outcome == engine::PlanOutcome::kSuccess) {
296     fl_runner_result.set_contribution_result(FLRunnerResult::SUCCESS);
297   } else {
298     switch (outcome) {
299       case engine::PlanOutcome::kInvalidArgument:
300         fl_runner_result.set_error_status(FLRunnerResult::INVALID_ARGUMENT);
301         break;
302       case engine::PlanOutcome::kTensorflowError:
303         fl_runner_result.set_error_status(FLRunnerResult::TENSORFLOW_ERROR);
304         break;
305       case engine::PlanOutcome::kExampleIteratorError:
306         fl_runner_result.set_error_status(
307             FLRunnerResult::EXAMPLE_ITERATOR_ERROR);
308         break;
309       default:
310         break;
311     }
312     fl_runner_result.set_contribution_result(FLRunnerResult::FAIL);
313     std::string error_message = std::string{
314         plan_result_and_checkpoint_file.plan_result.original_status.message()};
315     fl_runner_result.set_error_message(error_message);
316   }
317 
318   FLRunnerResult::ExampleStats example_stats;
319   example_stats.set_example_count(
320       plan_result_and_checkpoint_file.plan_result.example_stats.example_count);
321   example_stats.set_example_size_bytes(
322       plan_result_and_checkpoint_file.plan_result.example_stats
323           .example_size_bytes);
324 
325   *fl_runner_result.mutable_example_stats() = example_stats;
326 
327   return fl_runner_result;
328 }
329 
330 }  // namespace client
331 }  // namespace fcp