xref: /aosp_15_r20/external/federated-compute/fcp/client/engine/tflite_plan_engine.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2021 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "fcp/client/engine/tflite_plan_engine.h"
17 
18 #include <functional>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "fcp/client/engine/plan_engine_helpers.h"
24 #include "fcp/client/engine/tflite_wrapper.h"
25 #include "fcp/protos/plan.pb.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/protobuf/struct.pb.h"
28 
29 namespace fcp {
30 namespace client {
31 namespace engine {
32 
33 using ::google::internal::federated::plan::TensorflowSpec;
34 
35 namespace {
36 
CreatePlanResultFromOutput(absl::StatusOr<OutputTensors> output,std::atomic<int> * total_example_count,std::atomic<int64_t> * total_example_size_bytes,absl::Status example_iterator_status)37 PlanResult CreatePlanResultFromOutput(
38     absl::StatusOr<OutputTensors> output, std::atomic<int>* total_example_count,
39     std::atomic<int64_t>* total_example_size_bytes,
40     absl::Status example_iterator_status) {
41   switch (output.status().code()) {
42     case absl::StatusCode::kOk: {
43       PlanResult plan_result(PlanOutcome::kSuccess, absl::OkStatus());
44       plan_result.output_names = std::move(output->output_tensor_names);
45       plan_result.output_tensors = std::move(output->output_tensors);
46       plan_result.example_stats = {
47           .example_count = *total_example_count,
48           .example_size_bytes = *total_example_size_bytes};
49       return plan_result;
50     }
51     case absl::StatusCode::kCancelled:
52       return PlanResult(PlanOutcome::kInterrupted, std::move(output.status()));
53     case absl::StatusCode::kInvalidArgument:
54       return CreateComputationErrorPlanResult(example_iterator_status,
55                                               output.status());
56     default:
57       FCP_LOG(FATAL) << "unexpected status code: " << output.status().code();
58   }
59   // Unreachable code.
60   return PlanResult(PlanOutcome::kTensorflowError, absl::InternalError(""));
61 }
62 
CreateOptions(const Flags & flags)63 TfLiteInterpreterOptions CreateOptions(const Flags& flags) {
64   return TfLiteInterpreterOptions{
65       .ensure_dynamic_tensors_are_released =
66           flags.ensure_dynamic_tensors_are_released(),
67       .large_tensor_threshold_for_dynamic_allocation =
68           flags.large_tensor_threshold_for_dynamic_allocation(),
69       .disable_delegate_clustering =
70           flags.disable_tflite_delegate_clustering()};
71 }
72 }  // namespace
73 
RunPlan(const TensorflowSpec & tensorflow_spec,const std::string & model,std::unique_ptr<absl::flat_hash_map<std::string,std::string>> inputs,const std::vector<std::string> & output_names)74 PlanResult TfLitePlanEngine::RunPlan(
75     const TensorflowSpec& tensorflow_spec, const std::string& model,
76     std::unique_ptr<absl::flat_hash_map<std::string, std::string>> inputs,
77     const std::vector<std::string>& output_names) {
78   FCP_LOG(INFO) << "***** start running plan";
79   log_manager_->LogDiag(ProdDiagCode::BACKGROUND_TRAINING_TFLITE_ENGINE_USED);
80   // Check that all inputs have corresponding TensorSpecProtos.
81   absl::flat_hash_set<std::string> expected_input_tensor_names_set;
82   for (auto it = inputs->begin(); it != inputs->end(); it++) {
83     expected_input_tensor_names_set.insert(it->first);
84   }
85   absl::Status validity_checks = ValidateTensorflowSpec(
86       tensorflow_spec, expected_input_tensor_names_set, output_names);
87   if (!validity_checks.ok()) {
88     FCP_LOG(ERROR) << validity_checks.message();
89     return PlanResult(PlanOutcome::kInvalidArgument,
90                       std::move(validity_checks));
91   }
92   std::atomic<int> total_example_count = 0;
93   std::atomic<int64_t> total_example_size_bytes = 0;
94   ExampleIteratorStatus example_iterator_status;
95   HostObjectRegistration host_registration = AddDatasetTokenToInputsForTfLite(
96       example_iterator_factories_, opstats_logger_, inputs.get(),
97       tensorflow_spec.dataset_token_tensor_name(), &total_example_count,
98       &total_example_size_bytes, &example_iterator_status);
99   // If the constant inputs are provided and the flag is enabled, add these to
100   // the map of TFLite inputs.
101   if (!tensorflow_spec.constant_inputs().empty()) {
102     FCP_LOG(INFO) << "***** constant inputs is not empty";
103     if (!flags_.support_constant_tf_inputs()) {
104       return PlanResult(
105           PlanOutcome::kInvalidArgument,
106           absl::InternalError(
107               "Cannot run constant_inputs when experiment is disabled."));
108     } else {
109       for (const auto& [name, tensor_proto] :
110            tensorflow_spec.constant_inputs()) {
111         tensorflow::Tensor input_tensor;
112         if (!input_tensor.FromProto(tensor_proto)) {
113           FCP_LOG(ERROR) << "unable to convert constant_input to tensor: "
114                          << tensor_proto.DebugString();
115           return PlanResult(PlanOutcome::kInvalidArgument,
116                             absl::InternalError(
117                                 "Unable to convert constant_input to tensor"));
118         }
119         // Convert Tensor to TFLite represenation and add this as a string to
120         // inputs.
121         if (input_tensor.dtype() == tensorflow::DT_STRING) {
122           tensorflow::tstring str_data =
123               input_tensor.scalar<tensorflow::tstring>()();
124           inputs->insert({name, std::string(str_data.data(), str_data.size())});
125         } else {
126           FCP_LOG(ERROR) << "Constant input tensor is not a string tensor. "
127                             "Currently only string tensors are supported.";
128           return PlanResult(
129               PlanOutcome::kInvalidArgument,
130               absl::InternalError("Only string tensors are supported"));
131         }
132       }
133     }
134   }
135   absl::StatusOr<std::unique_ptr<TfLiteWrapper>> tflite_wrapper =
136       TfLiteWrapper::Create(model, should_abort_, *timing_config_, log_manager_,
137                             std::move(inputs), output_names,
138                             CreateOptions(flags_),
139                             flags_.num_threads_for_tflite());
140   FCP_LOG(INFO) << "***** create tflite wrapper";
141 
142   if (!tflite_wrapper.ok()) {
143     return PlanResult(PlanOutcome::kTensorflowError, tflite_wrapper.status());
144   }
145   // Start running the plan.
146   absl::StatusOr<OutputTensors> output = (*tflite_wrapper)->Run();
147   PlanResult plan_result = CreatePlanResultFromOutput(
148       std::move(output), &total_example_count, &total_example_size_bytes,
149       example_iterator_status.GetStatus());
150   return plan_result;
151 }
152 
153 }  // namespace engine
154 }  // namespace client
155 }  // namespace fcp
156