/* * Copyright 2020 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "fcp/client/engine/plan_engine_helpers.h" #include #include #include #include #include "absl/status/statusor.h" #include "fcp/client/diag_codes.pb.h" #include "fcp/client/opstats/opstats_logger.h" #include "fcp/client/opstats/opstats_logger_impl.h" // #include "fcp/client/opstats/pds_backed_opstats_db.h" #include "fcp/protos/plan.pb.h" #include "fcp/tensorflow/external_dataset.h" namespace fcp { namespace client { namespace engine { namespace { using ::fcp::client::opstats::OpStatsLogger; using ::fcp::client::opstats::OpStatsLoggerImpl; // using ::fcp::client::opstats::PdsBackedOpStatsDb; using ::google::internal::federated::plan::ExampleSelector; /** An iterator that forwards the failing status from the external dataset to * TensorFlow. */ class FailingDatasetIterator : public ExternalDatasetIterator { public: explicit FailingDatasetIterator(absl::Status status) : status_(status) {} absl::StatusOr GetNext() final { return status_; } private: const absl::Status status_; }; class TrainingDatasetProvider : public ExternalDatasetProvider::UsingProtoSelector { public: TrainingDatasetProvider( std::vector example_iterator_factories, OpStatsLogger* opstats_logger, std::atomic* total_example_count, std::atomic* total_example_size_bytes, ExampleIteratorStatus* example_iterator_status) : example_iterator_factories_(example_iterator_factories), opstats_logger_(opstats_logger), total_example_count_(total_example_count), total_example_size_bytes_(total_example_size_bytes), example_iterator_status_(example_iterator_status) {} absl::StatusOr> MakeDataset( ExampleSelector selector) final { return ExternalDataset::FromFunction( [example_iterator_factories = example_iterator_factories_, opstats_logger = opstats_logger_, selector, total_example_count = total_example_count_, total_example_size_bytes = total_example_size_bytes_, example_iterator_status = example_iterator_status_]() -> std::unique_ptr { ExampleIteratorFactory* example_iterator_factory = FindExampleIteratorFactory(selector, example_iterator_factories); // The DatasetOp requires a valid iterator at this stage so return an // empty iterator if there was an error. if (example_iterator_factory == nullptr) { absl::Status error( absl::StatusCode::kInternal, "Could not find suitable ExampleIteratorFactory"); example_iterator_status->SetStatus(error); return std::make_unique(error); } absl::StatusOr> example_iterator = example_iterator_factory->CreateExampleIterator(selector); if (!example_iterator.ok()) { example_iterator_status->SetStatus(example_iterator.status()); return std::make_unique( example_iterator.status()); } return std::make_unique( std::move(*example_iterator), opstats_logger, total_example_count, total_example_size_bytes, example_iterator_status, selector.collection_uri(), /*collect_stats=*/example_iterator_factory->ShouldCollectStats()); }); } private: std::vector example_iterator_factories_; OpStatsLogger* opstats_logger_; std::atomic* total_example_count_; std::atomic* total_example_size_bytes_; ExampleIteratorStatus* example_iterator_status_; }; } // namespace DatasetIterator::DatasetIterator( std::unique_ptr example_iterator, opstats::OpStatsLogger* opstats_logger, std::atomic* total_example_count, std::atomic* total_example_size_bytes, ExampleIteratorStatus* example_iterator_status, const std::string& collection_uri, bool collect_stats) : example_iterator_(std::move(example_iterator)), opstats_logger_(opstats_logger), iterator_start_time_(absl::Now()), total_example_count_(total_example_count), total_example_size_bytes_(total_example_size_bytes), example_iterator_status_(example_iterator_status), example_count_(0), example_size_bytes_(0), collection_uri_(collection_uri), iterator_finished_(false), collect_stats_(collect_stats) {} DatasetIterator::~DatasetIterator() { if (collect_stats_) { opstats_logger_->UpdateDatasetStats(collection_uri_, example_count_, example_size_bytes_); } } // Returns the next entry from the dataset. absl::StatusOr DatasetIterator::GetNext() { absl::MutexLock locked(&iterator_lock_); if (iterator_finished_) { // If we've reached the end of the iterator, always return OUT_OF_RANGE. return absl::OutOfRangeError("End of iterator reached"); } absl::StatusOr example = example_iterator_->Next(); absl::StatusCode error_code = example.status().code(); example_iterator_status_->SetStatus(example.status()); if (error_code == absl::StatusCode::kOutOfRange) { example_iterator_->Close(); iterator_finished_ = true; } // If we're not forwarding an OUT_OF_RANGE to the caller, record example // stats for metrics logging. if (collect_stats_ && example.ok()) { // TODO(team): Consider reducing logic duplication in // cross-dataset and single-dataset example stat variables. *total_example_count_ += 1; *total_example_size_bytes_ += example->size(); example_count_ += 1; example_size_bytes_ += example->size(); } return example; } void ExampleIteratorStatus::SetStatus(absl::Status status) { absl::MutexLock lock(&mu_); // We ignores normal status such as ok and outOfRange to avoid running into a // race condition when an error happened, then an outofRange or ok status // returned in a different thread which overrides the error status. if (status.code() != absl::StatusCode::kOk && status.code() != absl::StatusCode::kOutOfRange) { status_ = status; } } absl::Status ExampleIteratorStatus::GetStatus() { absl::MutexLock lock(&mu_); return status_; } HostObjectRegistration AddDatasetTokenToInputs( std::vector example_iterator_factories, OpStatsLogger* opstats_logger, std::vector>* inputs, const std::string& dataset_token_tensor_name, std::atomic* total_example_count, std::atomic* total_example_size_bytes, ExampleIteratorStatus* example_iterator_status) { // Register the TrainingDatasetProvider with the global // ExternalDatasetProviderRegistry. auto host_registration = fcp::ExternalDatasetProviderRegistry::Register( std::make_shared( example_iterator_factories, opstats_logger, total_example_count, total_example_size_bytes, example_iterator_status)); // Pack the token returned from registering the provider into a string // tensor. TensorFlow will use that token via the ExternalDatasetOp to create // datasets and iterators. tensorflow::Tensor token_scalar(std::string{}); token_scalar.scalar()() = host_registration.token().ToString(); std::pair token_pair( dataset_token_tensor_name, token_scalar); inputs->emplace_back(token_pair); return host_registration; } HostObjectRegistration AddDatasetTokenToInputsForTfLite( std::vector example_iterator_factories, OpStatsLogger* opstats_logger, absl::flat_hash_map* inputs, const std::string& dataset_token_tensor_name, std::atomic* total_example_count, std::atomic* total_example_size_bytes, ExampleIteratorStatus* example_iterator_status) { // Registers the TrainingDatasetProvider with the global // ExternalDatasetProviderRegistry. auto host_registration = fcp::ExternalDatasetProviderRegistry::Register( std::make_shared( example_iterator_factories, opstats_logger, total_example_count, total_example_size_bytes, example_iterator_status)); // Adds the token returned from registering the provider to the map of inputs. // TfLite will use that token via the ExternalDatasetOp to create // datasets and iterators. (*inputs)[dataset_token_tensor_name] = host_registration.token().ToString(); return host_registration; } std::unique_ptr<::fcp::client::opstats::OpStatsLogger> CreateOpStatsLogger( const std::string& base_dir, const Flags* flags, LogManager* log_manager, const std::string& session_name, const std::string& population_name) { // if (flags->enable_opstats()) { // auto db_or = PdsBackedOpStatsDb::Create( // base_dir, flags->opstats_ttl_days() * absl::Hours(24), *log_manager, // flags->opstats_db_size_limit_bytes()); // if (db_or.ok()) { // return std::make_unique( // std::move(db_or).value(), log_manager, flags, session_name, // population_name); // } else { // if (flags->log_opstats_initialization_errors()) { // return std::make_unique( // /*opstats_enabled=*/flags->enable_opstats(), // /*init_status=*/db_or.status()); // } // } // } return std::make_unique( /*opstats_enabled=*/flags->enable_opstats()); } PlanResult CreateComputationErrorPlanResult( absl::Status example_iterator_status, absl::Status computation_error_status) { switch (example_iterator_status.code()) { case absl::StatusCode::kOk: case absl::StatusCode::kOutOfRange: // Either example iterators are working fine or we don't know the status // of the example iterators. In this case, we'll use the error status // returned from TensorFlow. return PlanResult(PlanOutcome::kTensorflowError, computation_error_status); case absl::StatusCode::kCancelled: // Example iterator got interrupted. return PlanResult(PlanOutcome::kInterrupted, example_iterator_status); default: // All other Example iterator errors. return PlanResult(PlanOutcome::kExampleIteratorError, example_iterator_status); } } ExampleIteratorFactory* FindExampleIteratorFactory( const ExampleSelector& selector, std::vector example_iterator_factories) { for (ExampleIteratorFactory* factory : example_iterator_factories) { if (factory->CanHandle(selector)) { return factory; } } return nullptr; } } // namespace engine } // namespace client } // namespace fcp