xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/flex/tflite_subgraph_execute.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <string>
16 #include <utility>
17 
18 #include "absl/strings/numbers.h"
19 #include "absl/strings/str_split.h"
20 #include "tensorflow/core/framework/common_shape_fns.h"
21 #include "tensorflow/core/framework/op.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/register_types.h"
24 #include "tensorflow/core/framework/resource_handle.h"
25 #include "tensorflow/core/framework/resource_mgr.h"
26 #include "tensorflow/core/framework/shape_inference.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/platform/errors.h"
29 #include "tensorflow/core/platform/mutex.h"
30 #include "tensorflow/core/platform/tstring.h"
31 #include "tensorflow/lite/c/c_api_types.h"
32 #include "tensorflow/lite/c/common.h"
33 #include "tensorflow/lite/core/subgraph.h"
34 #include "tensorflow/lite/delegates/flex/buffer_map_util.h"
35 #include "tensorflow/lite/delegates/flex/subgraph_resource.h"
36 #include "tensorflow/lite/delegates/flex/util.h"
37 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
38 #include "tensorflow/lite/kernels/kernel_util.h"
39 #include "tensorflow/lite/string_util.h"
40 
41 namespace tensorflow {
42 
43 namespace {
44 constexpr int kTfLiteSubgraphResource = 0;
45 }
46 
47 REGISTER_OP("TfLiteSubgraphExecute")
48     .Input("subgraph_key: string")
49     .Input("args: Tin")
50     .Output("output: Tout")
51     .Attr("Tin: list(type) >= 0")
52     .Attr("Tout: list(type) >= 0")
53     .SetShapeFn(shape_inference::UnknownShape);
54 
55 // The `TfLiteSubgraphExecute` executes a tflite subgraph with the designated
56 // inputs. This op will first look up the tflite subgraph from TF resource
57 // manager based on the resource name stored on the first input, and then it
58 // will call that specific subgraph with the remaining arguments. The first
59 // input of this op is always a scalar string, which denotes the name of the
60 // subgraph resource. The remaining inputs will be fed to the subgraph as
61 // inputs, so the caller needs to ensure that the remaining inputs match with
62 // the subgraph's expected inputs. This is currently WIP/experimental and
63 // subject to change.
64 class TfLiteSubgraphExecute : public OpKernel {
65  public:
TfLiteSubgraphExecute(OpKernelConstruction * ctx)66   explicit TfLiteSubgraphExecute(OpKernelConstruction* ctx)
67       : OpKernel(ctx), tfl_tensors_need_allocation_(true) {}
68 
Compute(OpKernelContext * ctx)69   void Compute(OpKernelContext* ctx) override {
70     // Fetch the TF Lite subgraph to execute.
71     tflite::flex::TFLiteSubgraphResource* resource = nullptr;
72     OP_REQUIRES_OK(
73         ctx,
74         ctx->resource_manager()->Lookup<tflite::flex::TFLiteSubgraphResource>(
75             "flex", ctx->input(kTfLiteSubgraphResource).flat<tstring>()(0),
76             &resource));
77     tensorflow::core::ScopedUnref unref_resource(resource);
78 
79     // Try to acquire a mutex lock from this resource. This is because tflite
80     // subgraph is not thread-safe and we need to guarantee exclusive access to
81     // it.
82     mutex_lock lock(resource->GetExclusiveLock());
83     tflite::Subgraph& subgraph_selected = resource->GetSubgraphResource();
84 
85     OP_REQUIRES(ctx, ctx->num_inputs() == subgraph_selected.inputs().size() + 1,
86                 errors::InvalidArgument("TF Lite subgraph expects ",
87                                         subgraph_selected.inputs().size(),
88                                         " inputs, but received ",
89                                         ctx->num_inputs() - 1, "."));
90 
91     // Resize input tensors if necessary.
92     ResizeInputTensor(ctx, subgraph_selected);
93 
94     if (tfl_tensors_need_allocation_) {
95       OP_REQUIRES(ctx, subgraph_selected.AllocateTensors() == kTfLiteOk,
96                   errors::Internal("Failed to call allocate tensors"));
97       tfl_tensors_need_allocation_ = false;
98     }
99 
100     // Copy input tensors to subgraph.
101     SetSubgraphInput(ctx, subgraph_selected, resource->GetFlexDelegate());
102 
103     OP_REQUIRES(ctx, subgraph_selected.Invoke() == kTfLiteOk,
104                 errors::Internal("Failed to invoke tflite subgraph"));
105 
106     // Copy tflite results.
107     CopyTFLiteSubgraphResult(ctx, subgraph_selected);
108   }
109 
110  private:
ResizeInputTensor(OpKernelContext * ctx,tflite::Subgraph & subgraph_selected)111   void ResizeInputTensor(OpKernelContext* ctx,
112                          tflite::Subgraph& subgraph_selected) {
113     for (int i = 0; i < subgraph_selected.inputs().size(); ++i) {
114       // Shift index by 1 since the first input is always the resource name.
115       const Tensor& tf_tensor = ctx->input(i + 1);
116       TfLiteTensor* subgraph_input =
117           subgraph_selected.tensor(subgraph_selected.inputs()[i]);
118 
119       bool need_resize = false;
120       for (int dim = 0; dim < tf_tensor.shape().dims(); dim++) {
121         if (tf_tensor.shape().dim_size(dim) !=
122             subgraph_input->dims->data[dim]) {
123           need_resize = true;
124           break;
125         }
126       }
127       if (need_resize) {
128         std::vector<int> new_shape;
129         for (auto dim : tf_tensor.shape().dim_sizes()) {
130           new_shape.push_back(dim);
131         }
132         tfl_tensors_need_allocation_ = true;
133         OP_REQUIRES(ctx,
134                     subgraph_selected.ResizeInputTensor(
135                         subgraph_selected.inputs()[i], new_shape) == kTfLiteOk,
136                     errors::Internal("Failed to resize tflite tensor"));
137       }
138     }
139   }
140 
SetSubgraphInput(OpKernelContext * ctx,tflite::Subgraph & subgraph_selected,TfLiteDelegate * flex_delegate) const141   void SetSubgraphInput(OpKernelContext* ctx,
142                         tflite::Subgraph& subgraph_selected,
143                         TfLiteDelegate* flex_delegate) const {
144     auto InitializeVariantOrResource = [flex_delegate](
145                                            const Tensor& tf_tensor,
146                                            TfLiteTensor* subgraph_input) {
147       // The code here initializes the TfLiteTensor which points the data field
148       // to the original TF resource or variant tensor. This requires the TF
149       // tensor's lifetime must extend beyond the execution of callee subgraph.
150       // TODO(b/179094265): This is an experimental implementation, subject to
151       // change. This can be re-implemented with life cycle management
152       // mechanism like reference counting.
153       const size_t required_bytes = sizeof(tensorflow::Tensor**);
154       const tensorflow::Tensor** tf_tensor_ptr =
155           reinterpret_cast<const tensorflow::Tensor**>(malloc(required_bytes));
156       *tf_tensor_ptr = &tf_tensor;
157 
158       TfLiteTensorDataFree(subgraph_input);
159       subgraph_input->data.raw = reinterpret_cast<char*>(tf_tensor_ptr);
160       subgraph_input->bytes = required_bytes;
161       subgraph_input->data_is_stale = true;
162       subgraph_input->delegate = flex_delegate;
163     };
164 
165     for (int i = 0; i < subgraph_selected.inputs().size(); ++i) {
166       const Tensor& tf_tensor = ctx->input(i + 1);
167       TfLiteTensor* subgraph_input =
168           subgraph_selected.tensor(subgraph_selected.inputs()[i]);
169 
170       if (subgraph_input->type == kTfLiteString) {
171         OP_REQUIRES(ctx, tf_tensor.dtype() == tensorflow::DT_STRING,
172                     errors::InvalidArgument("Tensor doesn't have string type"));
173         tflite::DynamicBuffer dynamic_buffer;
174         auto tf_data = tf_tensor.flat<tensorflow::tstring>();
175         for (int i = 0; i < tf_tensor.NumElements(); ++i) {
176           dynamic_buffer.AddString(tf_data(i).data(), tf_data(i).size());
177         }
178 
179         dynamic_buffer.WriteToTensor(subgraph_input, /*new_shape=*/nullptr);
180       } else if (subgraph_input->type == kTfLiteResource) {
181         // Here we will try to parse the input tensor handle to see if it
182         // contains a valid TF lite resource ID. If not, then we know that the
183         // input is a TF resource tensor.
184         tensorflow::ResourceHandle handle =
185             tf_tensor.flat<tensorflow::ResourceHandle>()(0);
186         if (!tflite::flex::GetTfLiteResourceTensorFromResourceHandle(
187                 handle, subgraph_input)) {
188           InitializeVariantOrResource(tf_tensor, subgraph_input);
189         }
190       } else if (subgraph_input->type == kTfLiteVariant) {
191         InitializeVariantOrResource(tf_tensor, subgraph_input);
192       } else {
193         tensorflow::StringPiece tensor_data = tf_tensor.tensor_data();
194         OP_REQUIRES(ctx, subgraph_input->bytes == tensor_data.size(),
195                     errors::Internal("tensor size doesn't match"));
196         // TODO(b/181352924): This could incur some overhead in memory copy.
197         // Optimize this away in the future.
198         memcpy(subgraph_input->data.raw, tensor_data.data(),
199                tensor_data.size());
200       }
201     }
202   }
203 
CopyTFLiteSubgraphResult(OpKernelContext * ctx,tflite::Subgraph & subgraph_selected) const204   void CopyTFLiteSubgraphResult(OpKernelContext* ctx,
205                                 tflite::Subgraph& subgraph_selected) const {
206     for (int i = 0; i < subgraph_selected.outputs().size(); ++i) {
207       OP_REQUIRES(ctx,
208                   subgraph_selected.EnsureTensorDataIsReadable(
209                       subgraph_selected.outputs()[i]) == kTfLiteOk,
210                   errors::Internal("TF lite subgraph output is not readable"));
211       // Create an output tensor.
212       TfLiteTensor* subgraph_output =
213           subgraph_selected.tensor(subgraph_selected.outputs()[i]);
214 
215       // Forcing a memcpy for each tensor output from the called dataset
216       // subgraph. This is because the callee subgraph might be invoked
217       // repeatedly for each item in the dataset, and the result TfLiteTensor's
218       // data should be immediately copied into tensorflow::Tensor.
219       Tensor tensor;
220       OP_REQUIRES_OK(
221           ctx, tflite::flex::SetTfTensorFromTfLite(subgraph_output, &tensor,
222                                                    /*allow_reusing=*/false));
223       ctx->set_output(i, std::move(tensor));
224     }
225   }
226 
227   // Tells if the target subgraph needs to invoko AllocateTensors().
228   bool tfl_tensors_need_allocation_;
229 };
230 
231 REGISTER_KERNEL_BUILDER(Name("TfLiteSubgraphExecute").Device(DEVICE_CPU),
232                         TfLiteSubgraphExecute);
233 
234 }  // namespace tensorflow
235