1 /*
2 * Copyright 2020 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "fcp/client/fl_runner.h"
17
18 #include <fcntl.h>
19
20 #include <fstream>
21 #include <functional>
22 #include <map>
23 #include <memory>
24 #include <optional>
25 #include <string>
26 #include <utility>
27 #include <variant>
28 #include <vector>
29
30 #include "absl/status/status.h"
31 #include "absl/status/statusor.h"
32 #include "absl/strings/cord.h"
33 #include "absl/time/time.h"
34 #include "fcp/base/clock.h"
35 #include "fcp/base/monitoring.h"
36 #include "fcp/base/platform.h"
37 // #include "fcp/client/cache/file_backed_resource_cache.h"
38 #include "fcp/client/cache/resource_cache.h"
39 #include "fcp/client/engine/common.h"
40 #include "fcp/client/engine/engine.pb.h"
41 #include "fcp/client/engine/example_iterator_factory.h"
42 #include "fcp/client/engine/example_query_plan_engine.h"
43 #include "fcp/client/engine/plan_engine_helpers.h"
44 #include "fcp/client/opstats/opstats_utils.h"
45 #include "fcp/client/parsing_utils.h"
46 #ifdef FCP_CLIENT_SUPPORT_TFMOBILE
47 #include "fcp/client/engine/simple_plan_engine.h"
48 #endif
49 #include "fcp/client/engine/tflite_plan_engine.h"
50 #include "fcp/client/event_publisher.h"
51 #include "fcp/client/federated_protocol.h"
52 #include "fcp/client/federated_protocol_util.h"
53 #include "fcp/client/files.h"
54 #include "fcp/client/fl_runner.pb.h"
55 #include "fcp/client/flags.h"
56 #include "fcp/client/http/http_federated_protocol.h"
57 #ifdef FCP_CLIENT_SUPPORT_GRPC
58 #include "fcp/client/grpc_federated_protocol.h"
59 #endif
60 #include "fcp/client/interruptible_runner.h"
61 #include "fcp/client/log_manager.h"
62 #include "fcp/client/opstats/opstats_example_store.h"
63 #include "fcp/client/phase_logger_impl.h"
64 #include "fcp/client/secagg_runner.h"
65 #include "fcp/client/selector_context.pb.h"
66 #include "fcp/client/simple_task_environment.h"
67 #include "fcp/protos/federated_api.pb.h"
68 #include "fcp/protos/federatedcompute/eligibility_eval_tasks.pb.h"
69 #include "fcp/protos/opstats.pb.h"
70 #include "fcp/protos/plan.pb.h"
71 #include "openssl/digest.h"
72 #include "openssl/evp.h"
73 #include "tensorflow/core/framework/tensor.h"
74 #include "tensorflow/core/framework/tensor.pb.h"
75 #include "tensorflow/core/framework/tensor_shape.pb.h"
76 #include "tensorflow/core/protobuf/struct.pb.h"
77 namespace fcp {
78 namespace client {
79 using ::fcp::client::opstats::OpStatsLogger;
80 using ::google::internal::federated::plan::AggregationConfig;
81 using ::google::internal::federated::plan::ClientOnlyPlan;
82 using ::google::internal::federated::plan::FederatedComputeEligibilityIORouter;
83 using ::google::internal::federated::plan::FederatedComputeIORouter;
84 using ::google::internal::federated::plan::TensorflowSpec;
85 using ::google::internal::federatedcompute::v1::PopulationEligibilitySpec;
86 using ::google::internal::federatedml::v2::RetryWindow;
87 using ::google::internal::federatedml::v2::TaskEligibilityInfo;
88 using TfLiteInputs = absl::flat_hash_map<std::string, std::string>;
89 namespace {
90 template <typename T>
AddValuesToQuantized(QuantizedTensor * quantized,const tensorflow::Tensor & tensor)91 void AddValuesToQuantized(QuantizedTensor* quantized,
92 const tensorflow::Tensor& tensor) {
93 auto flat_tensor = tensor.flat<T>();
94 quantized->values.reserve(quantized->values.size() + flat_tensor.size());
95 for (int i = 0; i < flat_tensor.size(); i++) {
96 quantized->values.push_back(flat_tensor(i));
97 }
98 }
ComputeSHA256FromStringOrCord(std::variant<std::string,absl::Cord> data)99 std::string ComputeSHA256FromStringOrCord(
100 std::variant<std::string, absl::Cord> data) {
101 std::unique_ptr<EVP_MD_CTX, void (*)(EVP_MD_CTX*)> mdctx(EVP_MD_CTX_create(),
102 EVP_MD_CTX_destroy);
103 FCP_CHECK(EVP_DigestInit_ex(mdctx.get(), EVP_sha256(), nullptr));
104 std::string plan_str;
105 if (std::holds_alternative<std::string>(data)) {
106 plan_str = std::get<std::string>(data);
107 } else {
108 plan_str = std::string(std::get<absl::Cord>(data));
109 }
110 FCP_CHECK(EVP_DigestUpdate(mdctx.get(), plan_str.c_str(), sizeof(int)));
111 const int hash_len = 32; // 32 bytes for SHA-256.
112 uint8_t computation_id_bytes[hash_len];
113 FCP_CHECK(EVP_DigestFinal_ex(mdctx.get(), computation_id_bytes, nullptr));
114 return std::string(reinterpret_cast<char const*>(computation_id_bytes),
115 hash_len);
116 }
117 struct PlanResultAndCheckpointFile {
PlanResultAndCheckpointFilefcp::client::__anon340cd2590111::PlanResultAndCheckpointFile118 explicit PlanResultAndCheckpointFile(engine::PlanResult plan_result)
119 : plan_result(std::move(plan_result)) {}
120 engine::PlanResult plan_result;
121 std::string checkpoint_file;
122 PlanResultAndCheckpointFile(PlanResultAndCheckpointFile&&) = default;
123 PlanResultAndCheckpointFile& operator=(PlanResultAndCheckpointFile&&) =
124 default;
125 // Disallow copy and assign.
126 PlanResultAndCheckpointFile(const PlanResultAndCheckpointFile&) = delete;
127 PlanResultAndCheckpointFile& operator=(const PlanResultAndCheckpointFile&) =
128 delete;
129 };
130 // Creates computation results. The method checks for SecAgg tensors only if
131 // `tensorflow_spec != nullptr`.
CreateComputationResults(const TensorflowSpec * tensorflow_spec,const PlanResultAndCheckpointFile & plan_result_and_checkpoint_file)132 absl::StatusOr<ComputationResults> CreateComputationResults(
133 const TensorflowSpec* tensorflow_spec,
134 const PlanResultAndCheckpointFile& plan_result_and_checkpoint_file) {
135 const auto& [plan_result, checkpoint_file] = plan_result_and_checkpoint_file;
136 if (plan_result.outcome != engine::PlanOutcome::kSuccess) {
137 return absl::InvalidArgumentError("Computation failed.");
138 }
139 ComputationResults computation_results;
140 if (tensorflow_spec != nullptr) {
141 for (int i = 0; i < plan_result.output_names.size(); i++) {
142 QuantizedTensor quantized;
143 const auto& output_tensor = plan_result.output_tensors[i];
144 switch (output_tensor.dtype()) {
145 case tensorflow::DT_INT8:
146 AddValuesToQuantized<int8_t>(&quantized, output_tensor);
147 quantized.bitwidth = 7;
148 break;
149 case tensorflow::DT_UINT8:
150 AddValuesToQuantized<uint8_t>(&quantized, output_tensor);
151 quantized.bitwidth = 8;
152 break;
153 case tensorflow::DT_INT16:
154 AddValuesToQuantized<int16_t>(&quantized, output_tensor);
155 quantized.bitwidth = 15;
156 break;
157 case tensorflow::DT_UINT16:
158 AddValuesToQuantized<uint16_t>(&quantized, output_tensor);
159 quantized.bitwidth = 16;
160 break;
161 case tensorflow::DT_INT32:
162 AddValuesToQuantized<int32_t>(&quantized, output_tensor);
163 quantized.bitwidth = 31;
164 break;
165 case tensorflow::DT_INT64:
166 AddValuesToQuantized<tensorflow::int64>(&quantized, output_tensor);
167 quantized.bitwidth = 62;
168 break;
169 default:
170 return absl::InvalidArgumentError(
171 absl::StrCat("Tensor of type",
172 tensorflow::DataType_Name(output_tensor.dtype()),
173 "could not be converted to quantized value"));
174 }
175 computation_results[plan_result.output_names[i]] = std::move(quantized);
176 }
177 // Add dimensions to QuantizedTensors.
178 for (const tensorflow::TensorSpecProto& tensor_spec :
179 tensorflow_spec->output_tensor_specs()) {
180 if (computation_results.find(tensor_spec.name()) !=
181 computation_results.end()) {
182 for (const tensorflow::TensorShapeProto_Dim& dim :
183 tensor_spec.shape().dim()) {
184 std::get<QuantizedTensor>(computation_results[tensor_spec.name()])
185 .dimensions.push_back(dim.size());
186 }
187 }
188 }
189 }
190 // Name of the TF checkpoint inside the aggregand map in the Checkpoint
191 // protobuf. This field name is ignored by the server.
192 if (!checkpoint_file.empty()) {
193 FCP_ASSIGN_OR_RETURN(std::string tf_checkpoint,
194 fcp::ReadFileToString(checkpoint_file));
195 computation_results[std::string(kTensorflowCheckpointAggregand)] =
196 std::move(tf_checkpoint);
197 }
198 return computation_results;
199 }
200 #ifdef FCP_CLIENT_SUPPORT_TFMOBILE
201 std::unique_ptr<std::vector<std::pair<std::string, tensorflow::Tensor>>>
ConstructInputsForEligibilityEvalPlan(const FederatedComputeEligibilityIORouter & io_router,const std::string & checkpoint_input_filename)202 ConstructInputsForEligibilityEvalPlan(
203 const FederatedComputeEligibilityIORouter& io_router,
204 const std::string& checkpoint_input_filename) {
205 auto inputs = std::make_unique<
206 std::vector<std::pair<std::string, tensorflow::Tensor>>>();
207 if (!io_router.input_filepath_tensor_name().empty()) {
208 tensorflow::Tensor input_filepath(tensorflow::DT_STRING, {});
209 input_filepath.scalar<tensorflow::tstring>()() = checkpoint_input_filename;
210 inputs->push_back({io_router.input_filepath_tensor_name(), input_filepath});
211 }
212 return inputs;
213 }
214 #endif
ConstructTfLiteInputsForEligibilityEvalPlan(const FederatedComputeEligibilityIORouter & io_router,const std::string & checkpoint_input_filename)215 std::unique_ptr<TfLiteInputs> ConstructTfLiteInputsForEligibilityEvalPlan(
216 const FederatedComputeEligibilityIORouter& io_router,
217 const std::string& checkpoint_input_filename) {
218 auto inputs = std::make_unique<TfLiteInputs>();
219 if (!io_router.input_filepath_tensor_name().empty()) {
220 (*inputs)[io_router.input_filepath_tensor_name()] =
221 checkpoint_input_filename;
222 }
223 return inputs;
224 }
225 // Returns the cumulative network stats (those incurred up until this point in
226 // time).
227 //
228 // The `FederatedSelectManager` object may be null, if it is know that there
229 // has been no network usage from it yet.
GetCumulativeNetworkStats(FederatedProtocol * federated_protocol,FederatedSelectManager * fedselect_manager)230 NetworkStats GetCumulativeNetworkStats(
231 FederatedProtocol* federated_protocol,
232 FederatedSelectManager* fedselect_manager) {
233 NetworkStats result = federated_protocol->GetNetworkStats();
234 if (fedselect_manager != nullptr) {
235 result = result + fedselect_manager->GetNetworkStats();
236 }
237 return result;
238 }
239 // Returns the newly incurred network stats since the previous snapshot of
240 // stats (the `reference_point` argument).
GetNetworkStatsSince(FederatedProtocol * federated_protocol,FederatedSelectManager * fedselect_manager,const NetworkStats & reference_point)241 NetworkStats GetNetworkStatsSince(FederatedProtocol* federated_protocol,
242 FederatedSelectManager* fedselect_manager,
243 const NetworkStats& reference_point) {
244 return GetCumulativeNetworkStats(federated_protocol, fedselect_manager) -
245 reference_point;
246 }
247 // Updates the fields of `FLRunnerResult` that should always be updated after
248 // each interaction with the `FederatedProtocol` or `FederatedSelectManager`
249 // objects.
250 //
251 // The `FederatedSelectManager` object may be null, if it is know that there
252 // has been no network usage from it yet.
UpdateRetryWindowAndNetworkStats(FederatedProtocol & federated_protocol,FederatedSelectManager * fedselect_manager,PhaseLogger & phase_logger,FLRunnerResult & fl_runner_result)253 void UpdateRetryWindowAndNetworkStats(FederatedProtocol& federated_protocol,
254 FederatedSelectManager* fedselect_manager,
255 PhaseLogger& phase_logger,
256 FLRunnerResult& fl_runner_result) {
257 // Update the result's retry window to the most recent one.
258 auto retry_window = federated_protocol.GetLatestRetryWindow();
259 RetryInfo retry_info;
260 *retry_info.mutable_retry_token() = retry_window.retry_token();
261 *retry_info.mutable_minimum_delay() = retry_window.delay_min();
262 *fl_runner_result.mutable_retry_info() = retry_info;
263 phase_logger.UpdateRetryWindowAndNetworkStats(
264 retry_window,
265 GetCumulativeNetworkStats(&federated_protocol, fedselect_manager));
266 }
267 // Creates an ExampleIteratorFactory that routes queries to the
268 // SimpleTaskEnvironment::CreateExampleIterator() method.
269 std::unique_ptr<engine::ExampleIteratorFactory>
CreateSimpleTaskEnvironmentIteratorFactory(SimpleTaskEnvironment * task_env,const SelectorContext & selector_context)270 CreateSimpleTaskEnvironmentIteratorFactory(
271 SimpleTaskEnvironment* task_env, const SelectorContext& selector_context) {
272 return std::make_unique<engine::FunctionalExampleIteratorFactory>(
273 /*can_handle_func=*/
274 [](const google::internal::federated::plan::ExampleSelector&) {
275 // The SimpleTaskEnvironment-based ExampleIteratorFactory should
276 // be the catch-all factory that is able to handle all queries
277 // that no other ExampleIteratorFactory is able to handle.
278 return true;
279 },
280 /*create_iterator_func=*/
281 [task_env, selector_context](
282 const google::internal::federated::plan::ExampleSelector&
283 example_selector) {
284 return task_env->CreateExampleIterator(example_selector,
285 selector_context);
286 },
287 /*should_collect_stats=*/true);
288 }
RunEligibilityEvalPlanWithTensorflowSpec(std::vector<engine::ExampleIteratorFactory * > example_iterator_factories,std::function<bool ()> should_abort,LogManager * log_manager,OpStatsLogger * opstats_logger,const Flags * flags,const ClientOnlyPlan & client_plan,const std::string & checkpoint_input_filename,const fcp::client::InterruptibleRunner::TimingConfig & timing_config,const absl::Time run_plan_start_time,const absl::Time reference_time)289 engine::PlanResult RunEligibilityEvalPlanWithTensorflowSpec(
290 std::vector<engine::ExampleIteratorFactory*> example_iterator_factories,
291 std::function<bool()> should_abort, LogManager* log_manager,
292 OpStatsLogger* opstats_logger, const Flags* flags,
293 const ClientOnlyPlan& client_plan,
294 const std::string& checkpoint_input_filename,
295 const fcp::client::InterruptibleRunner::TimingConfig& timing_config,
296 const absl::Time run_plan_start_time, const absl::Time reference_time) {
297 // Check that this is a TensorflowSpec-based plan for federated eligibility
298 // computation.
299 if (!client_plan.phase().has_tensorflow_spec() ||
300 !client_plan.phase().has_federated_compute_eligibility()) {
301 return engine::PlanResult(
302 engine::PlanOutcome::kInvalidArgument,
303 absl::InvalidArgumentError("Invalid eligibility eval plan"));
304 }
305 const FederatedComputeEligibilityIORouter& io_router =
306 client_plan.phase().federated_compute_eligibility();
307 std::vector<std::string> output_names = {
308 io_router.task_eligibility_info_tensor_name()};
309 if (!client_plan.tflite_graph().empty()) {
310 log_manager->LogDiag(
311 ProdDiagCode::BACKGROUND_TRAINING_TFLITE_MODEL_INCLUDED);
312 }
313 if (flags->use_tflite_training() && !client_plan.tflite_graph().empty()) {
314 std::unique_ptr<TfLiteInputs> tflite_inputs =
315 ConstructTfLiteInputsForEligibilityEvalPlan(io_router,
316 checkpoint_input_filename);
317 engine::TfLitePlanEngine plan_engine(example_iterator_factories,
318 should_abort, log_manager,
319 opstats_logger, flags, &timing_config);
320 return plan_engine.RunPlan(client_plan.phase().tensorflow_spec(),
321 client_plan.tflite_graph(),
322 std::move(tflite_inputs), output_names);
323 }
324 #ifdef FCP_CLIENT_SUPPORT_TFMOBILE
325 // Construct input tensors and output tensor names based on the values in
326 // the FederatedComputeEligibilityIORouter message.
327 auto inputs = ConstructInputsForEligibilityEvalPlan(
328 io_router, checkpoint_input_filename);
329 // Run plan and get a set of output tensors back.
330 engine::SimplePlanEngine plan_engine(
331 example_iterator_factories, should_abort, log_manager, opstats_logger,
332 &timing_config, flags->support_constant_tf_inputs());
333 return plan_engine.RunPlan(
334 client_plan.phase().tensorflow_spec(), client_plan.graph(),
335 client_plan.tensorflow_config_proto(), std::move(inputs), output_names);
336 #else
337 return engine::PlanResult(
338 engine::PlanOutcome::kTensorflowError,
339 absl::InternalError("No eligibility eval plan engine enabled"));
340 #endif
341 }
342 // Validates the output tensors that resulted from executing the plan, and
343 // then parses the output into a TaskEligibilityInfo proto. Returns an error
344 // if validation or parsing failed.
ParseEligibilityEvalPlanOutput(const std::vector<tensorflow::Tensor> & output_tensors)345 absl::StatusOr<TaskEligibilityInfo> ParseEligibilityEvalPlanOutput(
346 const std::vector<tensorflow::Tensor>& output_tensors) {
347 auto output_size = output_tensors.size();
348 if (output_size != 1) {
349 return absl::InvalidArgumentError(
350 absl::StrCat("Unexpected number of output tensors: ", output_size));
351 }
352 auto output_elements = output_tensors[0].NumElements();
353 if (output_elements != 1) {
354 return absl::InvalidArgumentError(absl::StrCat(
355 "Unexpected number of output tensor elements: ", output_elements));
356 }
357 tensorflow::DataType output_type = output_tensors[0].dtype();
358 if (output_type != tensorflow::DT_STRING) {
359 return absl::InvalidArgumentError(
360 absl::StrCat("Unexpected output tensor type: ", output_type));
361 }
362 // Extract the serialized TaskEligibilityInfo proto from the tensor and
363 // parse it.
364 // First, convert the output Tensor into a Scalar (= a TensorMap with 1
365 // element), then use its operator() to access the actual data.
366 const tensorflow::tstring& serialized_output =
367 output_tensors[0].scalar<const tensorflow::tstring>()();
368 TaskEligibilityInfo parsed_output;
369 if (!parsed_output.ParseFromString(serialized_output)) {
370 return absl::InvalidArgumentError("Could not parse output proto");
371 }
372 return parsed_output;
373 }
374 #ifdef FCP_CLIENT_SUPPORT_TFMOBILE
375 std::unique_ptr<std::vector<std::pair<std::string, tensorflow::Tensor>>>
ConstructInputsForTensorflowSpecPlan(const FederatedComputeIORouter & io_router,const std::string & checkpoint_input_filename,const std::string & checkpoint_output_filename)376 ConstructInputsForTensorflowSpecPlan(
377 const FederatedComputeIORouter& io_router,
378 const std::string& checkpoint_input_filename,
379 const std::string& checkpoint_output_filename) {
380 auto inputs = std::make_unique<
381 std::vector<std::pair<std::string, tensorflow::Tensor>>>();
382 if (!io_router.input_filepath_tensor_name().empty()) {
383 tensorflow::Tensor input_filepath(tensorflow::DT_STRING, {});
384 input_filepath.scalar<tensorflow::tstring>()() = checkpoint_input_filename;
385 inputs->push_back({io_router.input_filepath_tensor_name(), input_filepath});
386 }
387 if (!io_router.output_filepath_tensor_name().empty()) {
388 tensorflow::Tensor output_filepath(tensorflow::DT_STRING, {});
389 output_filepath.scalar<tensorflow::tstring>()() =
390 checkpoint_output_filename;
391 inputs->push_back(
392 {io_router.output_filepath_tensor_name(), output_filepath});
393 }
394 return inputs;
395 }
396 #endif
ConstructTFLiteInputsForTensorflowSpecPlan(const FederatedComputeIORouter & io_router,const std::string & checkpoint_input_filename,const std::string & checkpoint_output_filename)397 std::unique_ptr<TfLiteInputs> ConstructTFLiteInputsForTensorflowSpecPlan(
398 const FederatedComputeIORouter& io_router,
399 const std::string& checkpoint_input_filename,
400 const std::string& checkpoint_output_filename) {
401 auto inputs = std::make_unique<TfLiteInputs>();
402 if (!io_router.input_filepath_tensor_name().empty()) {
403 (*inputs)[io_router.input_filepath_tensor_name()] =
404 checkpoint_input_filename;
405 }
406 if (!io_router.output_filepath_tensor_name().empty()) {
407 (*inputs)[io_router.output_filepath_tensor_name()] =
408 checkpoint_output_filename;
409 }
410 return inputs;
411 }
ConstructOutputsWithDeterministicOrder(const TensorflowSpec & tensorflow_spec,const FederatedComputeIORouter & io_router)412 absl::StatusOr<std::vector<std::string>> ConstructOutputsWithDeterministicOrder(
413 const TensorflowSpec& tensorflow_spec,
414 const FederatedComputeIORouter& io_router) {
415 std::vector<std::string> output_names;
416 // The order of output tensor names should match the order in
417 // TensorflowSpec.
418 for (const auto& output_tensor_spec : tensorflow_spec.output_tensor_specs()) {
419 std::string tensor_name = output_tensor_spec.name();
420 if (!io_router.aggregations().contains(tensor_name) ||
421 !io_router.aggregations().at(tensor_name).has_secure_aggregation()) {
422 return absl::InvalidArgumentError(
423 "Output tensor is missing in AggregationConfig, or has unsupported "
424 "aggregation type.");
425 }
426 output_names.push_back(tensor_name);
427 }
428 return output_names;
429 }
RunPlanWithTensorflowSpec(std::vector<engine::ExampleIteratorFactory * > example_iterator_factories,std::function<bool ()> should_abort,LogManager * log_manager,OpStatsLogger * opstats_logger,const Flags * flags,const ClientOnlyPlan & client_plan,const std::string & checkpoint_input_filename,const std::string & checkpoint_output_filename,const fcp::client::InterruptibleRunner::TimingConfig & timing_config)430 PlanResultAndCheckpointFile RunPlanWithTensorflowSpec(
431 std::vector<engine::ExampleIteratorFactory*> example_iterator_factories,
432 std::function<bool()> should_abort, LogManager* log_manager,
433 OpStatsLogger* opstats_logger, const Flags* flags,
434 const ClientOnlyPlan& client_plan,
435 const std::string& checkpoint_input_filename,
436 const std::string& checkpoint_output_filename,
437 const fcp::client::InterruptibleRunner::TimingConfig& timing_config) {
438 if (!client_plan.phase().has_tensorflow_spec()) {
439 return PlanResultAndCheckpointFile(engine::PlanResult(
440 engine::PlanOutcome::kInvalidArgument,
441 absl::InvalidArgumentError("Plan must include TensorflowSpec.")));
442 }
443 if (!client_plan.phase().has_federated_compute()) {
444 return PlanResultAndCheckpointFile(engine::PlanResult(
445 engine::PlanOutcome::kInvalidArgument,
446 absl::InvalidArgumentError("Invalid TensorflowSpec-based plan")));
447 }
448 // Get the output tensor names.
449 absl::StatusOr<std::vector<std::string>> output_names;
450 output_names = ConstructOutputsWithDeterministicOrder(
451 client_plan.phase().tensorflow_spec(),
452 client_plan.phase().federated_compute());
453 if (!output_names.ok()) {
454 return PlanResultAndCheckpointFile(engine::PlanResult(
455 engine::PlanOutcome::kInvalidArgument, output_names.status()));
456 }
457 // Run plan and get a set of output tensors back.
458 if (flags->use_tflite_training() && !client_plan.tflite_graph().empty()) {
459 std::unique_ptr<TfLiteInputs> tflite_inputs =
460 ConstructTFLiteInputsForTensorflowSpecPlan(
461 client_plan.phase().federated_compute(), checkpoint_input_filename,
462 checkpoint_output_filename);
463 engine::TfLitePlanEngine plan_engine(example_iterator_factories,
464 should_abort, log_manager,
465 opstats_logger, flags, &timing_config);
466 engine::PlanResult plan_result = plan_engine.RunPlan(
467 client_plan.phase().tensorflow_spec(), client_plan.tflite_graph(),
468 std::move(tflite_inputs), *output_names);
469 PlanResultAndCheckpointFile result(std::move(plan_result));
470 result.checkpoint_file = checkpoint_output_filename;
471 return result;
472 }
473 #ifdef FCP_CLIENT_SUPPORT_TFMOBILE
474 // Construct input tensors based on the values in the
475 // FederatedComputeIORouter message and create a temporary file for the
476 // output checkpoint if needed.
477 auto inputs = ConstructInputsForTensorflowSpecPlan(
478 client_plan.phase().federated_compute(), checkpoint_input_filename,
479 checkpoint_output_filename);
480 engine::SimplePlanEngine plan_engine(
481 example_iterator_factories, should_abort, log_manager, opstats_logger,
482 &timing_config, flags->support_constant_tf_inputs());
483 engine::PlanResult plan_result = plan_engine.RunPlan(
484 client_plan.phase().tensorflow_spec(), client_plan.graph(),
485 client_plan.tensorflow_config_proto(), std::move(inputs), *output_names);
486 PlanResultAndCheckpointFile result(std::move(plan_result));
487 result.checkpoint_file = checkpoint_output_filename;
488 return result;
489 #else
490 return PlanResultAndCheckpointFile(
491 engine::PlanResult(engine::PlanOutcome::kTensorflowError,
492 absl::InternalError("No plan engine enabled")));
493 #endif
494 }
RunPlanWithExampleQuerySpec(std::vector<engine::ExampleIteratorFactory * > example_iterator_factories,OpStatsLogger * opstats_logger,const Flags * flags,const ClientOnlyPlan & client_plan,const std::string & checkpoint_output_filename)495 PlanResultAndCheckpointFile RunPlanWithExampleQuerySpec(
496 std::vector<engine::ExampleIteratorFactory*> example_iterator_factories,
497 OpStatsLogger* opstats_logger, const Flags* flags,
498 const ClientOnlyPlan& client_plan,
499 const std::string& checkpoint_output_filename) {
500 if (!client_plan.phase().has_example_query_spec()) {
501 return PlanResultAndCheckpointFile(engine::PlanResult(
502 engine::PlanOutcome::kInvalidArgument,
503 absl::InvalidArgumentError("Plan must include ExampleQuerySpec")));
504 }
505 if (!flags->enable_example_query_plan_engine()) {
506 // Example query plan received while the flag is off.
507 return PlanResultAndCheckpointFile(engine::PlanResult(
508 engine::PlanOutcome::kInvalidArgument,
509 absl::InvalidArgumentError(
510 "Example query plan received while the flag is off")));
511 }
512 if (!client_plan.phase().has_federated_example_query()) {
513 return PlanResultAndCheckpointFile(engine::PlanResult(
514 engine::PlanOutcome::kInvalidArgument,
515 absl::InvalidArgumentError("Invalid ExampleQuerySpec-based plan")));
516 }
517 for (const auto& example_query :
518 client_plan.phase().example_query_spec().example_queries()) {
519 for (auto const& [vector_name, spec] :
520 example_query.output_vector_specs()) {
521 const auto& aggregations =
522 client_plan.phase().federated_example_query().aggregations();
523 if ((aggregations.find(vector_name) == aggregations.end()) ||
524 !aggregations.at(vector_name).has_tf_v1_checkpoint_aggregation()) {
525 return PlanResultAndCheckpointFile(engine::PlanResult(
526 engine::PlanOutcome::kInvalidArgument,
527 absl::InvalidArgumentError("Output vector is missing in "
528 "AggregationConfig, or has unsupported "
529 "aggregation type.")));
530 }
531 }
532 }
533 engine::ExampleQueryPlanEngine plan_engine(example_iterator_factories,
534 opstats_logger);
535 engine::PlanResult plan_result = plan_engine.RunPlan(
536 client_plan.phase().example_query_spec(), checkpoint_output_filename);
537 PlanResultAndCheckpointFile result(std::move(plan_result));
538 result.checkpoint_file = checkpoint_output_filename;
539 return result;
540 }
LogEligibilityEvalComputationOutcome(PhaseLogger & phase_logger,engine::PlanResult plan_result,const absl::Status & eligibility_info_parsing_status,absl::Time run_plan_start_time,absl::Time reference_time)541 void LogEligibilityEvalComputationOutcome(
542 PhaseLogger& phase_logger, engine::PlanResult plan_result,
543 const absl::Status& eligibility_info_parsing_status,
544 absl::Time run_plan_start_time, absl::Time reference_time) {
545 switch (plan_result.outcome) {
546 case engine::PlanOutcome::kSuccess: {
547 if (eligibility_info_parsing_status.ok()) {
548 phase_logger.LogEligibilityEvalComputationCompleted(
549 plan_result.example_stats, run_plan_start_time, reference_time);
550 } else {
551 phase_logger.LogEligibilityEvalComputationTensorflowError(
552 eligibility_info_parsing_status, plan_result.example_stats,
553 run_plan_start_time, reference_time);
554 FCP_LOG(ERROR) << eligibility_info_parsing_status.message();
555 }
556 break;
557 }
558 case engine::PlanOutcome::kInterrupted:
559 phase_logger.LogEligibilityEvalComputationInterrupted(
560 plan_result.original_status, plan_result.example_stats,
561 run_plan_start_time, reference_time);
562 break;
563 case engine::PlanOutcome::kInvalidArgument:
564 phase_logger.LogEligibilityEvalComputationInvalidArgument(
565 plan_result.original_status, plan_result.example_stats,
566 run_plan_start_time);
567 break;
568 case engine::PlanOutcome::kTensorflowError:
569 phase_logger.LogEligibilityEvalComputationTensorflowError(
570 plan_result.original_status, plan_result.example_stats,
571 run_plan_start_time, reference_time);
572 break;
573 case engine::PlanOutcome::kExampleIteratorError:
574 phase_logger.LogEligibilityEvalComputationExampleIteratorError(
575 plan_result.original_status, plan_result.example_stats,
576 run_plan_start_time);
577 break;
578 }
579 }
LogComputationOutcome(const engine::PlanResult & plan_result,absl::Status computation_results_parsing_status,PhaseLogger & phase_logger,const NetworkStats & network_stats,absl::Time run_plan_start_time,absl::Time reference_time)580 void LogComputationOutcome(const engine::PlanResult& plan_result,
581 absl::Status computation_results_parsing_status,
582 PhaseLogger& phase_logger,
583 const NetworkStats& network_stats,
584 absl::Time run_plan_start_time,
585 absl::Time reference_time) {
586 switch (plan_result.outcome) {
587 case engine::PlanOutcome::kSuccess: {
588 if (computation_results_parsing_status.ok()) {
589 phase_logger.LogComputationCompleted(plan_result.example_stats,
590 network_stats, run_plan_start_time,
591 reference_time);
592 } else {
593 phase_logger.LogComputationTensorflowError(
594 computation_results_parsing_status, plan_result.example_stats,
595 network_stats, run_plan_start_time, reference_time);
596 }
597 break;
598 }
599 case engine::PlanOutcome::kInterrupted:
600 phase_logger.LogComputationInterrupted(
601 plan_result.original_status, plan_result.example_stats, network_stats,
602 run_plan_start_time, reference_time);
603 break;
604 case engine::PlanOutcome::kInvalidArgument:
605 phase_logger.LogComputationInvalidArgument(
606 plan_result.original_status, plan_result.example_stats, network_stats,
607 run_plan_start_time);
608 break;
609 case engine::PlanOutcome::kTensorflowError:
610 phase_logger.LogComputationTensorflowError(
611 plan_result.original_status, plan_result.example_stats, network_stats,
612 run_plan_start_time, reference_time);
613 break;
614 case engine::PlanOutcome::kExampleIteratorError:
615 phase_logger.LogComputationExampleIteratorError(
616 plan_result.original_status, plan_result.example_stats, network_stats,
617 run_plan_start_time);
618 break;
619 }
620 }
LogResultUploadStatus(PhaseLogger & phase_logger,absl::Status result,const NetworkStats & network_stats,absl::Time time_before_result_upload,absl::Time reference_time)621 void LogResultUploadStatus(PhaseLogger& phase_logger, absl::Status result,
622 const NetworkStats& network_stats,
623 absl::Time time_before_result_upload,
624 absl::Time reference_time) {
625 if (result.ok()) {
626 phase_logger.LogResultUploadCompleted(
627 network_stats, time_before_result_upload, reference_time);
628 } else {
629 auto message =
630 absl::StrCat("Error reporting results: code: ", result.code(),
631 ", message: ", result.message());
632 FCP_LOG(INFO) << message;
633 if (result.code() == absl::StatusCode::kAborted) {
634 phase_logger.LogResultUploadServerAborted(
635 result, network_stats, time_before_result_upload, reference_time);
636 } else if (result.code() == absl::StatusCode::kCancelled) {
637 phase_logger.LogResultUploadClientInterrupted(
638 result, network_stats, time_before_result_upload, reference_time);
639 } else {
640 phase_logger.LogResultUploadIOError(
641 result, network_stats, time_before_result_upload, reference_time);
642 }
643 }
644 }
LogFailureUploadStatus(PhaseLogger & phase_logger,absl::Status result,const NetworkStats & network_stats,absl::Time time_before_failure_upload,absl::Time reference_time)645 void LogFailureUploadStatus(PhaseLogger& phase_logger, absl::Status result,
646 const NetworkStats& network_stats,
647 absl::Time time_before_failure_upload,
648 absl::Time reference_time) {
649 if (result.ok()) {
650 phase_logger.LogFailureUploadCompleted(
651 network_stats, time_before_failure_upload, reference_time);
652 } else {
653 auto message = absl::StrCat("Error reporting computation failure: code: ",
654 result.code(), ", message: ", result.message());
655 FCP_LOG(INFO) << message;
656 if (result.code() == absl::StatusCode::kAborted) {
657 phase_logger.LogFailureUploadServerAborted(
658 result, network_stats, time_before_failure_upload, reference_time);
659 } else if (result.code() == absl::StatusCode::kCancelled) {
660 phase_logger.LogFailureUploadClientInterrupted(
661 result, network_stats, time_before_failure_upload, reference_time);
662 } else {
663 phase_logger.LogFailureUploadIOError(
664 result, network_stats, time_before_failure_upload, reference_time);
665 }
666 }
667 }
ReportPlanResult(FederatedProtocol * federated_protocol,PhaseLogger & phase_logger,absl::StatusOr<ComputationResults> computation_results,absl::Time run_plan_start_time,absl::Time reference_time)668 absl::Status ReportPlanResult(
669 FederatedProtocol* federated_protocol, PhaseLogger& phase_logger,
670 absl::StatusOr<ComputationResults> computation_results,
671 absl::Time run_plan_start_time, absl::Time reference_time) {
672 const absl::Time before_report_time = absl::Now();
673 // Note that the FederatedSelectManager shouldn't be active anymore during
674 // the reporting of results, so we don't bother passing it to
675 // GetNetworkStatsSince.
676 //
677 // We must return only stats that cover the report phase for the log events
678 // below.
679 const NetworkStats before_report_stats =
680 GetCumulativeNetworkStats(federated_protocol,
681 /*fedselect_manager=*/nullptr);
682 absl::Status result = absl::InternalError("");
683 if (computation_results.ok()) {
684 FCP_RETURN_IF_ERROR(phase_logger.LogResultUploadStarted());
685 result = federated_protocol->ReportCompleted(
686 std::move(*computation_results),
687 /*plan_duration=*/absl::Now() - run_plan_start_time, std::nullopt);
688 LogResultUploadStatus(phase_logger, result,
689 GetNetworkStatsSince(federated_protocol,
690 /*fedselect_manager=*/nullptr,
691 before_report_stats),
692 before_report_time, reference_time);
693 } else {
694 FCP_RETURN_IF_ERROR(phase_logger.LogFailureUploadStarted());
695 result = federated_protocol->ReportNotCompleted(
696 engine::PhaseOutcome::ERROR,
697 /*plan_duration=*/absl::Now() - run_plan_start_time, std::nullopt);
698 LogFailureUploadStatus(phase_logger, result,
699 GetNetworkStatsSince(federated_protocol,
700 /*fedselect_manager=*/nullptr,
701 before_report_stats),
702 before_report_time, reference_time);
703 }
704 return result;
705 }
706 // Writes the given data to the stream, and returns true if successful and
707 // false if not.
WriteStringOrCordToFstream(std::fstream & stream,const std::variant<std::string,absl::Cord> & data)708 bool WriteStringOrCordToFstream(
709 std::fstream& stream, const std::variant<std::string, absl::Cord>& data) {
710 if (stream.fail()) {
711 return false;
712 }
713 if (std::holds_alternative<std::string>(data)) {
714 return (stream << std::get<std::string>(data)).good();
715 }
716 for (absl::string_view chunk : std::get<absl::Cord>(data).Chunks()) {
717 if (!(stream << chunk).good()) {
718 return false;
719 }
720 }
721 return true;
722 }
723 // Writes the given checkpoint data to a newly created temporary file.
724 // Returns the filename if successful, or an error if the file could not be
725 // created, or if writing to the file failed.
CreateInputCheckpointFile(Files * files,const std::variant<std::string,absl::Cord> & checkpoint)726 absl::StatusOr<std::string> CreateInputCheckpointFile(
727 Files* files, const std::variant<std::string, absl::Cord>& checkpoint) {
728 // Create the temporary checkpoint file.
729 // Deletion of the file is left to the caller / the Files implementation.
730 FCP_ASSIGN_OR_RETURN(absl::StatusOr<std::string> filename,
731 files->CreateTempFile("init", ".ckp"));
732 // Write the checkpoint data to the file.
733 std::fstream checkpoint_stream(*filename, std::ios_base::out);
734 if (!WriteStringOrCordToFstream(checkpoint_stream, checkpoint)) {
735 return absl::InvalidArgumentError("Failed to write to file");
736 }
737 checkpoint_stream.close();
738 return filename;
739 }
RunEligibilityEvalPlan(const FederatedProtocol::EligibilityEvalTask & eligibility_eval_task,std::vector<engine::ExampleIteratorFactory * > example_iterator_factories,std::function<bool ()> should_abort,PhaseLogger & phase_logger,Files * files,LogManager * log_manager,OpStatsLogger * opstats_logger,const Flags * flags,FederatedProtocol * federated_protocol,const fcp::client::InterruptibleRunner::TimingConfig & timing_config,const absl::Time reference_time,const absl::Time time_before_checkin,const absl::Time time_before_plan_download,const NetworkStats & network_stats)740 absl::StatusOr<std::optional<TaskEligibilityInfo>> RunEligibilityEvalPlan(
741 const FederatedProtocol::EligibilityEvalTask& eligibility_eval_task,
742 std::vector<engine::ExampleIteratorFactory*> example_iterator_factories,
743 std::function<bool()> should_abort, PhaseLogger& phase_logger, Files* files,
744 LogManager* log_manager, OpStatsLogger* opstats_logger, const Flags* flags,
745 FederatedProtocol* federated_protocol,
746 const fcp::client::InterruptibleRunner::TimingConfig& timing_config,
747 const absl::Time reference_time, const absl::Time time_before_checkin,
748 const absl::Time time_before_plan_download,
749 const NetworkStats& network_stats) {
750 ClientOnlyPlan plan;
751 if (!ParseFromStringOrCord(plan, eligibility_eval_task.payloads.plan)) {
752 auto message = "Failed to parse received eligibility eval plan";
753 phase_logger.LogEligibilityEvalCheckinInvalidPayloadError(
754 message, network_stats, time_before_plan_download);
755 FCP_LOG(ERROR) << message;
756 return absl::InternalError(message);
757 }
758 absl::StatusOr<std::string> checkpoint_input_filename =
759 CreateInputCheckpointFile(files,
760 eligibility_eval_task.payloads.checkpoint);
761 if (!checkpoint_input_filename.ok()) {
762 auto status = checkpoint_input_filename.status();
763 auto message = absl::StrCat(
764 "Failed to create eligibility eval checkpoint input file: code: ",
765 status.code(), ", message: ", status.message());
766 phase_logger.LogEligibilityEvalCheckinIOError(status, network_stats,
767 time_before_plan_download);
768 FCP_LOG(ERROR) << message;
769 return absl::InternalError("");
770 }
771 phase_logger.LogEligibilityEvalCheckinCompleted(network_stats,
772 /*time_before_checkin=*/
773 time_before_checkin,
774 /*time_before_plan_download=*/
775 time_before_plan_download);
776 absl::Time run_plan_start_time = absl::Now();
777 phase_logger.LogEligibilityEvalComputationStarted();
778 engine::PlanResult plan_result = RunEligibilityEvalPlanWithTensorflowSpec(
779 example_iterator_factories, should_abort, log_manager, opstats_logger,
780 flags, plan, *checkpoint_input_filename, timing_config,
781 run_plan_start_time, reference_time);
782 absl::StatusOr<TaskEligibilityInfo> task_eligibility_info;
783 if (plan_result.outcome == engine::PlanOutcome::kSuccess) {
784 task_eligibility_info =
785 ParseEligibilityEvalPlanOutput(plan_result.output_tensors);
786 }
787 LogEligibilityEvalComputationOutcome(phase_logger, std::move(plan_result),
788 task_eligibility_info.status(),
789 run_plan_start_time, reference_time);
790 return task_eligibility_info;
791 }
792 struct EligibilityEvalResult {
793 std::optional<TaskEligibilityInfo> task_eligibility_info;
794 std::vector<std::string> task_names_for_multiple_task_assignments;
795 };
796 // Create an EligibilityEvalResult from a TaskEligibilityInfo and
797 // PopulationEligibilitySpec. If both population_spec and
798 // task_eligibility_info are present, the returned EligibilityEvalResult will
799 // contain a TaskEligibilityInfo which only contains the tasks for single task
800 // assignment, and a vector of task names for multiple task assignment.
CreateEligibilityEvalResult(const std::optional<TaskEligibilityInfo> & task_eligibility_info,const std::optional<PopulationEligibilitySpec> & population_spec)801 EligibilityEvalResult CreateEligibilityEvalResult(
802 const std::optional<TaskEligibilityInfo>& task_eligibility_info,
803 const std::optional<PopulationEligibilitySpec>& population_spec) {
804 EligibilityEvalResult result;
805 if (population_spec.has_value() && task_eligibility_info.has_value()) {
806 absl::flat_hash_set<std::string> task_names_for_multiple_task_assignments;
807 for (const auto& task_info : population_spec.value().task_info()) {
808 if (task_info.task_assignment_mode() ==
809 PopulationEligibilitySpec::TaskInfo::TASK_ASSIGNMENT_MODE_MULTIPLE) {
810 task_names_for_multiple_task_assignments.insert(task_info.task_name());
811 }
812 }
813 TaskEligibilityInfo single_task_assignment_eligibility_info;
814 single_task_assignment_eligibility_info.set_version(
815 task_eligibility_info.value().version());
816 for (const auto& task_weight :
817 task_eligibility_info.value().task_weights()) {
818 if (task_names_for_multiple_task_assignments.contains(
819 task_weight.task_name())) {
820 result.task_names_for_multiple_task_assignments.push_back(
821 task_weight.task_name());
822 } else {
823 *single_task_assignment_eligibility_info.mutable_task_weights()->Add() =
824 task_weight;
825 }
826 }
827 result.task_eligibility_info = single_task_assignment_eligibility_info;
828 } else {
829 result.task_eligibility_info = task_eligibility_info;
830 }
831 return result;
832 }
833 // Issues an eligibility eval checkin request and executes the eligibility
834 // eval task if the server returns one.
835 //
836 // This function modifies the FLRunnerResult with values received over the
837 // course of the eligibility eval protocol interaction.
838 //
839 // Returns:
840 // - the TaskEligibilityInfo produced by the eligibility eval task, if the
841 // server provided an eligibility eval task to run.
842 // - an std::nullopt if the server indicated that there is no eligibility eval
843 // task configured for the population.
844 // - an INTERNAL error if the server rejects the client or another error
845 // occurs
846 // that should abort the training run. The error will already have been
847 // logged appropriately.
IssueEligibilityEvalCheckinAndRunPlan(std::vector<engine::ExampleIteratorFactory * > example_iterator_factories,std::function<bool ()> should_abort,PhaseLogger & phase_logger,Files * files,LogManager * log_manager,OpStatsLogger * opstats_logger,const Flags * flags,FederatedProtocol * federated_protocol,const fcp::client::InterruptibleRunner::TimingConfig & timing_config,const absl::Time reference_time,FLRunnerResult & fl_runner_result)848 absl::StatusOr<EligibilityEvalResult> IssueEligibilityEvalCheckinAndRunPlan(
849 std::vector<engine::ExampleIteratorFactory*> example_iterator_factories,
850 std::function<bool()> should_abort, PhaseLogger& phase_logger, Files* files,
851 LogManager* log_manager, OpStatsLogger* opstats_logger, const Flags* flags,
852 FederatedProtocol* federated_protocol,
853 const fcp::client::InterruptibleRunner::TimingConfig& timing_config,
854 const absl::Time reference_time, FLRunnerResult& fl_runner_result) {
855 const absl::Time time_before_checkin = absl::Now();
856 const NetworkStats network_stats_before_checkin =
857 GetCumulativeNetworkStats(federated_protocol,
858 /*fedselect_manager=*/nullptr);
859 // These fields will, after a successful checkin that resulted in an EET
860 // being received, contain the time at which the EET plan/checkpoint URIs
861 // were received (but not yet downloaded), as well as the cumulative network
862 // stats at that point, allowing us to separately calculate how long it took
863 // to then download the actual payloads.
864 absl::Time time_before_plan_download = time_before_checkin;
865 NetworkStats network_stats_before_plan_download =
866 network_stats_before_checkin;
867 // Log that we are about to check in with the server.
868 phase_logger.LogEligibilityEvalCheckinStarted();
869 // Issue the eligibility eval checkin request (providing a callback that
870 // will be called when an EET is assigned to the task but before its
871 // plan/checkpoint URIs have actually been downloaded).
872 bool plan_uris_received_callback_called = false;
873 std::function<void(const FederatedProtocol::EligibilityEvalTask&)>
874 plan_uris_received_callback =
875 [&time_before_plan_download, &network_stats_before_plan_download,
876 &time_before_checkin, &network_stats_before_checkin,
877 &federated_protocol, &phase_logger,
878 &plan_uris_received_callback_called](
879 const FederatedProtocol::EligibilityEvalTask& task) {
880 // When the plan URIs have been received, we already know the name
881 // of the task we have been assigned, so let's tell the
882 // PhaseLogger.
883 phase_logger.SetModelIdentifier(task.execution_id);
884 // We also should log a corresponding log event.
885 phase_logger.LogEligibilityEvalCheckinPlanUriReceived(
886 GetNetworkStatsSince(federated_protocol,
887 /*fedselect_manager=*/nullptr,
888 network_stats_before_checkin),
889 time_before_checkin);
890 // And we must take a snapshot of the current time & network
891 // stats, so we can distinguish between the duration/network stats
892 // incurred for the checkin request vs. the actual downloading of
893 // the plan/checkpoint resources.
894 time_before_plan_download = absl::Now();
895 network_stats_before_plan_download =
896 GetCumulativeNetworkStats(federated_protocol,
897 /*fedselect_manager=*/nullptr);
898 plan_uris_received_callback_called = true;
899 };
900 absl::StatusOr<FederatedProtocol::EligibilityEvalCheckinResult>
901 eligibility_checkin_result = federated_protocol->EligibilityEvalCheckin(
902 plan_uris_received_callback);
903 UpdateRetryWindowAndNetworkStats(*federated_protocol,
904 /*fedselect_manager=*/nullptr, phase_logger,
905 fl_runner_result);
906 // It's a bit unfortunate that we have to inspect the checkin_result and
907 // extract the model identifier here rather than further down the function,
908 // but this ensures that the histograms below will have the right model
909 // identifier attached (and we want to also emit the histograms even if we
910 // have failed/rejected checkin outcomes).
911 if (eligibility_checkin_result.ok() &&
912 std::holds_alternative<FederatedProtocol::EligibilityEvalTask>(
913 *eligibility_checkin_result)) {
914 // Make sure that if we received an EligibilityEvalTask, then the callback
915 // should have already been called by this point by the protocol (ensuring
916 // that SetModelIdentifier has been called etc.).
917 FCP_CHECK(plan_uris_received_callback_called);
918 }
919 if (!eligibility_checkin_result.ok()) {
920 auto status = eligibility_checkin_result.status();
921 auto message = absl::StrCat("Error during eligibility eval checkin: code: ",
922 status.code(), ", message: ", status.message());
923 if (status.code() == absl::StatusCode::kAborted) {
924 phase_logger.LogEligibilityEvalCheckinServerAborted(
925 status,
926 GetNetworkStatsSince(federated_protocol,
927 /*fedselect_manager=*/nullptr,
928 network_stats_before_plan_download),
929 time_before_plan_download);
930 } else if (status.code() == absl::StatusCode::kCancelled) {
931 phase_logger.LogEligibilityEvalCheckinClientInterrupted(
932 status,
933 GetNetworkStatsSince(federated_protocol,
934 /*fedselect_manager=*/nullptr,
935 network_stats_before_plan_download),
936 time_before_plan_download);
937 } else if (!status.ok()) {
938 phase_logger.LogEligibilityEvalCheckinIOError(
939 status,
940 GetNetworkStatsSince(federated_protocol,
941 /*fedselect_manager=*/nullptr,
942 network_stats_before_plan_download),
943 time_before_plan_download);
944 }
945 FCP_LOG(INFO) << message;
946 return absl::InternalError("");
947 }
948 EligibilityEvalResult result;
949 if (std::holds_alternative<FederatedProtocol::Rejection>(
950 *eligibility_checkin_result)) {
951 phase_logger.LogEligibilityEvalCheckinTurnedAway(
952 GetNetworkStatsSince(federated_protocol,
953 /*fedselect_manager=*/nullptr,
954 network_stats_before_checkin),
955 time_before_checkin);
956 // If the server explicitly rejected our request, then we must abort and
957 // we must not proceed to the "checkin" phase below.
958 FCP_LOG(INFO) << "Device rejected by server during eligibility eval "
959 "checkin; aborting";
960 return absl::InternalError("");
961 } else if (std::holds_alternative<FederatedProtocol::EligibilityEvalDisabled>(
962 *eligibility_checkin_result)) {
963 phase_logger.LogEligibilityEvalNotConfigured(
964 GetNetworkStatsSince(federated_protocol,
965 /*fedselect_manager=*/nullptr,
966 network_stats_before_checkin),
967 time_before_checkin);
968 // If the server indicates that no eligibility eval task is configured for
969 // the population then there is nothing more to do. We simply proceed to
970 // the "checkin" phase below without providing it a TaskEligibilityInfo
971 // proto.
972 result.task_eligibility_info = std::nullopt;
973 return result;
974 }
975 auto eligibility_eval_task =
976 absl::get<FederatedProtocol::EligibilityEvalTask>(
977 *eligibility_checkin_result);
978 // Parse and run the eligibility eval task if the server returned one.
979 // Now we have a EligibilityEvalTask, if an error happens, we will report to
980 // the server via the ReportEligibilityEvalError.
981 absl::StatusOr<std::optional<TaskEligibilityInfo>> task_eligibility_info =
982 RunEligibilityEvalPlan(
983 eligibility_eval_task, example_iterator_factories, should_abort,
984 phase_logger, files, log_manager, opstats_logger, flags,
985 federated_protocol, timing_config, reference_time,
986 /*time_before_checkin=*/time_before_checkin,
987 /*time_before_plan_download=*/time_before_plan_download,
988 GetNetworkStatsSince(federated_protocol,
989 /*fedselect_manager=*/nullptr,
990 network_stats_before_plan_download));
991 if (!task_eligibility_info.ok()) {
992 // Note that none of the PhaseLogger methods will reflect the very little
993 // amount of network usage the will be incurred by this protocol request.
994 // We consider this to be OK to keep things simple, and because this
995 // should use such a limited amount of network bandwidth. Do note that the
996 // network usage *will* be correctly reported in the OpStats database.
997 federated_protocol->ReportEligibilityEvalError(
998 absl::Status(task_eligibility_info.status().code(),
999 "Failed to compute eligibility info"));
1000 UpdateRetryWindowAndNetworkStats(*federated_protocol,
1001 /*fedselect_manager=*/nullptr,
1002 phase_logger, fl_runner_result);
1003 return task_eligibility_info.status();
1004 }
1005 return CreateEligibilityEvalResult(
1006 *task_eligibility_info,
1007 eligibility_eval_task.population_eligibility_spec);
1008 }
1009 struct CheckinResult {
1010 std::string task_name;
1011 ClientOnlyPlan plan;
1012 int32_t minimum_clients_in_server_visible_aggregate;
1013 std::string checkpoint_input_filename;
1014 std::string computation_id;
1015 std::string federated_select_uri_template;
1016 };
IssueCheckin(PhaseLogger & phase_logger,LogManager * log_manager,Files * files,FederatedProtocol * federated_protocol,std::optional<TaskEligibilityInfo> task_eligibility_info,absl::Time reference_time,const std::string & population_name,FLRunnerResult & fl_runner_result,const Flags * flags)1017 absl::StatusOr<CheckinResult> IssueCheckin(
1018 PhaseLogger& phase_logger, LogManager* log_manager, Files* files,
1019 FederatedProtocol* federated_protocol,
1020 std::optional<TaskEligibilityInfo> task_eligibility_info,
1021 absl::Time reference_time, const std::string& population_name,
1022 FLRunnerResult& fl_runner_result, const Flags* flags) {
1023 absl::Time time_before_checkin = absl::Now();
1024 // We must return only stats that cover the check in phase for the log
1025 // events below.
1026 const NetworkStats network_stats_before_checkin =
1027 GetCumulativeNetworkStats(federated_protocol,
1028 /*fedselect_manager=*/nullptr);
1029 // These fields will, after a successful checkin that resulted in a task
1030 // being assigned, contain the time at which the task plan/checkpoint URIs
1031 // were received (but not yet downloaded), as well as the cumulative network
1032 // stats at that point, allowing us to separately calculate how long it took
1033 // to then download the actual payloads.
1034 absl::Time time_before_plan_download = time_before_checkin;
1035 NetworkStats network_stats_before_plan_download =
1036 network_stats_before_checkin;
1037 // Clear the model identifier before check-in, to ensure that the any prior
1038 // eligibility eval task name isn't used any longer.
1039 phase_logger.SetModelIdentifier("");
1040 phase_logger.LogCheckinStarted();
1041 std::string task_name;
1042 // Issue the checkin request (providing a callback that will be called when
1043 // an EET is assigned to the task but before its plan/checkpoint URIs have
1044 // actually been downloaded).
1045 bool plan_uris_received_callback_called = false;
1046 std::function<void(const FederatedProtocol::TaskAssignment&)>
1047 plan_uris_received_callback =
1048 [&time_before_plan_download, &network_stats_before_plan_download,
1049 &time_before_checkin, &network_stats_before_checkin, &task_name,
1050 &federated_protocol, &population_name, &log_manager, &phase_logger,
1051 &plan_uris_received_callback_called](
1052 const FederatedProtocol::TaskAssignment& task_assignment) {
1053 // When the plan URIs have been received, we already know the name
1054 // of the task we have been assigned, so let's tell the
1055 // PhaseLogger.
1056 auto model_identifier = task_assignment.aggregation_session_id;
1057 phase_logger.SetModelIdentifier(model_identifier);
1058 // We also should log a corresponding log event.
1059 task_name = ExtractTaskNameFromAggregationSessionId(
1060 model_identifier, population_name, *log_manager);
1061 phase_logger.LogCheckinPlanUriReceived(
1062 task_name,
1063 GetNetworkStatsSince(federated_protocol,
1064 /*fedselect_manager=*/nullptr,
1065 network_stats_before_checkin),
1066 time_before_checkin);
1067 // And we must take a snapshot of the current time & network
1068 // stats, so we can distinguish between the duration/network stats
1069 // incurred for the checkin request vs. the actual downloading of
1070 // the plan/checkpoint resources.
1071 time_before_plan_download = absl::Now();
1072 network_stats_before_plan_download = GetCumulativeNetworkStats(
1073 federated_protocol, /*fedselect_manager=*/nullptr);
1074 plan_uris_received_callback_called = true;
1075 };
1076 absl::StatusOr<FederatedProtocol::CheckinResult> checkin_result =
1077 federated_protocol->Checkin(task_eligibility_info,
1078 plan_uris_received_callback);
1079 UpdateRetryWindowAndNetworkStats(*federated_protocol,
1080 /*fedselect_manager=*/nullptr, phase_logger,
1081 fl_runner_result);
1082 // It's a bit unfortunate that we have to inspect the checkin_result and
1083 // extract the model identifier here rather than further down the function,
1084 // but this ensures that the histograms below will have the right model
1085 // identifier attached (and we want to also emit the histograms even if we
1086 // have failed/rejected checkin outcomes).
1087 if (checkin_result.ok() &&
1088 std::holds_alternative<FederatedProtocol::TaskAssignment>(
1089 *checkin_result)) {
1090 // Make sure that if we received a TaskAssignment, then the callback
1091 // should have already been called by this point by the protocol (ensuring
1092 // that SetModelIdentifier has been called etc.).
1093 FCP_CHECK(plan_uris_received_callback_called);
1094 }
1095 if (!checkin_result.ok()) {
1096 auto status = checkin_result.status();
1097 auto message = absl::StrCat("Error during checkin: code: ", status.code(),
1098 ", message: ", status.message());
1099 if (status.code() == absl::StatusCode::kAborted) {
1100 phase_logger.LogCheckinServerAborted(
1101 status,
1102 GetNetworkStatsSince(federated_protocol,
1103 /*fedselect_manager=*/nullptr,
1104 network_stats_before_plan_download),
1105 time_before_plan_download, reference_time);
1106 } else if (status.code() == absl::StatusCode::kCancelled) {
1107 phase_logger.LogCheckinClientInterrupted(
1108 status,
1109 GetNetworkStatsSince(federated_protocol,
1110 /*fedselect_manager=*/nullptr,
1111 network_stats_before_plan_download),
1112 time_before_plan_download, reference_time);
1113 } else if (!status.ok()) {
1114 phase_logger.LogCheckinIOError(
1115 status,
1116 GetNetworkStatsSince(federated_protocol,
1117 /*fedselect_manager=*/nullptr,
1118 network_stats_before_plan_download),
1119 time_before_plan_download, reference_time);
1120 }
1121 FCP_LOG(INFO) << message;
1122 return status;
1123 }
1124 // Server rejected us? Return the fl_runner_results as-is.
1125 if (std::holds_alternative<FederatedProtocol::Rejection>(*checkin_result)) {
1126 phase_logger.LogCheckinTurnedAway(
1127 GetNetworkStatsSince(federated_protocol,
1128 /*fedselect_manager=*/nullptr,
1129 network_stats_before_checkin),
1130 time_before_checkin, reference_time);
1131 FCP_LOG(INFO) << "Device rejected by server during checkin; aborting";
1132 return absl::InternalError("Device rejected by server.");
1133 }
1134 auto task_assignment =
1135 absl::get<FederatedProtocol::TaskAssignment>(*checkin_result);
1136 ClientOnlyPlan plan;
1137 auto plan_bytes = task_assignment.payloads.plan;
1138 if (!ParseFromStringOrCord(plan, plan_bytes)) {
1139 auto message = "Failed to parse received plan";
1140 phase_logger.LogCheckinInvalidPayload(
1141 message,
1142 GetNetworkStatsSince(federated_protocol,
1143 /*fedselect_manager=*/nullptr,
1144 network_stats_before_plan_download),
1145 time_before_plan_download, reference_time);
1146 FCP_LOG(ERROR) << message;
1147 return absl::InternalError("");
1148 }
1149 std::string computation_id;
1150 if (flags->enable_computation_id()) {
1151 computation_id = ComputeSHA256FromStringOrCord(plan_bytes);
1152 }
1153 int32_t minimum_clients_in_server_visible_aggregate = 0;
1154 if (task_assignment.sec_agg_info.has_value()) {
1155 auto minimum_number_of_participants =
1156 plan.phase().minimum_number_of_participants();
1157 if (task_assignment.sec_agg_info->expected_number_of_clients <
1158 minimum_number_of_participants) {
1159 return absl::InternalError(
1160 "expectedNumberOfClients was less than Plan's "
1161 "minimumNumberOfParticipants.");
1162 }
1163 minimum_clients_in_server_visible_aggregate =
1164 task_assignment.sec_agg_info
1165 ->minimum_clients_in_server_visible_aggregate;
1166 }
1167 absl::StatusOr<std::string> checkpoint_input_filename = "";
1168 // Example query plan does not have an input checkpoint.
1169 if (!plan.phase().has_example_query_spec()) {
1170 checkpoint_input_filename =
1171 CreateInputCheckpointFile(files, task_assignment.payloads.checkpoint);
1172 if (!checkpoint_input_filename.ok()) {
1173 auto status = checkpoint_input_filename.status();
1174 auto message = absl::StrCat(
1175 "Failed to create checkpoint input file: code: ", status.code(),
1176 ", message: ", status.message());
1177 phase_logger.LogCheckinIOError(
1178 status,
1179 GetNetworkStatsSince(federated_protocol,
1180 /*fedselect_manager=*/nullptr,
1181 network_stats_before_plan_download),
1182 time_before_plan_download, reference_time);
1183 FCP_LOG(ERROR) << message;
1184 return status;
1185 }
1186 }
1187 phase_logger.LogCheckinCompleted(
1188 task_name,
1189 GetNetworkStatsSince(federated_protocol, /*fedselect_manager=*/nullptr,
1190 network_stats_before_plan_download),
1191 /*time_before_checkin=*/time_before_checkin,
1192 /*time_before_plan_download=*/time_before_plan_download, reference_time);
1193 return CheckinResult{
1194 .task_name = std::move(task_name),
1195 .plan = std::move(plan),
1196 .minimum_clients_in_server_visible_aggregate =
1197 minimum_clients_in_server_visible_aggregate,
1198 .checkpoint_input_filename = std::move(*checkpoint_input_filename),
1199 .computation_id = std::move(computation_id),
1200 .federated_select_uri_template =
1201 task_assignment.federated_select_uri_template};
1202 }
1203 } // namespace
RunFederatedComputation(SimpleTaskEnvironment * env_deps,EventPublisher * event_publisher,Files * files,LogManager * log_manager,const Flags * flags,const std::string & federated_service_uri,const std::string & api_key,const std::string & test_cert_path,const std::string & session_name,const std::string & population_name,const std::string & retry_token,const std::string & client_version,const std::string & attestation_measurement)1204 absl::StatusOr<FLRunnerResult> RunFederatedComputation(
1205 SimpleTaskEnvironment* env_deps, EventPublisher* event_publisher,
1206 Files* files, LogManager* log_manager, const Flags* flags,
1207 const std::string& federated_service_uri, const std::string& api_key,
1208 const std::string& test_cert_path, const std::string& session_name,
1209 const std::string& population_name, const std::string& retry_token,
1210 const std::string& client_version,
1211 const std::string& attestation_measurement) {
1212 auto opstats_logger =
1213 engine::CreateOpStatsLogger(env_deps->GetBaseDir(), flags, log_manager,
1214 session_name, population_name);
1215 absl::Time reference_time = absl::Now();
1216 FLRunnerResult fl_runner_result;
1217 fcp::client::InterruptibleRunner::TimingConfig timing_config = {
1218 .polling_period =
1219 absl::Milliseconds(flags->condition_polling_period_millis()),
1220 .graceful_shutdown_period = absl::Milliseconds(
1221 flags->tf_execution_teardown_grace_period_millis()),
1222 .extended_shutdown_period = absl::Milliseconds(
1223 flags->tf_execution_teardown_extended_period_millis()),
1224 };
1225 auto should_abort_protocol_callback = [&env_deps, &timing_config]() -> bool {
1226 // Return the Status if failed, or the negated value if successful.
1227 return env_deps->ShouldAbort(absl::Now(), timing_config.polling_period);
1228 };
1229 PhaseLoggerImpl phase_logger(event_publisher, opstats_logger.get(),
1230 log_manager, flags);
1231 // If there was an error initializing OpStats, opstats_logger will be a
1232 // no-op implementation and execution will be allowed to continue.
1233 if (!opstats_logger->GetInitStatus().ok()) {
1234 // This will only happen if OpStats is enabled and there was an error in
1235 // initialization.
1236 phase_logger.LogNonfatalInitializationError(
1237 opstats_logger->GetInitStatus());
1238 }
1239 Clock* clock = Clock::RealClock();
1240 std::unique_ptr<cache::ResourceCache> resource_cache;
1241 // if (flags->max_resource_cache_size_bytes() > 0) {
1242 // // Anything that goes wrong in FileBackedResourceCache::Create is a
1243 // // programmer error.
1244 // absl::StatusOr<std::unique_ptr<cache::ResourceCache>>
1245 // resource_cache_internal = cache::FileBackedResourceCache::Create(
1246 // env_deps->GetBaseDir(), env_deps->GetCacheDir(), log_manager,
1247 // clock, flags->max_resource_cache_size_bytes());
1248 // if (!resource_cache_internal.ok()) {
1249 // auto resource_init_failed_status = absl::Status(
1250 // resource_cache_internal.status().code(),
1251 // absl::StrCat("Failed to initialize FileBackedResourceCache: ",
1252 // resource_cache_internal.status().ToString()));
1253 // if (flags->resource_cache_initialization_error_is_fatal()) {
1254 // phase_logger.LogFatalInitializationError(resource_init_failed_status);
1255 // return resource_init_failed_status;
1256 // }
1257 // // We log an error but otherwise proceed as if the cache was disabled.
1258 // phase_logger.LogNonfatalInitializationError(resource_init_failed_status);
1259 // } else {
1260 // resource_cache = std::move(*resource_cache_internal);
1261 // }
1262 // }
1263 std::unique_ptr<::fcp::client::http::HttpClient> http_client =
1264 flags->enable_grpc_with_http_resource_support() ||
1265 flags->use_http_federated_compute_protocol()
1266 ? env_deps->CreateHttpClient()
1267 : nullptr;
1268 std::unique_ptr<FederatedProtocol> federated_protocol;
1269 if (flags->use_http_federated_compute_protocol()) {
1270 log_manager->LogDiag(ProdDiagCode::HTTP_FEDERATED_PROTOCOL_USED);
1271 // Verify the entry point uri starts with "https://" or
1272 // "http://localhost". Note "http://localhost" is allowed for testing
1273 // purpose.
1274 if (!(absl::StartsWith(federated_service_uri, "https://") ||
1275 absl::StartsWith(federated_service_uri, "http://localhost"))) {
1276 return absl::InvalidArgumentError("The entry point uri is invalid.");
1277 }
1278 federated_protocol = std::make_unique<http::HttpFederatedProtocol>(
1279 clock, log_manager, flags, http_client.get(),
1280 std::make_unique<SecAggRunnerFactoryImpl>(),
1281 event_publisher->secagg_event_publisher(), federated_service_uri,
1282 api_key, population_name, retry_token, client_version,
1283 attestation_measurement, should_abort_protocol_callback, absl::BitGen(),
1284 timing_config, resource_cache.get());
1285 } else {
1286 #ifdef FCP_CLIENT_SUPPORT_GRPC
1287 // Check in with the server to either retrieve a plan + initial
1288 // checkpoint, or get rejected with a RetryWindow.
1289 auto grpc_channel_deadline = flags->grpc_channel_deadline_seconds();
1290 if (grpc_channel_deadline <= 0) {
1291 grpc_channel_deadline = 600;
1292 FCP_LOG(INFO) << "Using default channel deadline of "
1293 << grpc_channel_deadline << " seconds.";
1294 }
1295 federated_protocol = std::make_unique<GrpcFederatedProtocol>(
1296 event_publisher, log_manager,
1297 std::make_unique<SecAggRunnerFactoryImpl>(), flags, http_client.get(),
1298 federated_service_uri, api_key, test_cert_path, population_name,
1299 retry_token, client_version, attestation_measurement,
1300 should_abort_protocol_callback, timing_config, grpc_channel_deadline,
1301 resource_cache.get());
1302 #else
1303 return absl::InternalError("No FederatedProtocol enabled.");
1304 #endif
1305 }
1306 std::unique_ptr<FederatedSelectManager> federated_select_manager;
1307 if (http_client != nullptr && flags->enable_federated_select()) {
1308 federated_select_manager = std::make_unique<HttpFederatedSelectManager>(
1309 log_manager, files, http_client.get(), should_abort_protocol_callback,
1310 timing_config);
1311 } else {
1312 federated_select_manager =
1313 std::make_unique<DisabledFederatedSelectManager>(log_manager);
1314 }
1315 return RunFederatedComputation(env_deps, phase_logger, event_publisher, files,
1316 log_manager, opstats_logger.get(), flags,
1317 federated_protocol.get(),
1318 federated_select_manager.get(), timing_config,
1319 reference_time, session_name, population_name);
1320 }
RunFederatedComputation(SimpleTaskEnvironment * env_deps,PhaseLogger & phase_logger,EventPublisher * event_publisher,Files * files,LogManager * log_manager,OpStatsLogger * opstats_logger,const Flags * flags,FederatedProtocol * federated_protocol,FederatedSelectManager * fedselect_manager,const fcp::client::InterruptibleRunner::TimingConfig & timing_config,const absl::Time reference_time,const std::string & session_name,const std::string & population_name)1321 absl::StatusOr<FLRunnerResult> RunFederatedComputation(
1322 SimpleTaskEnvironment* env_deps, PhaseLogger& phase_logger,
1323 EventPublisher* event_publisher, Files* files, LogManager* log_manager,
1324 OpStatsLogger* opstats_logger, const Flags* flags,
1325 FederatedProtocol* federated_protocol,
1326 FederatedSelectManager* fedselect_manager,
1327 const fcp::client::InterruptibleRunner::TimingConfig& timing_config,
1328 const absl::Time reference_time, const std::string& session_name,
1329 const std::string& population_name) {
1330 SelectorContext federated_selector_context;
1331 federated_selector_context.mutable_computation_properties()->set_session_name(
1332 session_name);
1333 FederatedComputation federated_computation;
1334 federated_computation.set_population_name(population_name);
1335 *federated_selector_context.mutable_computation_properties()
1336 ->mutable_federated() = federated_computation;
1337 SelectorContext eligibility_selector_context;
1338 eligibility_selector_context.mutable_computation_properties()
1339 ->set_session_name(session_name);
1340 EligibilityEvalComputation eligibility_eval_computation;
1341 eligibility_eval_computation.set_population_name(population_name);
1342 *eligibility_selector_context.mutable_computation_properties()
1343 ->mutable_eligibility_eval() = eligibility_eval_computation;
1344 // Construct a default FLRunnerResult that reflects an unsuccessful training
1345 // attempt and which uses RetryWindow corresponding to transient errors (if
1346 // the flag is on).
1347 // This is what will be returned if we have to bail early, before we've
1348 // received a RetryWindow from the server.
1349 FLRunnerResult fl_runner_result;
1350 fl_runner_result.set_contribution_result(FLRunnerResult::FAIL);
1351 // Before we even check whether we should abort right away, update the retry
1352 // window. That way we will use the most appropriate retry window we have
1353 // available (an implementation detail of FederatedProtocol, but generally a
1354 // 'transient error' retry window based on the provided flag values) in case
1355 // we do need to abort.
1356 UpdateRetryWindowAndNetworkStats(*federated_protocol, fedselect_manager,
1357 phase_logger, fl_runner_result);
1358 // Check if the device conditions allow for checking in with the server
1359 // and running a federated computation. If not, bail early with the
1360 // transient error retry window.
1361 std::function<bool()> should_abort = [env_deps, &timing_config]() {
1362 return env_deps->ShouldAbort(absl::Now(), timing_config.polling_period);
1363 };
1364 if (should_abort()) {
1365 std::string message =
1366 "Device conditions not satisfied, aborting federated computation";
1367 FCP_LOG(INFO) << message;
1368 phase_logger.LogTaskNotStarted(message);
1369 return fl_runner_result;
1370 }
1371 // Eligibility eval plans can use example iterators from the
1372 // SimpleTaskEnvironment and those reading the OpStats DB.
1373 opstats::OpStatsExampleIteratorFactory opstats_example_iterator_factory(
1374 opstats_logger, log_manager,
1375 flags->opstats_last_successful_contribution_criteria());
1376 std::unique_ptr<engine::ExampleIteratorFactory>
1377 env_eligibility_example_iterator_factory =
1378 CreateSimpleTaskEnvironmentIteratorFactory(
1379 env_deps, eligibility_selector_context);
1380 std::vector<engine::ExampleIteratorFactory*>
1381 eligibility_example_iterator_factories{
1382 &opstats_example_iterator_factory,
1383 env_eligibility_example_iterator_factory.get()};
1384 // Note that this method will update fl_runner_result's fields with values
1385 // received over the course of the eligibility eval protocol interaction.
1386 absl::StatusOr<EligibilityEvalResult> eligibility_eval_result =
1387 IssueEligibilityEvalCheckinAndRunPlan(
1388 eligibility_example_iterator_factories, should_abort, phase_logger,
1389 files, log_manager, opstats_logger, flags, federated_protocol,
1390 timing_config, reference_time, fl_runner_result);
1391 if (!eligibility_eval_result.ok()) {
1392 return fl_runner_result;
1393 }
1394 auto checkin_result =
1395 IssueCheckin(phase_logger, log_manager, files, federated_protocol,
1396 std::move(eligibility_eval_result->task_eligibility_info),
1397 reference_time, population_name, fl_runner_result, flags);
1398 if (!checkin_result.ok()) {
1399 return fl_runner_result;
1400 }
1401 SelectorContext federated_selector_context_with_task_name =
1402 federated_selector_context;
1403 federated_selector_context_with_task_name.mutable_computation_properties()
1404 ->mutable_federated()
1405 ->set_task_name(checkin_result->task_name);
1406 if (flags->enable_computation_id()) {
1407 federated_selector_context_with_task_name.mutable_computation_properties()
1408 ->mutable_federated()
1409 ->set_computation_id(checkin_result->computation_id);
1410 }
1411 if (checkin_result->plan.phase().has_example_query_spec()) {
1412 federated_selector_context_with_task_name.mutable_computation_properties()
1413 ->set_example_iterator_output_format(
1414 ::fcp::client::QueryTimeComputationProperties::
1415 EXAMPLE_QUERY_RESULT);
1416 }
1417 // Include the last successful contribution timestamp in the
1418 // SelectorContext.
1419 const auto& opstats_db = opstats_logger->GetOpStatsDb();
1420 if (opstats_db != nullptr) {
1421 absl::StatusOr<opstats::OpStatsSequence> data = opstats_db->Read();
1422 if (data.ok()) {
1423 std::optional<google::protobuf::Timestamp>
1424 last_successful_contribution_time =
1425 opstats::GetLastSuccessfulContributionTime(
1426 *data, checkin_result->task_name);
1427 if (last_successful_contribution_time.has_value()) {
1428 *(federated_selector_context_with_task_name
1429 .mutable_computation_properties()
1430 ->mutable_federated()
1431 ->mutable_historical_context()
1432 ->mutable_last_successful_contribution_time()) =
1433 *last_successful_contribution_time;
1434 }
1435 }
1436 }
1437 if (checkin_result->plan.phase().has_example_query_spec()) {
1438 // Example query plan only supports simple agg for now.
1439 *(federated_selector_context_with_task_name
1440 .mutable_computation_properties()
1441 ->mutable_federated()
1442 ->mutable_simple_aggregation()) = SimpleAggregation();
1443 } else {
1444 const auto& federated_compute_io_router =
1445 checkin_result->plan.phase().federated_compute();
1446 const bool has_simpleagg_tensors =
1447 !federated_compute_io_router.output_filepath_tensor_name().empty();
1448 bool all_aggregations_are_secagg = true;
1449 for (const auto& aggregation : federated_compute_io_router.aggregations()) {
1450 all_aggregations_are_secagg &=
1451 aggregation.second.protocol_config_case() ==
1452 AggregationConfig::kSecureAggregation;
1453 }
1454 if (!has_simpleagg_tensors && all_aggregations_are_secagg) {
1455 federated_selector_context_with_task_name
1456 .mutable_computation_properties()
1457 ->mutable_federated()
1458 ->mutable_secure_aggregation()
1459 ->set_minimum_clients_in_server_visible_aggregate(
1460 checkin_result->minimum_clients_in_server_visible_aggregate);
1461 } else {
1462 // Has an output checkpoint, so some tensors must be simply aggregated.
1463 *(federated_selector_context_with_task_name
1464 .mutable_computation_properties()
1465 ->mutable_federated()
1466 ->mutable_simple_aggregation()) = SimpleAggregation();
1467 }
1468 }
1469 RetryWindow report_retry_window;
1470 phase_logger.LogComputationStarted();
1471 absl::Time run_plan_start_time = absl::Now();
1472 NetworkStats run_plan_start_network_stats =
1473 GetCumulativeNetworkStats(federated_protocol, fedselect_manager);
1474 absl::StatusOr<std::string> checkpoint_output_filename =
1475 files->CreateTempFile("output", ".ckp");
1476 if (!checkpoint_output_filename.ok()) {
1477 auto status = checkpoint_output_filename.status();
1478 auto message = absl::StrCat(
1479 "Could not create temporary output checkpoint file: code: ",
1480 status.code(), ", message: ", status.message());
1481 phase_logger.LogComputationIOError(
1482 status, ExampleStats(),
1483 GetNetworkStatsSince(federated_protocol, fedselect_manager,
1484 run_plan_start_network_stats),
1485 run_plan_start_time);
1486 return fl_runner_result;
1487 }
1488 // Regular plans can use example iterators from the SimpleTaskEnvironment,
1489 // those reading the OpStats DB, or those serving Federated Select slices.
1490 std::unique_ptr<engine::ExampleIteratorFactory> env_example_iterator_factory =
1491 CreateSimpleTaskEnvironmentIteratorFactory(
1492 env_deps, federated_selector_context_with_task_name);
1493 std::unique_ptr<::fcp::client::engine::ExampleIteratorFactory>
1494 fedselect_example_iterator_factory =
1495 fedselect_manager->CreateExampleIteratorFactoryForUriTemplate(
1496 checkin_result->federated_select_uri_template);
1497 std::vector<engine::ExampleIteratorFactory*> example_iterator_factories{
1498 fedselect_example_iterator_factory.get(),
1499 &opstats_example_iterator_factory, env_example_iterator_factory.get()};
1500 PlanResultAndCheckpointFile plan_result_and_checkpoint_file =
1501 checkin_result->plan.phase().has_example_query_spec()
1502 ? RunPlanWithExampleQuerySpec(
1503 example_iterator_factories, opstats_logger, flags,
1504 checkin_result->plan, *checkpoint_output_filename)
1505 : RunPlanWithTensorflowSpec(
1506 example_iterator_factories, should_abort, log_manager,
1507 opstats_logger, flags, checkin_result->plan,
1508 checkin_result->checkpoint_input_filename,
1509 *checkpoint_output_filename, timing_config);
1510 // Update the FLRunnerResult fields to account for any network usage during
1511 // the execution of the plan (e.g. due to Federated Select slices having
1512 // been fetched).
1513 UpdateRetryWindowAndNetworkStats(*federated_protocol, fedselect_manager,
1514 phase_logger, fl_runner_result);
1515 auto outcome = plan_result_and_checkpoint_file.plan_result.outcome;
1516 absl::StatusOr<ComputationResults> computation_results;
1517 if (outcome == engine::PlanOutcome::kSuccess) {
1518 computation_results = CreateComputationResults(
1519 checkin_result->plan.phase().has_example_query_spec()
1520 ? nullptr
1521 : &checkin_result->plan.phase().tensorflow_spec(),
1522 plan_result_and_checkpoint_file);
1523 }
1524 LogComputationOutcome(
1525 plan_result_and_checkpoint_file.plan_result, computation_results.status(),
1526 phase_logger,
1527 GetNetworkStatsSince(federated_protocol, fedselect_manager,
1528 run_plan_start_network_stats),
1529 run_plan_start_time, reference_time);
1530 absl::Status report_result = ReportPlanResult(
1531 federated_protocol, phase_logger, std::move(computation_results),
1532 run_plan_start_time, reference_time);
1533 if (outcome == engine::PlanOutcome::kSuccess && report_result.ok()) {
1534 // Only if training succeeded *and* reporting succeeded do we consider
1535 // the device to have contributed successfully.
1536 fl_runner_result.set_contribution_result(FLRunnerResult::SUCCESS);
1537 }
1538 // Update the FLRunnerResult fields one more time to account for the
1539 // "Report" protocol interaction.
1540 UpdateRetryWindowAndNetworkStats(*federated_protocol, fedselect_manager,
1541 phase_logger, fl_runner_result);
1542 return fl_runner_result;
1543 }
RunPlanWithTensorflowSpecForTesting(SimpleTaskEnvironment * env_deps,EventPublisher * event_publisher,Files * files,LogManager * log_manager,const Flags * flags,const ClientOnlyPlan & client_plan,const std::string & checkpoint_input_filename,const fcp::client::InterruptibleRunner::TimingConfig & timing_config,const absl::Time run_plan_start_time,const absl::Time reference_time)1544 FLRunnerTensorflowSpecResult RunPlanWithTensorflowSpecForTesting(
1545 SimpleTaskEnvironment* env_deps, EventPublisher* event_publisher,
1546 Files* files, LogManager* log_manager, const Flags* flags,
1547 const ClientOnlyPlan& client_plan,
1548 const std::string& checkpoint_input_filename,
1549 const fcp::client::InterruptibleRunner::TimingConfig& timing_config,
1550 const absl::Time run_plan_start_time, const absl::Time reference_time) {
1551 FLRunnerTensorflowSpecResult result;
1552 result.set_outcome(engine::PhaseOutcome::ERROR);
1553 engine::PlanResult plan_result(engine::PlanOutcome::kTensorflowError,
1554 absl::UnknownError(""));
1555 std::function<bool()> should_abort = [env_deps, &timing_config]() {
1556 return env_deps->ShouldAbort(absl::Now(), timing_config.polling_period);
1557 };
1558 auto opstats_logger =
1559 engine::CreateOpStatsLogger(env_deps->GetBaseDir(), flags, log_manager,
1560 /*session_name=*/"", /*population_name=*/"");
1561 PhaseLoggerImpl phase_logger(event_publisher, opstats_logger.get(),
1562 log_manager, flags);
1563 // Regular plans can use example iterators from the SimpleTaskEnvironment,
1564 // those reading the OpStats DB, or those serving Federated Select slices.
1565 // However, we don't provide a Federated Select-specific example iterator
1566 // factory. That way, the Federated Select slice queries will be forwarded
1567 // to SimpleTaskEnvironment, which can handle them by providing
1568 // test-specific slices if they want to.
1569 //
1570 // Eligibility eval plans can only use iterators from the
1571 // SimpleTaskEnvironment and those reading the OpStats DB.
1572 opstats::OpStatsExampleIteratorFactory opstats_example_iterator_factory(
1573 opstats_logger.get(), log_manager,
1574 flags->opstats_last_successful_contribution_criteria());
1575 std::unique_ptr<engine::ExampleIteratorFactory> env_example_iterator_factory =
1576 CreateSimpleTaskEnvironmentIteratorFactory(env_deps, SelectorContext());
1577 std::vector<engine::ExampleIteratorFactory*> example_iterator_factories{
1578 &opstats_example_iterator_factory, env_example_iterator_factory.get()};
1579 phase_logger.LogComputationStarted();
1580 if (client_plan.phase().has_federated_compute()) {
1581 absl::StatusOr<std::string> checkpoint_output_filename =
1582 files->CreateTempFile("output", ".ckp");
1583 if (!checkpoint_output_filename.ok()) {
1584 phase_logger.LogComputationIOError(
1585 checkpoint_output_filename.status(), ExampleStats(),
1586 // Empty network stats, since no network protocol is actually used
1587 // in this method.
1588 NetworkStats(), run_plan_start_time);
1589 return result;
1590 }
1591 // Regular TensorflowSpec-based plans.
1592 PlanResultAndCheckpointFile plan_result_and_checkpoint_file =
1593 RunPlanWithTensorflowSpec(example_iterator_factories, should_abort,
1594 log_manager, opstats_logger.get(), flags,
1595 client_plan, checkpoint_input_filename,
1596 *checkpoint_output_filename, timing_config);
1597 result.set_checkpoint_output_filename(
1598 plan_result_and_checkpoint_file.checkpoint_file);
1599 plan_result = std::move(plan_result_and_checkpoint_file.plan_result);
1600 } else if (client_plan.phase().has_federated_compute_eligibility()) {
1601 // Eligibility eval plans.
1602 plan_result = RunEligibilityEvalPlanWithTensorflowSpec(
1603 example_iterator_factories, should_abort, log_manager,
1604 opstats_logger.get(), flags, client_plan, checkpoint_input_filename,
1605 timing_config, run_plan_start_time, reference_time);
1606 } else {
1607 // This branch shouldn't be taken, unless we add an additional type of
1608 // TensorflowSpec-based plan in the future. We return a readable error so
1609 // that when such new plan types *are* added, they result in clear
1610 // compatibility test failures when such plans are erroneously targeted at
1611 // old releases that don't support them yet.
1612 event_publisher->PublishIoError("Unsupported TensorflowSpec-based plan");
1613 return result;
1614 }
1615 // Copy output tensors into the result proto.
1616 result.set_outcome(
1617 engine::ConvertPlanOutcomeToPhaseOutcome(plan_result.outcome));
1618 if (plan_result.outcome == engine::PlanOutcome::kSuccess) {
1619 for (int i = 0; i < plan_result.output_names.size(); i++) {
1620 tensorflow::TensorProto output_tensor_proto;
1621 plan_result.output_tensors[i].AsProtoField(&output_tensor_proto);
1622 (*result.mutable_output_tensors())[plan_result.output_names[i]] =
1623 std::move(output_tensor_proto);
1624 }
1625 phase_logger.LogComputationCompleted(
1626 plan_result.example_stats,
1627 // Empty network stats, since no network protocol is actually used in
1628 // this method.
1629 NetworkStats(), run_plan_start_time, reference_time);
1630 } else {
1631 phase_logger.LogComputationTensorflowError(
1632 plan_result.original_status, plan_result.example_stats, NetworkStats(),
1633 run_plan_start_time, reference_time);
1634 }
1635 return result;
1636 }
1637 } // namespace client
1638 } // namespace fcp
1639