1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include <string> 16 #include <utility> 17 18 #include "absl/strings/numbers.h" 19 #include "absl/strings/str_split.h" 20 #include "tensorflow/core/framework/common_shape_fns.h" 21 #include "tensorflow/core/framework/op.h" 22 #include "tensorflow/core/framework/op_kernel.h" 23 #include "tensorflow/core/framework/register_types.h" 24 #include "tensorflow/core/framework/resource_handle.h" 25 #include "tensorflow/core/framework/resource_mgr.h" 26 #include "tensorflow/core/framework/shape_inference.h" 27 #include "tensorflow/core/framework/tensor_shape.h" 28 #include "tensorflow/core/platform/errors.h" 29 #include "tensorflow/core/platform/mutex.h" 30 #include "tensorflow/core/platform/tstring.h" 31 #include "tensorflow/lite/c/c_api_types.h" 32 #include "tensorflow/lite/c/common.h" 33 #include "tensorflow/lite/core/subgraph.h" 34 #include "tensorflow/lite/delegates/flex/buffer_map_util.h" 35 #include "tensorflow/lite/delegates/flex/subgraph_resource.h" 36 #include "tensorflow/lite/delegates/flex/util.h" 37 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" 38 #include "tensorflow/lite/kernels/kernel_util.h" 39 #include "tensorflow/lite/string_util.h" 40 41 namespace tensorflow { 42 43 namespace { 44 constexpr int kTfLiteSubgraphResource = 0; 45 } 46 47 REGISTER_OP("TfLiteSubgraphExecute") 48 .Input("subgraph_key: string") 49 .Input("args: Tin") 50 .Output("output: Tout") 51 .Attr("Tin: list(type) >= 0") 52 .Attr("Tout: list(type) >= 0") 53 .SetShapeFn(shape_inference::UnknownShape); 54 55 // The `TfLiteSubgraphExecute` executes a tflite subgraph with the designated 56 // inputs. This op will first look up the tflite subgraph from TF resource 57 // manager based on the resource name stored on the first input, and then it 58 // will call that specific subgraph with the remaining arguments. The first 59 // input of this op is always a scalar string, which denotes the name of the 60 // subgraph resource. The remaining inputs will be fed to the subgraph as 61 // inputs, so the caller needs to ensure that the remaining inputs match with 62 // the subgraph's expected inputs. This is currently WIP/experimental and 63 // subject to change. 64 class TfLiteSubgraphExecute : public OpKernel { 65 public: TfLiteSubgraphExecute(OpKernelConstruction * ctx)66 explicit TfLiteSubgraphExecute(OpKernelConstruction* ctx) 67 : OpKernel(ctx), tfl_tensors_need_allocation_(true) {} 68 Compute(OpKernelContext * ctx)69 void Compute(OpKernelContext* ctx) override { 70 // Fetch the TF Lite subgraph to execute. 71 tflite::flex::TFLiteSubgraphResource* resource = nullptr; 72 OP_REQUIRES_OK( 73 ctx, 74 ctx->resource_manager()->Lookup<tflite::flex::TFLiteSubgraphResource>( 75 "flex", ctx->input(kTfLiteSubgraphResource).flat<tstring>()(0), 76 &resource)); 77 tensorflow::core::ScopedUnref unref_resource(resource); 78 79 // Try to acquire a mutex lock from this resource. This is because tflite 80 // subgraph is not thread-safe and we need to guarantee exclusive access to 81 // it. 82 mutex_lock lock(resource->GetExclusiveLock()); 83 tflite::Subgraph& subgraph_selected = resource->GetSubgraphResource(); 84 85 OP_REQUIRES(ctx, ctx->num_inputs() == subgraph_selected.inputs().size() + 1, 86 errors::InvalidArgument("TF Lite subgraph expects ", 87 subgraph_selected.inputs().size(), 88 " inputs, but received ", 89 ctx->num_inputs() - 1, ".")); 90 91 // Resize input tensors if necessary. 92 ResizeInputTensor(ctx, subgraph_selected); 93 94 if (tfl_tensors_need_allocation_) { 95 OP_REQUIRES(ctx, subgraph_selected.AllocateTensors() == kTfLiteOk, 96 errors::Internal("Failed to call allocate tensors")); 97 tfl_tensors_need_allocation_ = false; 98 } 99 100 // Copy input tensors to subgraph. 101 SetSubgraphInput(ctx, subgraph_selected, resource->GetFlexDelegate()); 102 103 OP_REQUIRES(ctx, subgraph_selected.Invoke() == kTfLiteOk, 104 errors::Internal("Failed to invoke tflite subgraph")); 105 106 // Copy tflite results. 107 CopyTFLiteSubgraphResult(ctx, subgraph_selected); 108 } 109 110 private: ResizeInputTensor(OpKernelContext * ctx,tflite::Subgraph & subgraph_selected)111 void ResizeInputTensor(OpKernelContext* ctx, 112 tflite::Subgraph& subgraph_selected) { 113 for (int i = 0; i < subgraph_selected.inputs().size(); ++i) { 114 // Shift index by 1 since the first input is always the resource name. 115 const Tensor& tf_tensor = ctx->input(i + 1); 116 TfLiteTensor* subgraph_input = 117 subgraph_selected.tensor(subgraph_selected.inputs()[i]); 118 119 bool need_resize = false; 120 for (int dim = 0; dim < tf_tensor.shape().dims(); dim++) { 121 if (tf_tensor.shape().dim_size(dim) != 122 subgraph_input->dims->data[dim]) { 123 need_resize = true; 124 break; 125 } 126 } 127 if (need_resize) { 128 std::vector<int> new_shape; 129 for (auto dim : tf_tensor.shape().dim_sizes()) { 130 new_shape.push_back(dim); 131 } 132 tfl_tensors_need_allocation_ = true; 133 OP_REQUIRES(ctx, 134 subgraph_selected.ResizeInputTensor( 135 subgraph_selected.inputs()[i], new_shape) == kTfLiteOk, 136 errors::Internal("Failed to resize tflite tensor")); 137 } 138 } 139 } 140 SetSubgraphInput(OpKernelContext * ctx,tflite::Subgraph & subgraph_selected,TfLiteDelegate * flex_delegate) const141 void SetSubgraphInput(OpKernelContext* ctx, 142 tflite::Subgraph& subgraph_selected, 143 TfLiteDelegate* flex_delegate) const { 144 auto InitializeVariantOrResource = [flex_delegate]( 145 const Tensor& tf_tensor, 146 TfLiteTensor* subgraph_input) { 147 // The code here initializes the TfLiteTensor which points the data field 148 // to the original TF resource or variant tensor. This requires the TF 149 // tensor's lifetime must extend beyond the execution of callee subgraph. 150 // TODO(b/179094265): This is an experimental implementation, subject to 151 // change. This can be re-implemented with life cycle management 152 // mechanism like reference counting. 153 const size_t required_bytes = sizeof(tensorflow::Tensor**); 154 const tensorflow::Tensor** tf_tensor_ptr = 155 reinterpret_cast<const tensorflow::Tensor**>(malloc(required_bytes)); 156 *tf_tensor_ptr = &tf_tensor; 157 158 TfLiteTensorDataFree(subgraph_input); 159 subgraph_input->data.raw = reinterpret_cast<char*>(tf_tensor_ptr); 160 subgraph_input->bytes = required_bytes; 161 subgraph_input->data_is_stale = true; 162 subgraph_input->delegate = flex_delegate; 163 }; 164 165 for (int i = 0; i < subgraph_selected.inputs().size(); ++i) { 166 const Tensor& tf_tensor = ctx->input(i + 1); 167 TfLiteTensor* subgraph_input = 168 subgraph_selected.tensor(subgraph_selected.inputs()[i]); 169 170 if (subgraph_input->type == kTfLiteString) { 171 OP_REQUIRES(ctx, tf_tensor.dtype() == tensorflow::DT_STRING, 172 errors::InvalidArgument("Tensor doesn't have string type")); 173 tflite::DynamicBuffer dynamic_buffer; 174 auto tf_data = tf_tensor.flat<tensorflow::tstring>(); 175 for (int i = 0; i < tf_tensor.NumElements(); ++i) { 176 dynamic_buffer.AddString(tf_data(i).data(), tf_data(i).size()); 177 } 178 179 dynamic_buffer.WriteToTensor(subgraph_input, /*new_shape=*/nullptr); 180 } else if (subgraph_input->type == kTfLiteResource) { 181 // Here we will try to parse the input tensor handle to see if it 182 // contains a valid TF lite resource ID. If not, then we know that the 183 // input is a TF resource tensor. 184 tensorflow::ResourceHandle handle = 185 tf_tensor.flat<tensorflow::ResourceHandle>()(0); 186 if (!tflite::flex::GetTfLiteResourceTensorFromResourceHandle( 187 handle, subgraph_input)) { 188 InitializeVariantOrResource(tf_tensor, subgraph_input); 189 } 190 } else if (subgraph_input->type == kTfLiteVariant) { 191 InitializeVariantOrResource(tf_tensor, subgraph_input); 192 } else { 193 tensorflow::StringPiece tensor_data = tf_tensor.tensor_data(); 194 OP_REQUIRES(ctx, subgraph_input->bytes == tensor_data.size(), 195 errors::Internal("tensor size doesn't match")); 196 // TODO(b/181352924): This could incur some overhead in memory copy. 197 // Optimize this away in the future. 198 memcpy(subgraph_input->data.raw, tensor_data.data(), 199 tensor_data.size()); 200 } 201 } 202 } 203 CopyTFLiteSubgraphResult(OpKernelContext * ctx,tflite::Subgraph & subgraph_selected) const204 void CopyTFLiteSubgraphResult(OpKernelContext* ctx, 205 tflite::Subgraph& subgraph_selected) const { 206 for (int i = 0; i < subgraph_selected.outputs().size(); ++i) { 207 OP_REQUIRES(ctx, 208 subgraph_selected.EnsureTensorDataIsReadable( 209 subgraph_selected.outputs()[i]) == kTfLiteOk, 210 errors::Internal("TF lite subgraph output is not readable")); 211 // Create an output tensor. 212 TfLiteTensor* subgraph_output = 213 subgraph_selected.tensor(subgraph_selected.outputs()[i]); 214 215 // Forcing a memcpy for each tensor output from the called dataset 216 // subgraph. This is because the callee subgraph might be invoked 217 // repeatedly for each item in the dataset, and the result TfLiteTensor's 218 // data should be immediately copied into tensorflow::Tensor. 219 Tensor tensor; 220 OP_REQUIRES_OK( 221 ctx, tflite::flex::SetTfTensorFromTfLite(subgraph_output, &tensor, 222 /*allow_reusing=*/false)); 223 ctx->set_output(i, std::move(tensor)); 224 } 225 } 226 227 // Tells if the target subgraph needs to invoko AllocateTensors(). 228 bool tfl_tensors_need_allocation_; 229 }; 230 231 REGISTER_KERNEL_BUILDER(Name("TfLiteSubgraphExecute").Device(DEVICE_CPU), 232 TfLiteSubgraphExecute); 233 234 } // namespace tensorflow 235