1 /*
2 * Copyright 2023 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fcp/client/fcp_runner.h"
18
19 #include "fcp/client/engine/example_iterator_factory.h"
20 #include "fcp/client/engine/example_query_plan_engine.h"
21 #include "fcp/client/engine/plan_engine_helpers.h"
22 #include "fcp/client/engine/tflite_plan_engine.h"
23 #include "fcp/client/fl_runner.pb.h"
24 #include "fcp/client/opstats/opstats_logger.h"
25 #include "fcp/protos/plan.pb.h"
26
27 namespace fcp {
28 namespace client {
29
30 using ::fcp::client::opstats::OpStatsLogger;
31 using ::google::internal::federated::plan::AggregationConfig;
32 using ::google::internal::federated::plan::ClientOnlyPlan;
33 using ::google::internal::federated::plan::FederatedComputeIORouter;
34 using ::google::internal::federated::plan::TensorflowSpec;
35
36 using TfLiteInputs = absl::flat_hash_map<std::string, std::string>;
37 namespace {
38
39 // Creates an ExampleIteratorFactory that routes queries to the
40 // SimpleTaskEnvironment::CreateExampleIterator() method.
41 std::unique_ptr<engine::ExampleIteratorFactory>
CreateSimpleTaskEnvironmentIteratorFactory(SimpleTaskEnvironment * task_env,const SelectorContext & selector_context)42 CreateSimpleTaskEnvironmentIteratorFactory(
43 SimpleTaskEnvironment* task_env, const SelectorContext& selector_context) {
44 return std::make_unique<engine::FunctionalExampleIteratorFactory>(
45 /*can_handle_func=*/
46 [](const google::internal::federated::plan::ExampleSelector&) {
47 // The SimpleTaskEnvironment-based ExampleIteratorFactory should
48 // be the catch-all factory that is able to handle all queries
49 // that no other ExampleIteratorFactory is able to handle.
50 return true;
51 },
52 /*create_iterator_func=*/
53 [task_env, selector_context](
54 const google::internal::federated::plan::ExampleSelector&
55 example_selector) {
56 return task_env->CreateExampleIterator(example_selector,
57 selector_context);
58 },
59 /*should_collect_stats=*/true);
60 }
61
ConstructTFLiteInputsForTensorflowSpecPlan(const FederatedComputeIORouter & io_router,const std::string & checkpoint_input_filename,const std::string & checkpoint_output_filename)62 std::unique_ptr<TfLiteInputs> ConstructTFLiteInputsForTensorflowSpecPlan(
63 const FederatedComputeIORouter& io_router,
64 const std::string& checkpoint_input_filename,
65 const std::string& checkpoint_output_filename) {
66 auto inputs = std::make_unique<TfLiteInputs>();
67 if (!io_router.input_filepath_tensor_name().empty()) {
68 (*inputs)[io_router.input_filepath_tensor_name()] =
69 checkpoint_input_filename;
70 }
71
72 if (!io_router.output_filepath_tensor_name().empty()) {
73 (*inputs)[io_router.output_filepath_tensor_name()] =
74 checkpoint_output_filename;
75 }
76
77 return inputs;
78 }
79
ConstructOutputsWithDeterministicOrder(const TensorflowSpec & tensorflow_spec,const FederatedComputeIORouter & io_router)80 absl::StatusOr<std::vector<std::string>> ConstructOutputsWithDeterministicOrder(
81 const TensorflowSpec& tensorflow_spec,
82 const FederatedComputeIORouter& io_router) {
83 std::vector<std::string> output_names;
84 // The order of output tensor names should match the order in TensorflowSpec.
85 for (const auto& output_tensor_spec : tensorflow_spec.output_tensor_specs()) {
86 std::string tensor_name = output_tensor_spec.name();
87 if (!io_router.aggregations().contains(tensor_name) ||
88 !io_router.aggregations().at(tensor_name).has_secure_aggregation()) {
89 return absl::InvalidArgumentError(
90 "Output tensor is missing in AggregationConfig, or has unsupported "
91 "aggregation type.");
92 }
93 output_names.push_back(tensor_name);
94 }
95
96 return output_names;
97 }
98
99 struct PlanResultAndCheckpointFile {
PlanResultAndCheckpointFilefcp::client::__anonb8072fe00111::PlanResultAndCheckpointFile100 explicit PlanResultAndCheckpointFile(engine::PlanResult plan_result)
101 : plan_result(std::move(plan_result)) {}
102 engine::PlanResult plan_result;
103 std::string checkpoint_file;
104
105 PlanResultAndCheckpointFile(PlanResultAndCheckpointFile&&) = default;
106 PlanResultAndCheckpointFile& operator=(PlanResultAndCheckpointFile&&) =
107 default;
108
109 // Disallow copy and assign.
110 PlanResultAndCheckpointFile(const PlanResultAndCheckpointFile&) = delete;
111 PlanResultAndCheckpointFile& operator=(const PlanResultAndCheckpointFile&) =
112 delete;
113 };
114
RunPlanWithExampleQuerySpec(std::vector<engine::ExampleIteratorFactory * > example_iterator_factories,OpStatsLogger * opstats_logger,const Flags * flags,const ClientOnlyPlan & client_plan,const std::string & checkpoint_output_filename)115 PlanResultAndCheckpointFile RunPlanWithExampleQuerySpec(
116 std::vector<engine::ExampleIteratorFactory*> example_iterator_factories,
117 OpStatsLogger* opstats_logger, const Flags* flags,
118 const ClientOnlyPlan& client_plan,
119 const std::string& checkpoint_output_filename) {
120 if (!client_plan.phase().has_example_query_spec()) {
121 return PlanResultAndCheckpointFile(engine::PlanResult(
122 engine::PlanOutcome::kInvalidArgument,
123 absl::InvalidArgumentError("Plan must include ExampleQuerySpec")));
124 }
125 if (!flags->enable_example_query_plan_engine()) {
126 // Example query plan received while the flag is off.
127 return PlanResultAndCheckpointFile(engine::PlanResult(
128 engine::PlanOutcome::kInvalidArgument,
129 absl::InvalidArgumentError(
130 "Example query plan received while the flag is off")));
131 }
132 if (!client_plan.phase().has_federated_example_query()) {
133 return PlanResultAndCheckpointFile(engine::PlanResult(
134 engine::PlanOutcome::kInvalidArgument,
135 absl::InvalidArgumentError("Invalid ExampleQuerySpec-based plan")));
136 }
137 for (const auto& example_query :
138 client_plan.phase().example_query_spec().example_queries()) {
139 for (auto const& [vector_name, spec] :
140 example_query.output_vector_specs()) {
141 const auto& aggregations =
142 client_plan.phase().federated_example_query().aggregations();
143 if ((aggregations.find(vector_name) == aggregations.end()) ||
144 !aggregations.at(vector_name).has_tf_v1_checkpoint_aggregation()) {
145 return PlanResultAndCheckpointFile(engine::PlanResult(
146 engine::PlanOutcome::kInvalidArgument,
147 absl::InvalidArgumentError("Output vector is missing in "
148 "AggregationConfig, or has unsupported "
149 "aggregation type.")));
150 }
151 }
152 }
153
154 engine::ExampleQueryPlanEngine plan_engine(example_iterator_factories,
155 opstats_logger);
156 engine::PlanResult plan_result = plan_engine.RunPlan(
157 client_plan.phase().example_query_spec(), checkpoint_output_filename);
158 PlanResultAndCheckpointFile result(std::move(plan_result));
159 result.checkpoint_file = checkpoint_output_filename;
160 return result;
161 }
162
RunPlanWithTensorflowSpec(std::vector<engine::ExampleIteratorFactory * > example_iterator_factories,std::function<bool ()> should_abort,LogManager * log_manager,OpStatsLogger * opstats_logger,const Flags * flags,const ClientOnlyPlan & client_plan,const std::string & checkpoint_input_filename,const std::string & checkpoint_output_filename,const fcp::client::InterruptibleRunner::TimingConfig & timing_config)163 PlanResultAndCheckpointFile RunPlanWithTensorflowSpec(
164 std::vector<engine::ExampleIteratorFactory*> example_iterator_factories,
165 std::function<bool()> should_abort, LogManager* log_manager,
166 OpStatsLogger* opstats_logger, const Flags* flags,
167 const ClientOnlyPlan& client_plan,
168 const std::string& checkpoint_input_filename,
169 const std::string& checkpoint_output_filename,
170 const fcp::client::InterruptibleRunner::TimingConfig& timing_config) {
171 if (!client_plan.phase().has_tensorflow_spec()) {
172 return PlanResultAndCheckpointFile(engine::PlanResult(
173 engine::PlanOutcome::kInvalidArgument,
174 absl::InvalidArgumentError("Plan must include TensorflowSpec.")));
175 }
176 if (!client_plan.phase().has_federated_compute()) {
177 return PlanResultAndCheckpointFile(engine::PlanResult(
178 engine::PlanOutcome::kInvalidArgument,
179 absl::InvalidArgumentError("Invalid TensorflowSpec-based plan")));
180 }
181
182 // Get the output tensor names.
183 absl::StatusOr<std::vector<std::string>> output_names;
184 output_names = ConstructOutputsWithDeterministicOrder(
185 client_plan.phase().tensorflow_spec(),
186 client_plan.phase().federated_compute());
187 if (!output_names.ok()) {
188 return PlanResultAndCheckpointFile(engine::PlanResult(
189 engine::PlanOutcome::kInvalidArgument, output_names.status()));
190 }
191
192 // Run plan and get a set of output tensors back.
193 if (flags->use_tflite_training() && !client_plan.tflite_graph().empty()) {
194 std::unique_ptr<TfLiteInputs> tflite_inputs =
195 ConstructTFLiteInputsForTensorflowSpecPlan(
196 client_plan.phase().federated_compute(), checkpoint_input_filename,
197 checkpoint_output_filename);
198 engine::TfLitePlanEngine plan_engine(example_iterator_factories,
199 should_abort, log_manager,
200 opstats_logger, flags, &timing_config);
201 engine::PlanResult plan_result = plan_engine.RunPlan(
202 client_plan.phase().tensorflow_spec(), client_plan.tflite_graph(),
203 std::move(tflite_inputs), *output_names);
204 PlanResultAndCheckpointFile result(std::move(plan_result));
205 result.checkpoint_file = checkpoint_output_filename;
206
207 return result;
208 }
209
210 return PlanResultAndCheckpointFile(
211 engine::PlanResult(engine::PlanOutcome::kTensorflowError,
212 absl::InternalError("No plan engine enabled")));
213 }
214 } // namespace
215
RunFederatedComputation(SimpleTaskEnvironment * env_deps,LogManager * log_manager,const Flags * flags,const google::internal::federated::plan::ClientOnlyPlan & client_plan,const std::string & checkpoint_input_filename,const std::string & checkpoint_output_filename,const std::string & session_name,const std::string & population_name,const std::string & task_name,const fcp::client::InterruptibleRunner::TimingConfig & timing_config)216 absl::StatusOr<FLRunnerResult> RunFederatedComputation(
217 SimpleTaskEnvironment* env_deps, LogManager* log_manager,
218 const Flags* flags,
219 const google::internal::federated::plan::ClientOnlyPlan& client_plan,
220 const std::string& checkpoint_input_filename,
221 const std::string& checkpoint_output_filename,
222 const std::string& session_name, const std::string& population_name,
223 const std::string& task_name,
224 const fcp::client::InterruptibleRunner::TimingConfig& timing_config) {
225 SelectorContext federated_selector_context;
226 federated_selector_context.mutable_computation_properties()->set_session_name(
227 session_name);
228 FederatedComputation federated_computation;
229 federated_computation.set_population_name(population_name);
230 *federated_selector_context.mutable_computation_properties()
231 ->mutable_federated() = federated_computation;
232 federated_selector_context.mutable_computation_properties()
233 ->mutable_federated()
234 ->set_task_name(task_name);
235 if (client_plan.phase().has_example_query_spec()) {
236 federated_selector_context.mutable_computation_properties()
237 ->set_example_iterator_output_format(
238 ::fcp::client::QueryTimeComputationProperties::
239 EXAMPLE_QUERY_RESULT);
240 } else {
241 const auto& federated_compute_io_router =
242 client_plan.phase().federated_compute();
243 const bool has_simpleagg_tensors =
244 !federated_compute_io_router.output_filepath_tensor_name().empty();
245 bool all_aggregations_are_secagg = true;
246 for (const auto& aggregation : federated_compute_io_router.aggregations()) {
247 all_aggregations_are_secagg &=
248 aggregation.second.protocol_config_case() ==
249 AggregationConfig::kSecureAggregation;
250 }
251 if (!has_simpleagg_tensors && all_aggregations_are_secagg) {
252 federated_selector_context.mutable_computation_properties()
253 ->mutable_federated()
254 ->mutable_secure_aggregation()
255 ->set_minimum_clients_in_server_visible_aggregate(100);
256 } else {
257 // Has an output checkpoint, so some tensors must be simply aggregated.
258 *(federated_selector_context.mutable_computation_properties()
259 ->mutable_federated()
260 ->mutable_simple_aggregation()) = SimpleAggregation();
261 }
262 }
263
264 auto opstats_logger =
265 engine::CreateOpStatsLogger(env_deps->GetBaseDir(), flags, log_manager,
266 session_name, population_name);
267
268 // Check if the device conditions allow for checking in with the server
269 // and running a federated computation. If not, bail early with the
270 // transient error retry window.
271 std::function<bool()> should_abort = [env_deps, &timing_config]() {
272 return env_deps->ShouldAbort(absl::Now(), timing_config.polling_period);
273 };
274
275 // Regular plans can use example iterators from the SimpleTaskEnvironment,
276 // those reading the OpStats DB, or those serving Federated Select slices.
277 std::unique_ptr<engine::ExampleIteratorFactory> env_example_iterator_factory =
278 CreateSimpleTaskEnvironmentIteratorFactory(env_deps,
279 federated_selector_context);
280 std::vector<engine::ExampleIteratorFactory*> example_iterator_factories{
281 env_example_iterator_factory.get()};
282 PlanResultAndCheckpointFile plan_result_and_checkpoint_file =
283 client_plan.phase().has_example_query_spec()
284 ? RunPlanWithExampleQuerySpec(example_iterator_factories,
285 opstats_logger.get(), flags,
286 client_plan, checkpoint_output_filename)
287 : RunPlanWithTensorflowSpec(example_iterator_factories, should_abort,
288 log_manager, opstats_logger.get(), flags,
289 client_plan, checkpoint_input_filename,
290 checkpoint_output_filename,
291 timing_config);
292 auto outcome = plan_result_and_checkpoint_file.plan_result.outcome;
293 FLRunnerResult fl_runner_result;
294
295 if (outcome == engine::PlanOutcome::kSuccess) {
296 fl_runner_result.set_contribution_result(FLRunnerResult::SUCCESS);
297 } else {
298 switch (outcome) {
299 case engine::PlanOutcome::kInvalidArgument:
300 fl_runner_result.set_error_status(FLRunnerResult::INVALID_ARGUMENT);
301 break;
302 case engine::PlanOutcome::kTensorflowError:
303 fl_runner_result.set_error_status(FLRunnerResult::TENSORFLOW_ERROR);
304 break;
305 case engine::PlanOutcome::kExampleIteratorError:
306 fl_runner_result.set_error_status(
307 FLRunnerResult::EXAMPLE_ITERATOR_ERROR);
308 break;
309 default:
310 break;
311 }
312 fl_runner_result.set_contribution_result(FLRunnerResult::FAIL);
313 std::string error_message = std::string{
314 plan_result_and_checkpoint_file.plan_result.original_status.message()};
315 fl_runner_result.set_error_message(error_message);
316 }
317
318 FLRunnerResult::ExampleStats example_stats;
319 example_stats.set_example_count(
320 plan_result_and_checkpoint_file.plan_result.example_stats.example_count);
321 example_stats.set_example_size_bytes(
322 plan_result_and_checkpoint_file.plan_result.example_stats
323 .example_size_bytes);
324
325 *fl_runner_result.mutable_example_stats() = example_stats;
326
327 return fl_runner_result;
328 }
329
330 } // namespace client
331 } // namespace fcp