xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/flex/delegate.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/flex/delegate.h"
16 
17 #include <memory>
18 #include <utility>
19 #include <vector>
20 
21 #include "absl/strings/str_cat.h"
22 #include "tensorflow/core/framework/variant.h"
23 #include "tensorflow/core/lib/core/status.h"
24 #include "tensorflow/lite/c/common.h"
25 #include "tensorflow/lite/context_util.h"
26 #include "tensorflow/lite/core/macros.h"
27 #include "tensorflow/lite/delegates/flex/buffer_map.h"
28 #include "tensorflow/lite/delegates/flex/kernel.h"
29 #include "tensorflow/lite/delegates/flex/util.h"
30 #include "tensorflow/lite/minimal_logging.h"
31 #include "tensorflow/lite/string_util.h"
32 #include "tensorflow/lite/util.h"
33 
34 namespace tflite {
35 
Create(std::unique_ptr<FlexDelegate> base_delegate)36 TfLiteDelegateUniquePtr FlexDelegate::Create(
37     std::unique_ptr<FlexDelegate> base_delegate) {
38   TFLITE_LOG_PROD_ONCE(TFLITE_LOG_INFO,
39                        "Created TensorFlow Lite delegate for select TF ops.");
40   if (base_delegate == nullptr) {
41     base_delegate.reset(new FlexDelegate());
42   }
43   auto flex_delegate = TfLiteDelegateFactory::Create(std::move(base_delegate));
44   flex_delegate->CopyFromBufferHandle =
45       [](TfLiteContext* context, TfLiteDelegate* delegate,
46          TfLiteBufferHandle buffer_handle,
47          TfLiteTensor* tensor) -> TfLiteStatus {
48     return reinterpret_cast<FlexDelegate*>(delegate->data_)
49         ->CopyFromBufferHandle(context, buffer_handle, tensor);
50   };
51   flex_delegate->flags |= kTfLiteDelegateFlagsAllowDynamicTensors;
52   reinterpret_cast<FlexDelegate*>(flex_delegate->data_)->base_delegate_ =
53       flex_delegate.get();
54   return flex_delegate;
55 }
56 
Initialize(TfLiteContext * context)57 TfLiteStatus FlexDelegate::Initialize(TfLiteContext* context) {
58   // If the TensorFlow Lite thread count is explicitly configured, use it,
59   // otherwise rely on the default TensorFlow threading behavior.
60   tensorflow::SessionOptions session_options;
61   // We don't run multiple ops at the same time, so prefer using
62   // 1 thread for inter-op parallelism.
63   // Negative value means all are done on the caller thread.
64   session_options.config.set_inter_op_parallelism_threads(-1);
65   if (context->recommended_num_threads > 0) {
66     session_options.config.set_intra_op_parallelism_threads(
67         context->recommended_num_threads);
68   }
69 
70   auto status = delegate_data_.Prepare(
71       session_options, reinterpret_cast<Subgraph*>(context->impl_),
72       base_delegate_);
73   if (!status.ok()) {
74     TF_LITE_KERNEL_LOG(context, "Failed to initialize TensorFlow context: %s",
75                        status.error_message().c_str());
76     return kTfLiteError;
77   }
78 
79   // Initializes the cancellation manager.
80   if (!cancellation_manager_) {
81     cancellation_manager_ = std::make_unique<tensorflow::CancellationManager>();
82     delegate_data_.SetCancellationManager(cancellation_manager_.get());
83   }
84 
85   return kTfLiteOk;
86 }
87 
Name() const88 const char* FlexDelegate::Name() const {
89   static constexpr char kName[] = "TfLiteFlexDelegate";
90   return kName;
91 }
92 
IsNodeSupportedByDelegate(const TfLiteRegistration * registration,const TfLiteNode * node,TfLiteContext * context) const93 bool FlexDelegate::IsNodeSupportedByDelegate(
94     const TfLiteRegistration* registration, const TfLiteNode* node,
95     TfLiteContext* context) const {
96   return IsFlexOp(registration->custom_name);
97 }
98 
99 std::unique_ptr<SimpleDelegateKernelInterface>
CreateDelegateKernelInterface()100 FlexDelegate::CreateDelegateKernelInterface() {
101   return std::unique_ptr<SimpleDelegateKernelInterface>(
102       new tflite::flex::DelegateKernel());
103 }
104 
CopyFromBufferHandle(TfLiteContext * context,TfLiteBufferHandle buffer_handle,TfLiteTensor * output)105 TfLiteStatus FlexDelegate::CopyFromBufferHandle(
106     TfLiteContext* context, TfLiteBufferHandle buffer_handle,
107     TfLiteTensor* output) {
108   flex::BufferMap* buffer_map = delegate_data_.GetBufferMap(context);
109 
110   if (!buffer_map->HasTensor(buffer_handle)) {
111     TF_LITE_KERNEL_LOG(context, "Invalid tensor index %d.", buffer_handle);
112     return kTfLiteError;
113   }
114 
115   tensorflow::Tensor t = buffer_map->GetTensor(buffer_handle);
116 
117   if (output->type == kTfLiteString) {
118     if (t.dtype() != tensorflow::DT_STRING) {
119       TF_LITE_KERNEL_LOG(context,
120                          "Inconsistent type for TF string tensor index %d.",
121                          buffer_handle);
122       return kTfLiteError;
123     }
124     DynamicBuffer dynamic_buffer;
125 
126     auto tf_data = t.flat<tensorflow::tstring>();
127     for (int i = 0; i < t.NumElements(); ++i) {
128       dynamic_buffer.AddString(tf_data(i).data(), tf_data(i).size());
129     }
130 
131     dynamic_buffer.WriteToTensor(output, /*new_shape=*/nullptr);
132     return kTfLiteOk;
133   }
134 
135   // TODO(b/179094265): This is an experimental implementation, subject to
136   // change. This can be re-implemented with life cycle management mechanism
137   // like reference counting.
138   // When copying resource and variant tensors from Flex delegate to TensorFlow
139   // Lite tensors, the CopyFromBufferHandle method of the Flex delegate is
140   // invoked and it will store the `data` field of the given TensorFlow Lite
141   // tensor and pass the TensorFlow Lite tensor pointer. Copying the `data`
142   // field will act as passing pointers between TensorFlow Lite tensors.
143   //
144   // The life cycle of the pointer will be managed by the reference counting in
145   // the TensorFlow world and the pointer will be freed when all the buffer
146   // maps, who own it, are gone.
147   if (IsResourceOrVariant(output)) {
148     const size_t required_bytes = sizeof(tensorflow::Tensor**);
149     const tensorflow::Tensor** tf_tensor_ptr =
150         reinterpret_cast<const tensorflow::Tensor**>(malloc(required_bytes));
151     *tf_tensor_ptr = buffer_map->GetTensorPtr(buffer_handle);
152 
153     TfLiteTensorDataFree(output);
154     output->data.raw = reinterpret_cast<char*>(tf_tensor_ptr);
155     output->bytes = required_bytes;
156     output->data_is_stale = true;
157     return kTfLiteOk;
158   }
159 
160   tensorflow::StringPiece t_data = t.tensor_data();
161 
162   if (output->bytes != t_data.size()) {
163     TF_LITE_KERNEL_LOG(context,
164                        absl::StrCat("The given ", output->bytes,
165                                     " bytes are not enough to store "
166                                     "TensorFlow's aligned buffer of size ",
167                                     t_data.size(), " bytes.")
168                            .c_str());
169     return kTfLiteError;
170   }
171 
172   memcpy(output->data.raw, t_data.data(), t_data.size());
173   return kTfLiteOk;
174 }
175 
Cancel()176 void FlexDelegate::Cancel() { cancellation_manager_->StartCancel(); }
177 
HasCancelled(void * data)178 bool FlexDelegate::HasCancelled(void* data) {
179   if (data == nullptr) {
180     return false;
181   }
182 
183   auto* flex_delegate = static_cast<FlexDelegate*>(data);
184   return flex_delegate->cancellation_manager_->IsCancelled();
185 }
186 
187 }  // namespace tflite
188