/* * Copyright 2021 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "fcp/client/engine/tflite_plan_engine.h" #include #include #include #include #include "fcp/client/engine/plan_engine_helpers.h" #include "fcp/client/engine/tflite_wrapper.h" #include "fcp/protos/plan.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/protobuf/struct.pb.h" namespace fcp { namespace client { namespace engine { using ::google::internal::federated::plan::TensorflowSpec; namespace { PlanResult CreatePlanResultFromOutput( absl::StatusOr output, std::atomic* total_example_count, std::atomic* total_example_size_bytes, absl::Status example_iterator_status) { switch (output.status().code()) { case absl::StatusCode::kOk: { PlanResult plan_result(PlanOutcome::kSuccess, absl::OkStatus()); plan_result.output_names = std::move(output->output_tensor_names); plan_result.output_tensors = std::move(output->output_tensors); plan_result.example_stats = { .example_count = *total_example_count, .example_size_bytes = *total_example_size_bytes}; return plan_result; } case absl::StatusCode::kCancelled: return PlanResult(PlanOutcome::kInterrupted, std::move(output.status())); case absl::StatusCode::kInvalidArgument: return CreateComputationErrorPlanResult(example_iterator_status, output.status()); default: FCP_LOG(FATAL) << "unexpected status code: " << output.status().code(); } // Unreachable code. return PlanResult(PlanOutcome::kTensorflowError, absl::InternalError("")); } TfLiteInterpreterOptions CreateOptions(const Flags& flags) { return TfLiteInterpreterOptions{ .ensure_dynamic_tensors_are_released = flags.ensure_dynamic_tensors_are_released(), .large_tensor_threshold_for_dynamic_allocation = flags.large_tensor_threshold_for_dynamic_allocation(), .disable_delegate_clustering = flags.disable_tflite_delegate_clustering()}; } } // namespace PlanResult TfLitePlanEngine::RunPlan( const TensorflowSpec& tensorflow_spec, const std::string& model, std::unique_ptr> inputs, const std::vector& output_names) { FCP_LOG(INFO) << "***** start running plan"; log_manager_->LogDiag(ProdDiagCode::BACKGROUND_TRAINING_TFLITE_ENGINE_USED); // Check that all inputs have corresponding TensorSpecProtos. absl::flat_hash_set expected_input_tensor_names_set; for (auto it = inputs->begin(); it != inputs->end(); it++) { expected_input_tensor_names_set.insert(it->first); } absl::Status validity_checks = ValidateTensorflowSpec( tensorflow_spec, expected_input_tensor_names_set, output_names); if (!validity_checks.ok()) { FCP_LOG(ERROR) << validity_checks.message(); return PlanResult(PlanOutcome::kInvalidArgument, std::move(validity_checks)); } std::atomic total_example_count = 0; std::atomic total_example_size_bytes = 0; ExampleIteratorStatus example_iterator_status; HostObjectRegistration host_registration = AddDatasetTokenToInputsForTfLite( example_iterator_factories_, opstats_logger_, inputs.get(), tensorflow_spec.dataset_token_tensor_name(), &total_example_count, &total_example_size_bytes, &example_iterator_status); // If the constant inputs are provided and the flag is enabled, add these to // the map of TFLite inputs. if (!tensorflow_spec.constant_inputs().empty()) { FCP_LOG(INFO) << "***** constant inputs is not empty"; if (!flags_.support_constant_tf_inputs()) { return PlanResult( PlanOutcome::kInvalidArgument, absl::InternalError( "Cannot run constant_inputs when experiment is disabled.")); } else { for (const auto& [name, tensor_proto] : tensorflow_spec.constant_inputs()) { tensorflow::Tensor input_tensor; if (!input_tensor.FromProto(tensor_proto)) { FCP_LOG(ERROR) << "unable to convert constant_input to tensor: " << tensor_proto.DebugString(); return PlanResult(PlanOutcome::kInvalidArgument, absl::InternalError( "Unable to convert constant_input to tensor")); } // Convert Tensor to TFLite represenation and add this as a string to // inputs. if (input_tensor.dtype() == tensorflow::DT_STRING) { tensorflow::tstring str_data = input_tensor.scalar()(); inputs->insert({name, std::string(str_data.data(), str_data.size())}); } else { FCP_LOG(ERROR) << "Constant input tensor is not a string tensor. " "Currently only string tensors are supported."; return PlanResult( PlanOutcome::kInvalidArgument, absl::InternalError("Only string tensors are supported")); } } } } absl::StatusOr> tflite_wrapper = TfLiteWrapper::Create(model, should_abort_, *timing_config_, log_manager_, std::move(inputs), output_names, CreateOptions(flags_), flags_.num_threads_for_tflite()); FCP_LOG(INFO) << "***** create tflite wrapper"; if (!tflite_wrapper.ok()) { return PlanResult(PlanOutcome::kTensorflowError, tflite_wrapper.status()); } // Start running the plan. absl::StatusOr output = (*tflite_wrapper)->Run(); PlanResult plan_result = CreatePlanResultFromOutput( std::move(output), &total_example_count, &total_example_size_bytes, example_iterator_status.GetStatus()); return plan_result; } } // namespace engine } // namespace client } // namespace fcp