1 /*
2 * Copyright 2021 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "fcp/client/engine/tflite_plan_engine.h"
17
18 #include <functional>
19 #include <string>
20 #include <utility>
21 #include <vector>
22
23 #include "fcp/client/engine/plan_engine_helpers.h"
24 #include "fcp/client/engine/tflite_wrapper.h"
25 #include "fcp/protos/plan.pb.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/protobuf/struct.pb.h"
28
29 namespace fcp {
30 namespace client {
31 namespace engine {
32
33 using ::google::internal::federated::plan::TensorflowSpec;
34
35 namespace {
36
CreatePlanResultFromOutput(absl::StatusOr<OutputTensors> output,std::atomic<int> * total_example_count,std::atomic<int64_t> * total_example_size_bytes,absl::Status example_iterator_status)37 PlanResult CreatePlanResultFromOutput(
38 absl::StatusOr<OutputTensors> output, std::atomic<int>* total_example_count,
39 std::atomic<int64_t>* total_example_size_bytes,
40 absl::Status example_iterator_status) {
41 switch (output.status().code()) {
42 case absl::StatusCode::kOk: {
43 PlanResult plan_result(PlanOutcome::kSuccess, absl::OkStatus());
44 plan_result.output_names = std::move(output->output_tensor_names);
45 plan_result.output_tensors = std::move(output->output_tensors);
46 plan_result.example_stats = {
47 .example_count = *total_example_count,
48 .example_size_bytes = *total_example_size_bytes};
49 return plan_result;
50 }
51 case absl::StatusCode::kCancelled:
52 return PlanResult(PlanOutcome::kInterrupted, std::move(output.status()));
53 case absl::StatusCode::kInvalidArgument:
54 return CreateComputationErrorPlanResult(example_iterator_status,
55 output.status());
56 default:
57 FCP_LOG(FATAL) << "unexpected status code: " << output.status().code();
58 }
59 // Unreachable code.
60 return PlanResult(PlanOutcome::kTensorflowError, absl::InternalError(""));
61 }
62
CreateOptions(const Flags & flags)63 TfLiteInterpreterOptions CreateOptions(const Flags& flags) {
64 return TfLiteInterpreterOptions{
65 .ensure_dynamic_tensors_are_released =
66 flags.ensure_dynamic_tensors_are_released(),
67 .large_tensor_threshold_for_dynamic_allocation =
68 flags.large_tensor_threshold_for_dynamic_allocation(),
69 .disable_delegate_clustering =
70 flags.disable_tflite_delegate_clustering()};
71 }
72 } // namespace
73
RunPlan(const TensorflowSpec & tensorflow_spec,const std::string & model,std::unique_ptr<absl::flat_hash_map<std::string,std::string>> inputs,const std::vector<std::string> & output_names)74 PlanResult TfLitePlanEngine::RunPlan(
75 const TensorflowSpec& tensorflow_spec, const std::string& model,
76 std::unique_ptr<absl::flat_hash_map<std::string, std::string>> inputs,
77 const std::vector<std::string>& output_names) {
78 FCP_LOG(INFO) << "***** start running plan";
79 log_manager_->LogDiag(ProdDiagCode::BACKGROUND_TRAINING_TFLITE_ENGINE_USED);
80 // Check that all inputs have corresponding TensorSpecProtos.
81 absl::flat_hash_set<std::string> expected_input_tensor_names_set;
82 for (auto it = inputs->begin(); it != inputs->end(); it++) {
83 expected_input_tensor_names_set.insert(it->first);
84 }
85 absl::Status validity_checks = ValidateTensorflowSpec(
86 tensorflow_spec, expected_input_tensor_names_set, output_names);
87 if (!validity_checks.ok()) {
88 FCP_LOG(ERROR) << validity_checks.message();
89 return PlanResult(PlanOutcome::kInvalidArgument,
90 std::move(validity_checks));
91 }
92 std::atomic<int> total_example_count = 0;
93 std::atomic<int64_t> total_example_size_bytes = 0;
94 ExampleIteratorStatus example_iterator_status;
95 HostObjectRegistration host_registration = AddDatasetTokenToInputsForTfLite(
96 example_iterator_factories_, opstats_logger_, inputs.get(),
97 tensorflow_spec.dataset_token_tensor_name(), &total_example_count,
98 &total_example_size_bytes, &example_iterator_status);
99 // If the constant inputs are provided and the flag is enabled, add these to
100 // the map of TFLite inputs.
101 if (!tensorflow_spec.constant_inputs().empty()) {
102 FCP_LOG(INFO) << "***** constant inputs is not empty";
103 if (!flags_.support_constant_tf_inputs()) {
104 return PlanResult(
105 PlanOutcome::kInvalidArgument,
106 absl::InternalError(
107 "Cannot run constant_inputs when experiment is disabled."));
108 } else {
109 for (const auto& [name, tensor_proto] :
110 tensorflow_spec.constant_inputs()) {
111 tensorflow::Tensor input_tensor;
112 if (!input_tensor.FromProto(tensor_proto)) {
113 FCP_LOG(ERROR) << "unable to convert constant_input to tensor: "
114 << tensor_proto.DebugString();
115 return PlanResult(PlanOutcome::kInvalidArgument,
116 absl::InternalError(
117 "Unable to convert constant_input to tensor"));
118 }
119 // Convert Tensor to TFLite represenation and add this as a string to
120 // inputs.
121 if (input_tensor.dtype() == tensorflow::DT_STRING) {
122 tensorflow::tstring str_data =
123 input_tensor.scalar<tensorflow::tstring>()();
124 inputs->insert({name, std::string(str_data.data(), str_data.size())});
125 } else {
126 FCP_LOG(ERROR) << "Constant input tensor is not a string tensor. "
127 "Currently only string tensors are supported.";
128 return PlanResult(
129 PlanOutcome::kInvalidArgument,
130 absl::InternalError("Only string tensors are supported"));
131 }
132 }
133 }
134 }
135 absl::StatusOr<std::unique_ptr<TfLiteWrapper>> tflite_wrapper =
136 TfLiteWrapper::Create(model, should_abort_, *timing_config_, log_manager_,
137 std::move(inputs), output_names,
138 CreateOptions(flags_),
139 flags_.num_threads_for_tflite());
140 FCP_LOG(INFO) << "***** create tflite wrapper";
141
142 if (!tflite_wrapper.ok()) {
143 return PlanResult(PlanOutcome::kTensorflowError, tflite_wrapper.status());
144 }
145 // Start running the plan.
146 absl::StatusOr<OutputTensors> output = (*tflite_wrapper)->Run();
147 PlanResult plan_result = CreatePlanResultFromOutput(
148 std::move(output), &total_example_count, &total_example_size_bytes,
149 example_iterator_status.GetStatus());
150 return plan_result;
151 }
152
153 } // namespace engine
154 } // namespace client
155 } // namespace fcp
156