/* * Copyright 2020 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "fcp/client/fl_runner.h" #include #include #include #include #include #include #include #include #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/cord.h" #include "absl/time/time.h" #include "fcp/base/clock.h" #include "fcp/base/monitoring.h" #include "fcp/base/platform.h" // #include "fcp/client/cache/file_backed_resource_cache.h" #include "fcp/client/cache/resource_cache.h" #include "fcp/client/engine/common.h" #include "fcp/client/engine/engine.pb.h" #include "fcp/client/engine/example_iterator_factory.h" #include "fcp/client/engine/example_query_plan_engine.h" #include "fcp/client/engine/plan_engine_helpers.h" #include "fcp/client/opstats/opstats_utils.h" #include "fcp/client/parsing_utils.h" #ifdef FCP_CLIENT_SUPPORT_TFMOBILE #include "fcp/client/engine/simple_plan_engine.h" #endif #include "fcp/client/engine/tflite_plan_engine.h" #include "fcp/client/event_publisher.h" #include "fcp/client/federated_protocol.h" #include "fcp/client/federated_protocol_util.h" #include "fcp/client/files.h" #include "fcp/client/fl_runner.pb.h" #include "fcp/client/flags.h" #include "fcp/client/http/http_federated_protocol.h" #ifdef FCP_CLIENT_SUPPORT_GRPC #include "fcp/client/grpc_federated_protocol.h" #endif #include "fcp/client/interruptible_runner.h" #include "fcp/client/log_manager.h" #include "fcp/client/opstats/opstats_example_store.h" #include "fcp/client/phase_logger_impl.h" #include "fcp/client/secagg_runner.h" #include "fcp/client/selector_context.pb.h" #include "fcp/client/simple_task_environment.h" #include "fcp/protos/federated_api.pb.h" #include "fcp/protos/federatedcompute/eligibility_eval_tasks.pb.h" #include "fcp/protos/opstats.pb.h" #include "fcp/protos/plan.pb.h" #include "openssl/digest.h" #include "openssl/evp.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/protobuf/struct.pb.h" namespace fcp { namespace client { using ::fcp::client::opstats::OpStatsLogger; using ::google::internal::federated::plan::AggregationConfig; using ::google::internal::federated::plan::ClientOnlyPlan; using ::google::internal::federated::plan::FederatedComputeEligibilityIORouter; using ::google::internal::federated::plan::FederatedComputeIORouter; using ::google::internal::federated::plan::TensorflowSpec; using ::google::internal::federatedcompute::v1::PopulationEligibilitySpec; using ::google::internal::federatedml::v2::RetryWindow; using ::google::internal::federatedml::v2::TaskEligibilityInfo; using TfLiteInputs = absl::flat_hash_map; namespace { template void AddValuesToQuantized(QuantizedTensor* quantized, const tensorflow::Tensor& tensor) { auto flat_tensor = tensor.flat(); quantized->values.reserve(quantized->values.size() + flat_tensor.size()); for (int i = 0; i < flat_tensor.size(); i++) { quantized->values.push_back(flat_tensor(i)); } } std::string ComputeSHA256FromStringOrCord( std::variant data) { std::unique_ptr mdctx(EVP_MD_CTX_create(), EVP_MD_CTX_destroy); FCP_CHECK(EVP_DigestInit_ex(mdctx.get(), EVP_sha256(), nullptr)); std::string plan_str; if (std::holds_alternative(data)) { plan_str = std::get(data); } else { plan_str = std::string(std::get(data)); } FCP_CHECK(EVP_DigestUpdate(mdctx.get(), plan_str.c_str(), sizeof(int))); const int hash_len = 32; // 32 bytes for SHA-256. uint8_t computation_id_bytes[hash_len]; FCP_CHECK(EVP_DigestFinal_ex(mdctx.get(), computation_id_bytes, nullptr)); return std::string(reinterpret_cast(computation_id_bytes), hash_len); } struct PlanResultAndCheckpointFile { explicit PlanResultAndCheckpointFile(engine::PlanResult plan_result) : plan_result(std::move(plan_result)) {} engine::PlanResult plan_result; std::string checkpoint_file; PlanResultAndCheckpointFile(PlanResultAndCheckpointFile&&) = default; PlanResultAndCheckpointFile& operator=(PlanResultAndCheckpointFile&&) = default; // Disallow copy and assign. PlanResultAndCheckpointFile(const PlanResultAndCheckpointFile&) = delete; PlanResultAndCheckpointFile& operator=(const PlanResultAndCheckpointFile&) = delete; }; // Creates computation results. The method checks for SecAgg tensors only if // `tensorflow_spec != nullptr`. absl::StatusOr CreateComputationResults( const TensorflowSpec* tensorflow_spec, const PlanResultAndCheckpointFile& plan_result_and_checkpoint_file) { const auto& [plan_result, checkpoint_file] = plan_result_and_checkpoint_file; if (plan_result.outcome != engine::PlanOutcome::kSuccess) { return absl::InvalidArgumentError("Computation failed."); } ComputationResults computation_results; if (tensorflow_spec != nullptr) { for (int i = 0; i < plan_result.output_names.size(); i++) { QuantizedTensor quantized; const auto& output_tensor = plan_result.output_tensors[i]; switch (output_tensor.dtype()) { case tensorflow::DT_INT8: AddValuesToQuantized(&quantized, output_tensor); quantized.bitwidth = 7; break; case tensorflow::DT_UINT8: AddValuesToQuantized(&quantized, output_tensor); quantized.bitwidth = 8; break; case tensorflow::DT_INT16: AddValuesToQuantized(&quantized, output_tensor); quantized.bitwidth = 15; break; case tensorflow::DT_UINT16: AddValuesToQuantized(&quantized, output_tensor); quantized.bitwidth = 16; break; case tensorflow::DT_INT32: AddValuesToQuantized(&quantized, output_tensor); quantized.bitwidth = 31; break; case tensorflow::DT_INT64: AddValuesToQuantized(&quantized, output_tensor); quantized.bitwidth = 62; break; default: return absl::InvalidArgumentError( absl::StrCat("Tensor of type", tensorflow::DataType_Name(output_tensor.dtype()), "could not be converted to quantized value")); } computation_results[plan_result.output_names[i]] = std::move(quantized); } // Add dimensions to QuantizedTensors. for (const tensorflow::TensorSpecProto& tensor_spec : tensorflow_spec->output_tensor_specs()) { if (computation_results.find(tensor_spec.name()) != computation_results.end()) { for (const tensorflow::TensorShapeProto_Dim& dim : tensor_spec.shape().dim()) { std::get(computation_results[tensor_spec.name()]) .dimensions.push_back(dim.size()); } } } } // Name of the TF checkpoint inside the aggregand map in the Checkpoint // protobuf. This field name is ignored by the server. if (!checkpoint_file.empty()) { FCP_ASSIGN_OR_RETURN(std::string tf_checkpoint, fcp::ReadFileToString(checkpoint_file)); computation_results[std::string(kTensorflowCheckpointAggregand)] = std::move(tf_checkpoint); } return computation_results; } #ifdef FCP_CLIENT_SUPPORT_TFMOBILE std::unique_ptr>> ConstructInputsForEligibilityEvalPlan( const FederatedComputeEligibilityIORouter& io_router, const std::string& checkpoint_input_filename) { auto inputs = std::make_unique< std::vector>>(); if (!io_router.input_filepath_tensor_name().empty()) { tensorflow::Tensor input_filepath(tensorflow::DT_STRING, {}); input_filepath.scalar()() = checkpoint_input_filename; inputs->push_back({io_router.input_filepath_tensor_name(), input_filepath}); } return inputs; } #endif std::unique_ptr ConstructTfLiteInputsForEligibilityEvalPlan( const FederatedComputeEligibilityIORouter& io_router, const std::string& checkpoint_input_filename) { auto inputs = std::make_unique(); if (!io_router.input_filepath_tensor_name().empty()) { (*inputs)[io_router.input_filepath_tensor_name()] = checkpoint_input_filename; } return inputs; } // Returns the cumulative network stats (those incurred up until this point in // time). // // The `FederatedSelectManager` object may be null, if it is know that there // has been no network usage from it yet. NetworkStats GetCumulativeNetworkStats( FederatedProtocol* federated_protocol, FederatedSelectManager* fedselect_manager) { NetworkStats result = federated_protocol->GetNetworkStats(); if (fedselect_manager != nullptr) { result = result + fedselect_manager->GetNetworkStats(); } return result; } // Returns the newly incurred network stats since the previous snapshot of // stats (the `reference_point` argument). NetworkStats GetNetworkStatsSince(FederatedProtocol* federated_protocol, FederatedSelectManager* fedselect_manager, const NetworkStats& reference_point) { return GetCumulativeNetworkStats(federated_protocol, fedselect_manager) - reference_point; } // Updates the fields of `FLRunnerResult` that should always be updated after // each interaction with the `FederatedProtocol` or `FederatedSelectManager` // objects. // // The `FederatedSelectManager` object may be null, if it is know that there // has been no network usage from it yet. void UpdateRetryWindowAndNetworkStats(FederatedProtocol& federated_protocol, FederatedSelectManager* fedselect_manager, PhaseLogger& phase_logger, FLRunnerResult& fl_runner_result) { // Update the result's retry window to the most recent one. auto retry_window = federated_protocol.GetLatestRetryWindow(); RetryInfo retry_info; *retry_info.mutable_retry_token() = retry_window.retry_token(); *retry_info.mutable_minimum_delay() = retry_window.delay_min(); *fl_runner_result.mutable_retry_info() = retry_info; phase_logger.UpdateRetryWindowAndNetworkStats( retry_window, GetCumulativeNetworkStats(&federated_protocol, fedselect_manager)); } // Creates an ExampleIteratorFactory that routes queries to the // SimpleTaskEnvironment::CreateExampleIterator() method. std::unique_ptr CreateSimpleTaskEnvironmentIteratorFactory( SimpleTaskEnvironment* task_env, const SelectorContext& selector_context) { return std::make_unique( /*can_handle_func=*/ [](const google::internal::federated::plan::ExampleSelector&) { // The SimpleTaskEnvironment-based ExampleIteratorFactory should // be the catch-all factory that is able to handle all queries // that no other ExampleIteratorFactory is able to handle. return true; }, /*create_iterator_func=*/ [task_env, selector_context]( const google::internal::federated::plan::ExampleSelector& example_selector) { return task_env->CreateExampleIterator(example_selector, selector_context); }, /*should_collect_stats=*/true); } engine::PlanResult RunEligibilityEvalPlanWithTensorflowSpec( std::vector example_iterator_factories, std::function should_abort, LogManager* log_manager, OpStatsLogger* opstats_logger, const Flags* flags, const ClientOnlyPlan& client_plan, const std::string& checkpoint_input_filename, const fcp::client::InterruptibleRunner::TimingConfig& timing_config, const absl::Time run_plan_start_time, const absl::Time reference_time) { // Check that this is a TensorflowSpec-based plan for federated eligibility // computation. if (!client_plan.phase().has_tensorflow_spec() || !client_plan.phase().has_federated_compute_eligibility()) { return engine::PlanResult( engine::PlanOutcome::kInvalidArgument, absl::InvalidArgumentError("Invalid eligibility eval plan")); } const FederatedComputeEligibilityIORouter& io_router = client_plan.phase().federated_compute_eligibility(); std::vector output_names = { io_router.task_eligibility_info_tensor_name()}; if (!client_plan.tflite_graph().empty()) { log_manager->LogDiag( ProdDiagCode::BACKGROUND_TRAINING_TFLITE_MODEL_INCLUDED); } if (flags->use_tflite_training() && !client_plan.tflite_graph().empty()) { std::unique_ptr tflite_inputs = ConstructTfLiteInputsForEligibilityEvalPlan(io_router, checkpoint_input_filename); engine::TfLitePlanEngine plan_engine(example_iterator_factories, should_abort, log_manager, opstats_logger, flags, &timing_config); return plan_engine.RunPlan(client_plan.phase().tensorflow_spec(), client_plan.tflite_graph(), std::move(tflite_inputs), output_names); } #ifdef FCP_CLIENT_SUPPORT_TFMOBILE // Construct input tensors and output tensor names based on the values in // the FederatedComputeEligibilityIORouter message. auto inputs = ConstructInputsForEligibilityEvalPlan( io_router, checkpoint_input_filename); // Run plan and get a set of output tensors back. engine::SimplePlanEngine plan_engine( example_iterator_factories, should_abort, log_manager, opstats_logger, &timing_config, flags->support_constant_tf_inputs()); return plan_engine.RunPlan( client_plan.phase().tensorflow_spec(), client_plan.graph(), client_plan.tensorflow_config_proto(), std::move(inputs), output_names); #else return engine::PlanResult( engine::PlanOutcome::kTensorflowError, absl::InternalError("No eligibility eval plan engine enabled")); #endif } // Validates the output tensors that resulted from executing the plan, and // then parses the output into a TaskEligibilityInfo proto. Returns an error // if validation or parsing failed. absl::StatusOr ParseEligibilityEvalPlanOutput( const std::vector& output_tensors) { auto output_size = output_tensors.size(); if (output_size != 1) { return absl::InvalidArgumentError( absl::StrCat("Unexpected number of output tensors: ", output_size)); } auto output_elements = output_tensors[0].NumElements(); if (output_elements != 1) { return absl::InvalidArgumentError(absl::StrCat( "Unexpected number of output tensor elements: ", output_elements)); } tensorflow::DataType output_type = output_tensors[0].dtype(); if (output_type != tensorflow::DT_STRING) { return absl::InvalidArgumentError( absl::StrCat("Unexpected output tensor type: ", output_type)); } // Extract the serialized TaskEligibilityInfo proto from the tensor and // parse it. // First, convert the output Tensor into a Scalar (= a TensorMap with 1 // element), then use its operator() to access the actual data. const tensorflow::tstring& serialized_output = output_tensors[0].scalar()(); TaskEligibilityInfo parsed_output; if (!parsed_output.ParseFromString(serialized_output)) { return absl::InvalidArgumentError("Could not parse output proto"); } return parsed_output; } #ifdef FCP_CLIENT_SUPPORT_TFMOBILE std::unique_ptr>> ConstructInputsForTensorflowSpecPlan( const FederatedComputeIORouter& io_router, const std::string& checkpoint_input_filename, const std::string& checkpoint_output_filename) { auto inputs = std::make_unique< std::vector>>(); if (!io_router.input_filepath_tensor_name().empty()) { tensorflow::Tensor input_filepath(tensorflow::DT_STRING, {}); input_filepath.scalar()() = checkpoint_input_filename; inputs->push_back({io_router.input_filepath_tensor_name(), input_filepath}); } if (!io_router.output_filepath_tensor_name().empty()) { tensorflow::Tensor output_filepath(tensorflow::DT_STRING, {}); output_filepath.scalar()() = checkpoint_output_filename; inputs->push_back( {io_router.output_filepath_tensor_name(), output_filepath}); } return inputs; } #endif std::unique_ptr ConstructTFLiteInputsForTensorflowSpecPlan( const FederatedComputeIORouter& io_router, const std::string& checkpoint_input_filename, const std::string& checkpoint_output_filename) { auto inputs = std::make_unique(); if (!io_router.input_filepath_tensor_name().empty()) { (*inputs)[io_router.input_filepath_tensor_name()] = checkpoint_input_filename; } if (!io_router.output_filepath_tensor_name().empty()) { (*inputs)[io_router.output_filepath_tensor_name()] = checkpoint_output_filename; } return inputs; } absl::StatusOr> ConstructOutputsWithDeterministicOrder( const TensorflowSpec& tensorflow_spec, const FederatedComputeIORouter& io_router) { std::vector output_names; // The order of output tensor names should match the order in // TensorflowSpec. for (const auto& output_tensor_spec : tensorflow_spec.output_tensor_specs()) { std::string tensor_name = output_tensor_spec.name(); if (!io_router.aggregations().contains(tensor_name) || !io_router.aggregations().at(tensor_name).has_secure_aggregation()) { return absl::InvalidArgumentError( "Output tensor is missing in AggregationConfig, or has unsupported " "aggregation type."); } output_names.push_back(tensor_name); } return output_names; } PlanResultAndCheckpointFile RunPlanWithTensorflowSpec( std::vector example_iterator_factories, std::function should_abort, LogManager* log_manager, OpStatsLogger* opstats_logger, const Flags* flags, const ClientOnlyPlan& client_plan, const std::string& checkpoint_input_filename, const std::string& checkpoint_output_filename, const fcp::client::InterruptibleRunner::TimingConfig& timing_config) { if (!client_plan.phase().has_tensorflow_spec()) { return PlanResultAndCheckpointFile(engine::PlanResult( engine::PlanOutcome::kInvalidArgument, absl::InvalidArgumentError("Plan must include TensorflowSpec."))); } if (!client_plan.phase().has_federated_compute()) { return PlanResultAndCheckpointFile(engine::PlanResult( engine::PlanOutcome::kInvalidArgument, absl::InvalidArgumentError("Invalid TensorflowSpec-based plan"))); } // Get the output tensor names. absl::StatusOr> output_names; output_names = ConstructOutputsWithDeterministicOrder( client_plan.phase().tensorflow_spec(), client_plan.phase().federated_compute()); if (!output_names.ok()) { return PlanResultAndCheckpointFile(engine::PlanResult( engine::PlanOutcome::kInvalidArgument, output_names.status())); } // Run plan and get a set of output tensors back. if (flags->use_tflite_training() && !client_plan.tflite_graph().empty()) { std::unique_ptr tflite_inputs = ConstructTFLiteInputsForTensorflowSpecPlan( client_plan.phase().federated_compute(), checkpoint_input_filename, checkpoint_output_filename); engine::TfLitePlanEngine plan_engine(example_iterator_factories, should_abort, log_manager, opstats_logger, flags, &timing_config); engine::PlanResult plan_result = plan_engine.RunPlan( client_plan.phase().tensorflow_spec(), client_plan.tflite_graph(), std::move(tflite_inputs), *output_names); PlanResultAndCheckpointFile result(std::move(plan_result)); result.checkpoint_file = checkpoint_output_filename; return result; } #ifdef FCP_CLIENT_SUPPORT_TFMOBILE // Construct input tensors based on the values in the // FederatedComputeIORouter message and create a temporary file for the // output checkpoint if needed. auto inputs = ConstructInputsForTensorflowSpecPlan( client_plan.phase().federated_compute(), checkpoint_input_filename, checkpoint_output_filename); engine::SimplePlanEngine plan_engine( example_iterator_factories, should_abort, log_manager, opstats_logger, &timing_config, flags->support_constant_tf_inputs()); engine::PlanResult plan_result = plan_engine.RunPlan( client_plan.phase().tensorflow_spec(), client_plan.graph(), client_plan.tensorflow_config_proto(), std::move(inputs), *output_names); PlanResultAndCheckpointFile result(std::move(plan_result)); result.checkpoint_file = checkpoint_output_filename; return result; #else return PlanResultAndCheckpointFile( engine::PlanResult(engine::PlanOutcome::kTensorflowError, absl::InternalError("No plan engine enabled"))); #endif } PlanResultAndCheckpointFile RunPlanWithExampleQuerySpec( std::vector example_iterator_factories, OpStatsLogger* opstats_logger, const Flags* flags, const ClientOnlyPlan& client_plan, const std::string& checkpoint_output_filename) { if (!client_plan.phase().has_example_query_spec()) { return PlanResultAndCheckpointFile(engine::PlanResult( engine::PlanOutcome::kInvalidArgument, absl::InvalidArgumentError("Plan must include ExampleQuerySpec"))); } if (!flags->enable_example_query_plan_engine()) { // Example query plan received while the flag is off. return PlanResultAndCheckpointFile(engine::PlanResult( engine::PlanOutcome::kInvalidArgument, absl::InvalidArgumentError( "Example query plan received while the flag is off"))); } if (!client_plan.phase().has_federated_example_query()) { return PlanResultAndCheckpointFile(engine::PlanResult( engine::PlanOutcome::kInvalidArgument, absl::InvalidArgumentError("Invalid ExampleQuerySpec-based plan"))); } for (const auto& example_query : client_plan.phase().example_query_spec().example_queries()) { for (auto const& [vector_name, spec] : example_query.output_vector_specs()) { const auto& aggregations = client_plan.phase().federated_example_query().aggregations(); if ((aggregations.find(vector_name) == aggregations.end()) || !aggregations.at(vector_name).has_tf_v1_checkpoint_aggregation()) { return PlanResultAndCheckpointFile(engine::PlanResult( engine::PlanOutcome::kInvalidArgument, absl::InvalidArgumentError("Output vector is missing in " "AggregationConfig, or has unsupported " "aggregation type."))); } } } engine::ExampleQueryPlanEngine plan_engine(example_iterator_factories, opstats_logger); engine::PlanResult plan_result = plan_engine.RunPlan( client_plan.phase().example_query_spec(), checkpoint_output_filename); PlanResultAndCheckpointFile result(std::move(plan_result)); result.checkpoint_file = checkpoint_output_filename; return result; } void LogEligibilityEvalComputationOutcome( PhaseLogger& phase_logger, engine::PlanResult plan_result, const absl::Status& eligibility_info_parsing_status, absl::Time run_plan_start_time, absl::Time reference_time) { switch (plan_result.outcome) { case engine::PlanOutcome::kSuccess: { if (eligibility_info_parsing_status.ok()) { phase_logger.LogEligibilityEvalComputationCompleted( plan_result.example_stats, run_plan_start_time, reference_time); } else { phase_logger.LogEligibilityEvalComputationTensorflowError( eligibility_info_parsing_status, plan_result.example_stats, run_plan_start_time, reference_time); FCP_LOG(ERROR) << eligibility_info_parsing_status.message(); } break; } case engine::PlanOutcome::kInterrupted: phase_logger.LogEligibilityEvalComputationInterrupted( plan_result.original_status, plan_result.example_stats, run_plan_start_time, reference_time); break; case engine::PlanOutcome::kInvalidArgument: phase_logger.LogEligibilityEvalComputationInvalidArgument( plan_result.original_status, plan_result.example_stats, run_plan_start_time); break; case engine::PlanOutcome::kTensorflowError: phase_logger.LogEligibilityEvalComputationTensorflowError( plan_result.original_status, plan_result.example_stats, run_plan_start_time, reference_time); break; case engine::PlanOutcome::kExampleIteratorError: phase_logger.LogEligibilityEvalComputationExampleIteratorError( plan_result.original_status, plan_result.example_stats, run_plan_start_time); break; } } void LogComputationOutcome(const engine::PlanResult& plan_result, absl::Status computation_results_parsing_status, PhaseLogger& phase_logger, const NetworkStats& network_stats, absl::Time run_plan_start_time, absl::Time reference_time) { switch (plan_result.outcome) { case engine::PlanOutcome::kSuccess: { if (computation_results_parsing_status.ok()) { phase_logger.LogComputationCompleted(plan_result.example_stats, network_stats, run_plan_start_time, reference_time); } else { phase_logger.LogComputationTensorflowError( computation_results_parsing_status, plan_result.example_stats, network_stats, run_plan_start_time, reference_time); } break; } case engine::PlanOutcome::kInterrupted: phase_logger.LogComputationInterrupted( plan_result.original_status, plan_result.example_stats, network_stats, run_plan_start_time, reference_time); break; case engine::PlanOutcome::kInvalidArgument: phase_logger.LogComputationInvalidArgument( plan_result.original_status, plan_result.example_stats, network_stats, run_plan_start_time); break; case engine::PlanOutcome::kTensorflowError: phase_logger.LogComputationTensorflowError( plan_result.original_status, plan_result.example_stats, network_stats, run_plan_start_time, reference_time); break; case engine::PlanOutcome::kExampleIteratorError: phase_logger.LogComputationExampleIteratorError( plan_result.original_status, plan_result.example_stats, network_stats, run_plan_start_time); break; } } void LogResultUploadStatus(PhaseLogger& phase_logger, absl::Status result, const NetworkStats& network_stats, absl::Time time_before_result_upload, absl::Time reference_time) { if (result.ok()) { phase_logger.LogResultUploadCompleted( network_stats, time_before_result_upload, reference_time); } else { auto message = absl::StrCat("Error reporting results: code: ", result.code(), ", message: ", result.message()); FCP_LOG(INFO) << message; if (result.code() == absl::StatusCode::kAborted) { phase_logger.LogResultUploadServerAborted( result, network_stats, time_before_result_upload, reference_time); } else if (result.code() == absl::StatusCode::kCancelled) { phase_logger.LogResultUploadClientInterrupted( result, network_stats, time_before_result_upload, reference_time); } else { phase_logger.LogResultUploadIOError( result, network_stats, time_before_result_upload, reference_time); } } } void LogFailureUploadStatus(PhaseLogger& phase_logger, absl::Status result, const NetworkStats& network_stats, absl::Time time_before_failure_upload, absl::Time reference_time) { if (result.ok()) { phase_logger.LogFailureUploadCompleted( network_stats, time_before_failure_upload, reference_time); } else { auto message = absl::StrCat("Error reporting computation failure: code: ", result.code(), ", message: ", result.message()); FCP_LOG(INFO) << message; if (result.code() == absl::StatusCode::kAborted) { phase_logger.LogFailureUploadServerAborted( result, network_stats, time_before_failure_upload, reference_time); } else if (result.code() == absl::StatusCode::kCancelled) { phase_logger.LogFailureUploadClientInterrupted( result, network_stats, time_before_failure_upload, reference_time); } else { phase_logger.LogFailureUploadIOError( result, network_stats, time_before_failure_upload, reference_time); } } } absl::Status ReportPlanResult( FederatedProtocol* federated_protocol, PhaseLogger& phase_logger, absl::StatusOr computation_results, absl::Time run_plan_start_time, absl::Time reference_time) { const absl::Time before_report_time = absl::Now(); // Note that the FederatedSelectManager shouldn't be active anymore during // the reporting of results, so we don't bother passing it to // GetNetworkStatsSince. // // We must return only stats that cover the report phase for the log events // below. const NetworkStats before_report_stats = GetCumulativeNetworkStats(federated_protocol, /*fedselect_manager=*/nullptr); absl::Status result = absl::InternalError(""); if (computation_results.ok()) { FCP_RETURN_IF_ERROR(phase_logger.LogResultUploadStarted()); result = federated_protocol->ReportCompleted( std::move(*computation_results), /*plan_duration=*/absl::Now() - run_plan_start_time, std::nullopt); LogResultUploadStatus(phase_logger, result, GetNetworkStatsSince(federated_protocol, /*fedselect_manager=*/nullptr, before_report_stats), before_report_time, reference_time); } else { FCP_RETURN_IF_ERROR(phase_logger.LogFailureUploadStarted()); result = federated_protocol->ReportNotCompleted( engine::PhaseOutcome::ERROR, /*plan_duration=*/absl::Now() - run_plan_start_time, std::nullopt); LogFailureUploadStatus(phase_logger, result, GetNetworkStatsSince(federated_protocol, /*fedselect_manager=*/nullptr, before_report_stats), before_report_time, reference_time); } return result; } // Writes the given data to the stream, and returns true if successful and // false if not. bool WriteStringOrCordToFstream( std::fstream& stream, const std::variant& data) { if (stream.fail()) { return false; } if (std::holds_alternative(data)) { return (stream << std::get(data)).good(); } for (absl::string_view chunk : std::get(data).Chunks()) { if (!(stream << chunk).good()) { return false; } } return true; } // Writes the given checkpoint data to a newly created temporary file. // Returns the filename if successful, or an error if the file could not be // created, or if writing to the file failed. absl::StatusOr CreateInputCheckpointFile( Files* files, const std::variant& checkpoint) { // Create the temporary checkpoint file. // Deletion of the file is left to the caller / the Files implementation. FCP_ASSIGN_OR_RETURN(absl::StatusOr filename, files->CreateTempFile("init", ".ckp")); // Write the checkpoint data to the file. std::fstream checkpoint_stream(*filename, std::ios_base::out); if (!WriteStringOrCordToFstream(checkpoint_stream, checkpoint)) { return absl::InvalidArgumentError("Failed to write to file"); } checkpoint_stream.close(); return filename; } absl::StatusOr> RunEligibilityEvalPlan( const FederatedProtocol::EligibilityEvalTask& eligibility_eval_task, std::vector example_iterator_factories, std::function should_abort, PhaseLogger& phase_logger, Files* files, LogManager* log_manager, OpStatsLogger* opstats_logger, const Flags* flags, FederatedProtocol* federated_protocol, const fcp::client::InterruptibleRunner::TimingConfig& timing_config, const absl::Time reference_time, const absl::Time time_before_checkin, const absl::Time time_before_plan_download, const NetworkStats& network_stats) { ClientOnlyPlan plan; if (!ParseFromStringOrCord(plan, eligibility_eval_task.payloads.plan)) { auto message = "Failed to parse received eligibility eval plan"; phase_logger.LogEligibilityEvalCheckinInvalidPayloadError( message, network_stats, time_before_plan_download); FCP_LOG(ERROR) << message; return absl::InternalError(message); } absl::StatusOr checkpoint_input_filename = CreateInputCheckpointFile(files, eligibility_eval_task.payloads.checkpoint); if (!checkpoint_input_filename.ok()) { auto status = checkpoint_input_filename.status(); auto message = absl::StrCat( "Failed to create eligibility eval checkpoint input file: code: ", status.code(), ", message: ", status.message()); phase_logger.LogEligibilityEvalCheckinIOError(status, network_stats, time_before_plan_download); FCP_LOG(ERROR) << message; return absl::InternalError(""); } phase_logger.LogEligibilityEvalCheckinCompleted(network_stats, /*time_before_checkin=*/ time_before_checkin, /*time_before_plan_download=*/ time_before_plan_download); absl::Time run_plan_start_time = absl::Now(); phase_logger.LogEligibilityEvalComputationStarted(); engine::PlanResult plan_result = RunEligibilityEvalPlanWithTensorflowSpec( example_iterator_factories, should_abort, log_manager, opstats_logger, flags, plan, *checkpoint_input_filename, timing_config, run_plan_start_time, reference_time); absl::StatusOr task_eligibility_info; if (plan_result.outcome == engine::PlanOutcome::kSuccess) { task_eligibility_info = ParseEligibilityEvalPlanOutput(plan_result.output_tensors); } LogEligibilityEvalComputationOutcome(phase_logger, std::move(plan_result), task_eligibility_info.status(), run_plan_start_time, reference_time); return task_eligibility_info; } struct EligibilityEvalResult { std::optional task_eligibility_info; std::vector task_names_for_multiple_task_assignments; }; // Create an EligibilityEvalResult from a TaskEligibilityInfo and // PopulationEligibilitySpec. If both population_spec and // task_eligibility_info are present, the returned EligibilityEvalResult will // contain a TaskEligibilityInfo which only contains the tasks for single task // assignment, and a vector of task names for multiple task assignment. EligibilityEvalResult CreateEligibilityEvalResult( const std::optional& task_eligibility_info, const std::optional& population_spec) { EligibilityEvalResult result; if (population_spec.has_value() && task_eligibility_info.has_value()) { absl::flat_hash_set task_names_for_multiple_task_assignments; for (const auto& task_info : population_spec.value().task_info()) { if (task_info.task_assignment_mode() == PopulationEligibilitySpec::TaskInfo::TASK_ASSIGNMENT_MODE_MULTIPLE) { task_names_for_multiple_task_assignments.insert(task_info.task_name()); } } TaskEligibilityInfo single_task_assignment_eligibility_info; single_task_assignment_eligibility_info.set_version( task_eligibility_info.value().version()); for (const auto& task_weight : task_eligibility_info.value().task_weights()) { if (task_names_for_multiple_task_assignments.contains( task_weight.task_name())) { result.task_names_for_multiple_task_assignments.push_back( task_weight.task_name()); } else { *single_task_assignment_eligibility_info.mutable_task_weights()->Add() = task_weight; } } result.task_eligibility_info = single_task_assignment_eligibility_info; } else { result.task_eligibility_info = task_eligibility_info; } return result; } // Issues an eligibility eval checkin request and executes the eligibility // eval task if the server returns one. // // This function modifies the FLRunnerResult with values received over the // course of the eligibility eval protocol interaction. // // Returns: // - the TaskEligibilityInfo produced by the eligibility eval task, if the // server provided an eligibility eval task to run. // - an std::nullopt if the server indicated that there is no eligibility eval // task configured for the population. // - an INTERNAL error if the server rejects the client or another error // occurs // that should abort the training run. The error will already have been // logged appropriately. absl::StatusOr IssueEligibilityEvalCheckinAndRunPlan( std::vector example_iterator_factories, std::function should_abort, PhaseLogger& phase_logger, Files* files, LogManager* log_manager, OpStatsLogger* opstats_logger, const Flags* flags, FederatedProtocol* federated_protocol, const fcp::client::InterruptibleRunner::TimingConfig& timing_config, const absl::Time reference_time, FLRunnerResult& fl_runner_result) { const absl::Time time_before_checkin = absl::Now(); const NetworkStats network_stats_before_checkin = GetCumulativeNetworkStats(federated_protocol, /*fedselect_manager=*/nullptr); // These fields will, after a successful checkin that resulted in an EET // being received, contain the time at which the EET plan/checkpoint URIs // were received (but not yet downloaded), as well as the cumulative network // stats at that point, allowing us to separately calculate how long it took // to then download the actual payloads. absl::Time time_before_plan_download = time_before_checkin; NetworkStats network_stats_before_plan_download = network_stats_before_checkin; // Log that we are about to check in with the server. phase_logger.LogEligibilityEvalCheckinStarted(); // Issue the eligibility eval checkin request (providing a callback that // will be called when an EET is assigned to the task but before its // plan/checkpoint URIs have actually been downloaded). bool plan_uris_received_callback_called = false; std::function plan_uris_received_callback = [&time_before_plan_download, &network_stats_before_plan_download, &time_before_checkin, &network_stats_before_checkin, &federated_protocol, &phase_logger, &plan_uris_received_callback_called]( const FederatedProtocol::EligibilityEvalTask& task) { // When the plan URIs have been received, we already know the name // of the task we have been assigned, so let's tell the // PhaseLogger. phase_logger.SetModelIdentifier(task.execution_id); // We also should log a corresponding log event. phase_logger.LogEligibilityEvalCheckinPlanUriReceived( GetNetworkStatsSince(federated_protocol, /*fedselect_manager=*/nullptr, network_stats_before_checkin), time_before_checkin); // And we must take a snapshot of the current time & network // stats, so we can distinguish between the duration/network stats // incurred for the checkin request vs. the actual downloading of // the plan/checkpoint resources. time_before_plan_download = absl::Now(); network_stats_before_plan_download = GetCumulativeNetworkStats(federated_protocol, /*fedselect_manager=*/nullptr); plan_uris_received_callback_called = true; }; absl::StatusOr eligibility_checkin_result = federated_protocol->EligibilityEvalCheckin( plan_uris_received_callback); UpdateRetryWindowAndNetworkStats(*federated_protocol, /*fedselect_manager=*/nullptr, phase_logger, fl_runner_result); // It's a bit unfortunate that we have to inspect the checkin_result and // extract the model identifier here rather than further down the function, // but this ensures that the histograms below will have the right model // identifier attached (and we want to also emit the histograms even if we // have failed/rejected checkin outcomes). if (eligibility_checkin_result.ok() && std::holds_alternative( *eligibility_checkin_result)) { // Make sure that if we received an EligibilityEvalTask, then the callback // should have already been called by this point by the protocol (ensuring // that SetModelIdentifier has been called etc.). FCP_CHECK(plan_uris_received_callback_called); } if (!eligibility_checkin_result.ok()) { auto status = eligibility_checkin_result.status(); auto message = absl::StrCat("Error during eligibility eval checkin: code: ", status.code(), ", message: ", status.message()); if (status.code() == absl::StatusCode::kAborted) { phase_logger.LogEligibilityEvalCheckinServerAborted( status, GetNetworkStatsSince(federated_protocol, /*fedselect_manager=*/nullptr, network_stats_before_plan_download), time_before_plan_download); } else if (status.code() == absl::StatusCode::kCancelled) { phase_logger.LogEligibilityEvalCheckinClientInterrupted( status, GetNetworkStatsSince(federated_protocol, /*fedselect_manager=*/nullptr, network_stats_before_plan_download), time_before_plan_download); } else if (!status.ok()) { phase_logger.LogEligibilityEvalCheckinIOError( status, GetNetworkStatsSince(federated_protocol, /*fedselect_manager=*/nullptr, network_stats_before_plan_download), time_before_plan_download); } FCP_LOG(INFO) << message; return absl::InternalError(""); } EligibilityEvalResult result; if (std::holds_alternative( *eligibility_checkin_result)) { phase_logger.LogEligibilityEvalCheckinTurnedAway( GetNetworkStatsSince(federated_protocol, /*fedselect_manager=*/nullptr, network_stats_before_checkin), time_before_checkin); // If the server explicitly rejected our request, then we must abort and // we must not proceed to the "checkin" phase below. FCP_LOG(INFO) << "Device rejected by server during eligibility eval " "checkin; aborting"; return absl::InternalError(""); } else if (std::holds_alternative( *eligibility_checkin_result)) { phase_logger.LogEligibilityEvalNotConfigured( GetNetworkStatsSince(federated_protocol, /*fedselect_manager=*/nullptr, network_stats_before_checkin), time_before_checkin); // If the server indicates that no eligibility eval task is configured for // the population then there is nothing more to do. We simply proceed to // the "checkin" phase below without providing it a TaskEligibilityInfo // proto. result.task_eligibility_info = std::nullopt; return result; } auto eligibility_eval_task = absl::get( *eligibility_checkin_result); // Parse and run the eligibility eval task if the server returned one. // Now we have a EligibilityEvalTask, if an error happens, we will report to // the server via the ReportEligibilityEvalError. absl::StatusOr> task_eligibility_info = RunEligibilityEvalPlan( eligibility_eval_task, example_iterator_factories, should_abort, phase_logger, files, log_manager, opstats_logger, flags, federated_protocol, timing_config, reference_time, /*time_before_checkin=*/time_before_checkin, /*time_before_plan_download=*/time_before_plan_download, GetNetworkStatsSince(federated_protocol, /*fedselect_manager=*/nullptr, network_stats_before_plan_download)); if (!task_eligibility_info.ok()) { // Note that none of the PhaseLogger methods will reflect the very little // amount of network usage the will be incurred by this protocol request. // We consider this to be OK to keep things simple, and because this // should use such a limited amount of network bandwidth. Do note that the // network usage *will* be correctly reported in the OpStats database. federated_protocol->ReportEligibilityEvalError( absl::Status(task_eligibility_info.status().code(), "Failed to compute eligibility info")); UpdateRetryWindowAndNetworkStats(*federated_protocol, /*fedselect_manager=*/nullptr, phase_logger, fl_runner_result); return task_eligibility_info.status(); } return CreateEligibilityEvalResult( *task_eligibility_info, eligibility_eval_task.population_eligibility_spec); } struct CheckinResult { std::string task_name; ClientOnlyPlan plan; int32_t minimum_clients_in_server_visible_aggregate; std::string checkpoint_input_filename; std::string computation_id; std::string federated_select_uri_template; }; absl::StatusOr IssueCheckin( PhaseLogger& phase_logger, LogManager* log_manager, Files* files, FederatedProtocol* federated_protocol, std::optional task_eligibility_info, absl::Time reference_time, const std::string& population_name, FLRunnerResult& fl_runner_result, const Flags* flags) { absl::Time time_before_checkin = absl::Now(); // We must return only stats that cover the check in phase for the log // events below. const NetworkStats network_stats_before_checkin = GetCumulativeNetworkStats(federated_protocol, /*fedselect_manager=*/nullptr); // These fields will, after a successful checkin that resulted in a task // being assigned, contain the time at which the task plan/checkpoint URIs // were received (but not yet downloaded), as well as the cumulative network // stats at that point, allowing us to separately calculate how long it took // to then download the actual payloads. absl::Time time_before_plan_download = time_before_checkin; NetworkStats network_stats_before_plan_download = network_stats_before_checkin; // Clear the model identifier before check-in, to ensure that the any prior // eligibility eval task name isn't used any longer. phase_logger.SetModelIdentifier(""); phase_logger.LogCheckinStarted(); std::string task_name; // Issue the checkin request (providing a callback that will be called when // an EET is assigned to the task but before its plan/checkpoint URIs have // actually been downloaded). bool plan_uris_received_callback_called = false; std::function plan_uris_received_callback = [&time_before_plan_download, &network_stats_before_plan_download, &time_before_checkin, &network_stats_before_checkin, &task_name, &federated_protocol, &population_name, &log_manager, &phase_logger, &plan_uris_received_callback_called]( const FederatedProtocol::TaskAssignment& task_assignment) { // When the plan URIs have been received, we already know the name // of the task we have been assigned, so let's tell the // PhaseLogger. auto model_identifier = task_assignment.aggregation_session_id; phase_logger.SetModelIdentifier(model_identifier); // We also should log a corresponding log event. task_name = ExtractTaskNameFromAggregationSessionId( model_identifier, population_name, *log_manager); phase_logger.LogCheckinPlanUriReceived( task_name, GetNetworkStatsSince(federated_protocol, /*fedselect_manager=*/nullptr, network_stats_before_checkin), time_before_checkin); // And we must take a snapshot of the current time & network // stats, so we can distinguish between the duration/network stats // incurred for the checkin request vs. the actual downloading of // the plan/checkpoint resources. time_before_plan_download = absl::Now(); network_stats_before_plan_download = GetCumulativeNetworkStats( federated_protocol, /*fedselect_manager=*/nullptr); plan_uris_received_callback_called = true; }; absl::StatusOr checkin_result = federated_protocol->Checkin(task_eligibility_info, plan_uris_received_callback); UpdateRetryWindowAndNetworkStats(*federated_protocol, /*fedselect_manager=*/nullptr, phase_logger, fl_runner_result); // It's a bit unfortunate that we have to inspect the checkin_result and // extract the model identifier here rather than further down the function, // but this ensures that the histograms below will have the right model // identifier attached (and we want to also emit the histograms even if we // have failed/rejected checkin outcomes). if (checkin_result.ok() && std::holds_alternative( *checkin_result)) { // Make sure that if we received a TaskAssignment, then the callback // should have already been called by this point by the protocol (ensuring // that SetModelIdentifier has been called etc.). FCP_CHECK(plan_uris_received_callback_called); } if (!checkin_result.ok()) { auto status = checkin_result.status(); auto message = absl::StrCat("Error during checkin: code: ", status.code(), ", message: ", status.message()); if (status.code() == absl::StatusCode::kAborted) { phase_logger.LogCheckinServerAborted( status, GetNetworkStatsSince(federated_protocol, /*fedselect_manager=*/nullptr, network_stats_before_plan_download), time_before_plan_download, reference_time); } else if (status.code() == absl::StatusCode::kCancelled) { phase_logger.LogCheckinClientInterrupted( status, GetNetworkStatsSince(federated_protocol, /*fedselect_manager=*/nullptr, network_stats_before_plan_download), time_before_plan_download, reference_time); } else if (!status.ok()) { phase_logger.LogCheckinIOError( status, GetNetworkStatsSince(federated_protocol, /*fedselect_manager=*/nullptr, network_stats_before_plan_download), time_before_plan_download, reference_time); } FCP_LOG(INFO) << message; return status; } // Server rejected us? Return the fl_runner_results as-is. if (std::holds_alternative(*checkin_result)) { phase_logger.LogCheckinTurnedAway( GetNetworkStatsSince(federated_protocol, /*fedselect_manager=*/nullptr, network_stats_before_checkin), time_before_checkin, reference_time); FCP_LOG(INFO) << "Device rejected by server during checkin; aborting"; return absl::InternalError("Device rejected by server."); } auto task_assignment = absl::get(*checkin_result); ClientOnlyPlan plan; auto plan_bytes = task_assignment.payloads.plan; if (!ParseFromStringOrCord(plan, plan_bytes)) { auto message = "Failed to parse received plan"; phase_logger.LogCheckinInvalidPayload( message, GetNetworkStatsSince(federated_protocol, /*fedselect_manager=*/nullptr, network_stats_before_plan_download), time_before_plan_download, reference_time); FCP_LOG(ERROR) << message; return absl::InternalError(""); } std::string computation_id; if (flags->enable_computation_id()) { computation_id = ComputeSHA256FromStringOrCord(plan_bytes); } int32_t minimum_clients_in_server_visible_aggregate = 0; if (task_assignment.sec_agg_info.has_value()) { auto minimum_number_of_participants = plan.phase().minimum_number_of_participants(); if (task_assignment.sec_agg_info->expected_number_of_clients < minimum_number_of_participants) { return absl::InternalError( "expectedNumberOfClients was less than Plan's " "minimumNumberOfParticipants."); } minimum_clients_in_server_visible_aggregate = task_assignment.sec_agg_info ->minimum_clients_in_server_visible_aggregate; } absl::StatusOr checkpoint_input_filename = ""; // Example query plan does not have an input checkpoint. if (!plan.phase().has_example_query_spec()) { checkpoint_input_filename = CreateInputCheckpointFile(files, task_assignment.payloads.checkpoint); if (!checkpoint_input_filename.ok()) { auto status = checkpoint_input_filename.status(); auto message = absl::StrCat( "Failed to create checkpoint input file: code: ", status.code(), ", message: ", status.message()); phase_logger.LogCheckinIOError( status, GetNetworkStatsSince(federated_protocol, /*fedselect_manager=*/nullptr, network_stats_before_plan_download), time_before_plan_download, reference_time); FCP_LOG(ERROR) << message; return status; } } phase_logger.LogCheckinCompleted( task_name, GetNetworkStatsSince(federated_protocol, /*fedselect_manager=*/nullptr, network_stats_before_plan_download), /*time_before_checkin=*/time_before_checkin, /*time_before_plan_download=*/time_before_plan_download, reference_time); return CheckinResult{ .task_name = std::move(task_name), .plan = std::move(plan), .minimum_clients_in_server_visible_aggregate = minimum_clients_in_server_visible_aggregate, .checkpoint_input_filename = std::move(*checkpoint_input_filename), .computation_id = std::move(computation_id), .federated_select_uri_template = task_assignment.federated_select_uri_template}; } } // namespace absl::StatusOr RunFederatedComputation( SimpleTaskEnvironment* env_deps, EventPublisher* event_publisher, Files* files, LogManager* log_manager, const Flags* flags, const std::string& federated_service_uri, const std::string& api_key, const std::string& test_cert_path, const std::string& session_name, const std::string& population_name, const std::string& retry_token, const std::string& client_version, const std::string& attestation_measurement) { auto opstats_logger = engine::CreateOpStatsLogger(env_deps->GetBaseDir(), flags, log_manager, session_name, population_name); absl::Time reference_time = absl::Now(); FLRunnerResult fl_runner_result; fcp::client::InterruptibleRunner::TimingConfig timing_config = { .polling_period = absl::Milliseconds(flags->condition_polling_period_millis()), .graceful_shutdown_period = absl::Milliseconds( flags->tf_execution_teardown_grace_period_millis()), .extended_shutdown_period = absl::Milliseconds( flags->tf_execution_teardown_extended_period_millis()), }; auto should_abort_protocol_callback = [&env_deps, &timing_config]() -> bool { // Return the Status if failed, or the negated value if successful. return env_deps->ShouldAbort(absl::Now(), timing_config.polling_period); }; PhaseLoggerImpl phase_logger(event_publisher, opstats_logger.get(), log_manager, flags); // If there was an error initializing OpStats, opstats_logger will be a // no-op implementation and execution will be allowed to continue. if (!opstats_logger->GetInitStatus().ok()) { // This will only happen if OpStats is enabled and there was an error in // initialization. phase_logger.LogNonfatalInitializationError( opstats_logger->GetInitStatus()); } Clock* clock = Clock::RealClock(); std::unique_ptr resource_cache; // if (flags->max_resource_cache_size_bytes() > 0) { // // Anything that goes wrong in FileBackedResourceCache::Create is a // // programmer error. // absl::StatusOr> // resource_cache_internal = cache::FileBackedResourceCache::Create( // env_deps->GetBaseDir(), env_deps->GetCacheDir(), log_manager, // clock, flags->max_resource_cache_size_bytes()); // if (!resource_cache_internal.ok()) { // auto resource_init_failed_status = absl::Status( // resource_cache_internal.status().code(), // absl::StrCat("Failed to initialize FileBackedResourceCache: ", // resource_cache_internal.status().ToString())); // if (flags->resource_cache_initialization_error_is_fatal()) { // phase_logger.LogFatalInitializationError(resource_init_failed_status); // return resource_init_failed_status; // } // // We log an error but otherwise proceed as if the cache was disabled. // phase_logger.LogNonfatalInitializationError(resource_init_failed_status); // } else { // resource_cache = std::move(*resource_cache_internal); // } // } std::unique_ptr<::fcp::client::http::HttpClient> http_client = flags->enable_grpc_with_http_resource_support() || flags->use_http_federated_compute_protocol() ? env_deps->CreateHttpClient() : nullptr; std::unique_ptr federated_protocol; if (flags->use_http_federated_compute_protocol()) { log_manager->LogDiag(ProdDiagCode::HTTP_FEDERATED_PROTOCOL_USED); // Verify the entry point uri starts with "https://" or // "http://localhost". Note "http://localhost" is allowed for testing // purpose. if (!(absl::StartsWith(federated_service_uri, "https://") || absl::StartsWith(federated_service_uri, "http://localhost"))) { return absl::InvalidArgumentError("The entry point uri is invalid."); } federated_protocol = std::make_unique( clock, log_manager, flags, http_client.get(), std::make_unique(), event_publisher->secagg_event_publisher(), federated_service_uri, api_key, population_name, retry_token, client_version, attestation_measurement, should_abort_protocol_callback, absl::BitGen(), timing_config, resource_cache.get()); } else { #ifdef FCP_CLIENT_SUPPORT_GRPC // Check in with the server to either retrieve a plan + initial // checkpoint, or get rejected with a RetryWindow. auto grpc_channel_deadline = flags->grpc_channel_deadline_seconds(); if (grpc_channel_deadline <= 0) { grpc_channel_deadline = 600; FCP_LOG(INFO) << "Using default channel deadline of " << grpc_channel_deadline << " seconds."; } federated_protocol = std::make_unique( event_publisher, log_manager, std::make_unique(), flags, http_client.get(), federated_service_uri, api_key, test_cert_path, population_name, retry_token, client_version, attestation_measurement, should_abort_protocol_callback, timing_config, grpc_channel_deadline, resource_cache.get()); #else return absl::InternalError("No FederatedProtocol enabled."); #endif } std::unique_ptr federated_select_manager; if (http_client != nullptr && flags->enable_federated_select()) { federated_select_manager = std::make_unique( log_manager, files, http_client.get(), should_abort_protocol_callback, timing_config); } else { federated_select_manager = std::make_unique(log_manager); } return RunFederatedComputation(env_deps, phase_logger, event_publisher, files, log_manager, opstats_logger.get(), flags, federated_protocol.get(), federated_select_manager.get(), timing_config, reference_time, session_name, population_name); } absl::StatusOr RunFederatedComputation( SimpleTaskEnvironment* env_deps, PhaseLogger& phase_logger, EventPublisher* event_publisher, Files* files, LogManager* log_manager, OpStatsLogger* opstats_logger, const Flags* flags, FederatedProtocol* federated_protocol, FederatedSelectManager* fedselect_manager, const fcp::client::InterruptibleRunner::TimingConfig& timing_config, const absl::Time reference_time, const std::string& session_name, const std::string& population_name) { SelectorContext federated_selector_context; federated_selector_context.mutable_computation_properties()->set_session_name( session_name); FederatedComputation federated_computation; federated_computation.set_population_name(population_name); *federated_selector_context.mutable_computation_properties() ->mutable_federated() = federated_computation; SelectorContext eligibility_selector_context; eligibility_selector_context.mutable_computation_properties() ->set_session_name(session_name); EligibilityEvalComputation eligibility_eval_computation; eligibility_eval_computation.set_population_name(population_name); *eligibility_selector_context.mutable_computation_properties() ->mutable_eligibility_eval() = eligibility_eval_computation; // Construct a default FLRunnerResult that reflects an unsuccessful training // attempt and which uses RetryWindow corresponding to transient errors (if // the flag is on). // This is what will be returned if we have to bail early, before we've // received a RetryWindow from the server. FLRunnerResult fl_runner_result; fl_runner_result.set_contribution_result(FLRunnerResult::FAIL); // Before we even check whether we should abort right away, update the retry // window. That way we will use the most appropriate retry window we have // available (an implementation detail of FederatedProtocol, but generally a // 'transient error' retry window based on the provided flag values) in case // we do need to abort. UpdateRetryWindowAndNetworkStats(*federated_protocol, fedselect_manager, phase_logger, fl_runner_result); // Check if the device conditions allow for checking in with the server // and running a federated computation. If not, bail early with the // transient error retry window. std::function should_abort = [env_deps, &timing_config]() { return env_deps->ShouldAbort(absl::Now(), timing_config.polling_period); }; if (should_abort()) { std::string message = "Device conditions not satisfied, aborting federated computation"; FCP_LOG(INFO) << message; phase_logger.LogTaskNotStarted(message); return fl_runner_result; } // Eligibility eval plans can use example iterators from the // SimpleTaskEnvironment and those reading the OpStats DB. opstats::OpStatsExampleIteratorFactory opstats_example_iterator_factory( opstats_logger, log_manager, flags->opstats_last_successful_contribution_criteria()); std::unique_ptr env_eligibility_example_iterator_factory = CreateSimpleTaskEnvironmentIteratorFactory( env_deps, eligibility_selector_context); std::vector eligibility_example_iterator_factories{ &opstats_example_iterator_factory, env_eligibility_example_iterator_factory.get()}; // Note that this method will update fl_runner_result's fields with values // received over the course of the eligibility eval protocol interaction. absl::StatusOr eligibility_eval_result = IssueEligibilityEvalCheckinAndRunPlan( eligibility_example_iterator_factories, should_abort, phase_logger, files, log_manager, opstats_logger, flags, federated_protocol, timing_config, reference_time, fl_runner_result); if (!eligibility_eval_result.ok()) { return fl_runner_result; } auto checkin_result = IssueCheckin(phase_logger, log_manager, files, federated_protocol, std::move(eligibility_eval_result->task_eligibility_info), reference_time, population_name, fl_runner_result, flags); if (!checkin_result.ok()) { return fl_runner_result; } SelectorContext federated_selector_context_with_task_name = federated_selector_context; federated_selector_context_with_task_name.mutable_computation_properties() ->mutable_federated() ->set_task_name(checkin_result->task_name); if (flags->enable_computation_id()) { federated_selector_context_with_task_name.mutable_computation_properties() ->mutable_federated() ->set_computation_id(checkin_result->computation_id); } if (checkin_result->plan.phase().has_example_query_spec()) { federated_selector_context_with_task_name.mutable_computation_properties() ->set_example_iterator_output_format( ::fcp::client::QueryTimeComputationProperties:: EXAMPLE_QUERY_RESULT); } // Include the last successful contribution timestamp in the // SelectorContext. const auto& opstats_db = opstats_logger->GetOpStatsDb(); if (opstats_db != nullptr) { absl::StatusOr data = opstats_db->Read(); if (data.ok()) { std::optional last_successful_contribution_time = opstats::GetLastSuccessfulContributionTime( *data, checkin_result->task_name); if (last_successful_contribution_time.has_value()) { *(federated_selector_context_with_task_name .mutable_computation_properties() ->mutable_federated() ->mutable_historical_context() ->mutable_last_successful_contribution_time()) = *last_successful_contribution_time; } } } if (checkin_result->plan.phase().has_example_query_spec()) { // Example query plan only supports simple agg for now. *(federated_selector_context_with_task_name .mutable_computation_properties() ->mutable_federated() ->mutable_simple_aggregation()) = SimpleAggregation(); } else { const auto& federated_compute_io_router = checkin_result->plan.phase().federated_compute(); const bool has_simpleagg_tensors = !federated_compute_io_router.output_filepath_tensor_name().empty(); bool all_aggregations_are_secagg = true; for (const auto& aggregation : federated_compute_io_router.aggregations()) { all_aggregations_are_secagg &= aggregation.second.protocol_config_case() == AggregationConfig::kSecureAggregation; } if (!has_simpleagg_tensors && all_aggregations_are_secagg) { federated_selector_context_with_task_name .mutable_computation_properties() ->mutable_federated() ->mutable_secure_aggregation() ->set_minimum_clients_in_server_visible_aggregate( checkin_result->minimum_clients_in_server_visible_aggregate); } else { // Has an output checkpoint, so some tensors must be simply aggregated. *(federated_selector_context_with_task_name .mutable_computation_properties() ->mutable_federated() ->mutable_simple_aggregation()) = SimpleAggregation(); } } RetryWindow report_retry_window; phase_logger.LogComputationStarted(); absl::Time run_plan_start_time = absl::Now(); NetworkStats run_plan_start_network_stats = GetCumulativeNetworkStats(federated_protocol, fedselect_manager); absl::StatusOr checkpoint_output_filename = files->CreateTempFile("output", ".ckp"); if (!checkpoint_output_filename.ok()) { auto status = checkpoint_output_filename.status(); auto message = absl::StrCat( "Could not create temporary output checkpoint file: code: ", status.code(), ", message: ", status.message()); phase_logger.LogComputationIOError( status, ExampleStats(), GetNetworkStatsSince(federated_protocol, fedselect_manager, run_plan_start_network_stats), run_plan_start_time); return fl_runner_result; } // Regular plans can use example iterators from the SimpleTaskEnvironment, // those reading the OpStats DB, or those serving Federated Select slices. std::unique_ptr env_example_iterator_factory = CreateSimpleTaskEnvironmentIteratorFactory( env_deps, federated_selector_context_with_task_name); std::unique_ptr<::fcp::client::engine::ExampleIteratorFactory> fedselect_example_iterator_factory = fedselect_manager->CreateExampleIteratorFactoryForUriTemplate( checkin_result->federated_select_uri_template); std::vector example_iterator_factories{ fedselect_example_iterator_factory.get(), &opstats_example_iterator_factory, env_example_iterator_factory.get()}; PlanResultAndCheckpointFile plan_result_and_checkpoint_file = checkin_result->plan.phase().has_example_query_spec() ? RunPlanWithExampleQuerySpec( example_iterator_factories, opstats_logger, flags, checkin_result->plan, *checkpoint_output_filename) : RunPlanWithTensorflowSpec( example_iterator_factories, should_abort, log_manager, opstats_logger, flags, checkin_result->plan, checkin_result->checkpoint_input_filename, *checkpoint_output_filename, timing_config); // Update the FLRunnerResult fields to account for any network usage during // the execution of the plan (e.g. due to Federated Select slices having // been fetched). UpdateRetryWindowAndNetworkStats(*federated_protocol, fedselect_manager, phase_logger, fl_runner_result); auto outcome = plan_result_and_checkpoint_file.plan_result.outcome; absl::StatusOr computation_results; if (outcome == engine::PlanOutcome::kSuccess) { computation_results = CreateComputationResults( checkin_result->plan.phase().has_example_query_spec() ? nullptr : &checkin_result->plan.phase().tensorflow_spec(), plan_result_and_checkpoint_file); } LogComputationOutcome( plan_result_and_checkpoint_file.plan_result, computation_results.status(), phase_logger, GetNetworkStatsSince(federated_protocol, fedselect_manager, run_plan_start_network_stats), run_plan_start_time, reference_time); absl::Status report_result = ReportPlanResult( federated_protocol, phase_logger, std::move(computation_results), run_plan_start_time, reference_time); if (outcome == engine::PlanOutcome::kSuccess && report_result.ok()) { // Only if training succeeded *and* reporting succeeded do we consider // the device to have contributed successfully. fl_runner_result.set_contribution_result(FLRunnerResult::SUCCESS); } // Update the FLRunnerResult fields one more time to account for the // "Report" protocol interaction. UpdateRetryWindowAndNetworkStats(*federated_protocol, fedselect_manager, phase_logger, fl_runner_result); return fl_runner_result; } FLRunnerTensorflowSpecResult RunPlanWithTensorflowSpecForTesting( SimpleTaskEnvironment* env_deps, EventPublisher* event_publisher, Files* files, LogManager* log_manager, const Flags* flags, const ClientOnlyPlan& client_plan, const std::string& checkpoint_input_filename, const fcp::client::InterruptibleRunner::TimingConfig& timing_config, const absl::Time run_plan_start_time, const absl::Time reference_time) { FLRunnerTensorflowSpecResult result; result.set_outcome(engine::PhaseOutcome::ERROR); engine::PlanResult plan_result(engine::PlanOutcome::kTensorflowError, absl::UnknownError("")); std::function should_abort = [env_deps, &timing_config]() { return env_deps->ShouldAbort(absl::Now(), timing_config.polling_period); }; auto opstats_logger = engine::CreateOpStatsLogger(env_deps->GetBaseDir(), flags, log_manager, /*session_name=*/"", /*population_name=*/""); PhaseLoggerImpl phase_logger(event_publisher, opstats_logger.get(), log_manager, flags); // Regular plans can use example iterators from the SimpleTaskEnvironment, // those reading the OpStats DB, or those serving Federated Select slices. // However, we don't provide a Federated Select-specific example iterator // factory. That way, the Federated Select slice queries will be forwarded // to SimpleTaskEnvironment, which can handle them by providing // test-specific slices if they want to. // // Eligibility eval plans can only use iterators from the // SimpleTaskEnvironment and those reading the OpStats DB. opstats::OpStatsExampleIteratorFactory opstats_example_iterator_factory( opstats_logger.get(), log_manager, flags->opstats_last_successful_contribution_criteria()); std::unique_ptr env_example_iterator_factory = CreateSimpleTaskEnvironmentIteratorFactory(env_deps, SelectorContext()); std::vector example_iterator_factories{ &opstats_example_iterator_factory, env_example_iterator_factory.get()}; phase_logger.LogComputationStarted(); if (client_plan.phase().has_federated_compute()) { absl::StatusOr checkpoint_output_filename = files->CreateTempFile("output", ".ckp"); if (!checkpoint_output_filename.ok()) { phase_logger.LogComputationIOError( checkpoint_output_filename.status(), ExampleStats(), // Empty network stats, since no network protocol is actually used // in this method. NetworkStats(), run_plan_start_time); return result; } // Regular TensorflowSpec-based plans. PlanResultAndCheckpointFile plan_result_and_checkpoint_file = RunPlanWithTensorflowSpec(example_iterator_factories, should_abort, log_manager, opstats_logger.get(), flags, client_plan, checkpoint_input_filename, *checkpoint_output_filename, timing_config); result.set_checkpoint_output_filename( plan_result_and_checkpoint_file.checkpoint_file); plan_result = std::move(plan_result_and_checkpoint_file.plan_result); } else if (client_plan.phase().has_federated_compute_eligibility()) { // Eligibility eval plans. plan_result = RunEligibilityEvalPlanWithTensorflowSpec( example_iterator_factories, should_abort, log_manager, opstats_logger.get(), flags, client_plan, checkpoint_input_filename, timing_config, run_plan_start_time, reference_time); } else { // This branch shouldn't be taken, unless we add an additional type of // TensorflowSpec-based plan in the future. We return a readable error so // that when such new plan types *are* added, they result in clear // compatibility test failures when such plans are erroneously targeted at // old releases that don't support them yet. event_publisher->PublishIoError("Unsupported TensorflowSpec-based plan"); return result; } // Copy output tensors into the result proto. result.set_outcome( engine::ConvertPlanOutcomeToPhaseOutcome(plan_result.outcome)); if (plan_result.outcome == engine::PlanOutcome::kSuccess) { for (int i = 0; i < plan_result.output_names.size(); i++) { tensorflow::TensorProto output_tensor_proto; plan_result.output_tensors[i].AsProtoField(&output_tensor_proto); (*result.mutable_output_tensors())[plan_result.output_names[i]] = std::move(output_tensor_proto); } phase_logger.LogComputationCompleted( plan_result.example_stats, // Empty network stats, since no network protocol is actually used in // this method. NetworkStats(), run_plan_start_time, reference_time); } else { phase_logger.LogComputationTensorflowError( plan_result.original_status, plan_result.example_stats, NetworkStats(), run_plan_start_time, reference_time); } return result; } } // namespace client } // namespace fcp