1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/flex/delegate.h"
16
17 #include <memory>
18 #include <utility>
19 #include <vector>
20
21 #include "absl/strings/str_cat.h"
22 #include "tensorflow/core/framework/variant.h"
23 #include "tensorflow/core/lib/core/status.h"
24 #include "tensorflow/lite/c/common.h"
25 #include "tensorflow/lite/context_util.h"
26 #include "tensorflow/lite/core/macros.h"
27 #include "tensorflow/lite/delegates/flex/buffer_map.h"
28 #include "tensorflow/lite/delegates/flex/kernel.h"
29 #include "tensorflow/lite/delegates/flex/util.h"
30 #include "tensorflow/lite/minimal_logging.h"
31 #include "tensorflow/lite/string_util.h"
32 #include "tensorflow/lite/util.h"
33
34 namespace tflite {
35
Create(std::unique_ptr<FlexDelegate> base_delegate)36 TfLiteDelegateUniquePtr FlexDelegate::Create(
37 std::unique_ptr<FlexDelegate> base_delegate) {
38 TFLITE_LOG_PROD_ONCE(TFLITE_LOG_INFO,
39 "Created TensorFlow Lite delegate for select TF ops.");
40 if (base_delegate == nullptr) {
41 base_delegate.reset(new FlexDelegate());
42 }
43 auto flex_delegate = TfLiteDelegateFactory::Create(std::move(base_delegate));
44 flex_delegate->CopyFromBufferHandle =
45 [](TfLiteContext* context, TfLiteDelegate* delegate,
46 TfLiteBufferHandle buffer_handle,
47 TfLiteTensor* tensor) -> TfLiteStatus {
48 return reinterpret_cast<FlexDelegate*>(delegate->data_)
49 ->CopyFromBufferHandle(context, buffer_handle, tensor);
50 };
51 flex_delegate->flags |= kTfLiteDelegateFlagsAllowDynamicTensors;
52 reinterpret_cast<FlexDelegate*>(flex_delegate->data_)->base_delegate_ =
53 flex_delegate.get();
54 return flex_delegate;
55 }
56
Initialize(TfLiteContext * context)57 TfLiteStatus FlexDelegate::Initialize(TfLiteContext* context) {
58 // If the TensorFlow Lite thread count is explicitly configured, use it,
59 // otherwise rely on the default TensorFlow threading behavior.
60 tensorflow::SessionOptions session_options;
61 // We don't run multiple ops at the same time, so prefer using
62 // 1 thread for inter-op parallelism.
63 // Negative value means all are done on the caller thread.
64 session_options.config.set_inter_op_parallelism_threads(-1);
65 if (context->recommended_num_threads > 0) {
66 session_options.config.set_intra_op_parallelism_threads(
67 context->recommended_num_threads);
68 }
69
70 auto status = delegate_data_.Prepare(
71 session_options, reinterpret_cast<Subgraph*>(context->impl_),
72 base_delegate_);
73 if (!status.ok()) {
74 TF_LITE_KERNEL_LOG(context, "Failed to initialize TensorFlow context: %s",
75 status.error_message().c_str());
76 return kTfLiteError;
77 }
78
79 // Initializes the cancellation manager.
80 if (!cancellation_manager_) {
81 cancellation_manager_ = std::make_unique<tensorflow::CancellationManager>();
82 delegate_data_.SetCancellationManager(cancellation_manager_.get());
83 }
84
85 return kTfLiteOk;
86 }
87
Name() const88 const char* FlexDelegate::Name() const {
89 static constexpr char kName[] = "TfLiteFlexDelegate";
90 return kName;
91 }
92
IsNodeSupportedByDelegate(const TfLiteRegistration * registration,const TfLiteNode * node,TfLiteContext * context) const93 bool FlexDelegate::IsNodeSupportedByDelegate(
94 const TfLiteRegistration* registration, const TfLiteNode* node,
95 TfLiteContext* context) const {
96 return IsFlexOp(registration->custom_name);
97 }
98
99 std::unique_ptr<SimpleDelegateKernelInterface>
CreateDelegateKernelInterface()100 FlexDelegate::CreateDelegateKernelInterface() {
101 return std::unique_ptr<SimpleDelegateKernelInterface>(
102 new tflite::flex::DelegateKernel());
103 }
104
CopyFromBufferHandle(TfLiteContext * context,TfLiteBufferHandle buffer_handle,TfLiteTensor * output)105 TfLiteStatus FlexDelegate::CopyFromBufferHandle(
106 TfLiteContext* context, TfLiteBufferHandle buffer_handle,
107 TfLiteTensor* output) {
108 flex::BufferMap* buffer_map = delegate_data_.GetBufferMap(context);
109
110 if (!buffer_map->HasTensor(buffer_handle)) {
111 TF_LITE_KERNEL_LOG(context, "Invalid tensor index %d.", buffer_handle);
112 return kTfLiteError;
113 }
114
115 tensorflow::Tensor t = buffer_map->GetTensor(buffer_handle);
116
117 if (output->type == kTfLiteString) {
118 if (t.dtype() != tensorflow::DT_STRING) {
119 TF_LITE_KERNEL_LOG(context,
120 "Inconsistent type for TF string tensor index %d.",
121 buffer_handle);
122 return kTfLiteError;
123 }
124 DynamicBuffer dynamic_buffer;
125
126 auto tf_data = t.flat<tensorflow::tstring>();
127 for (int i = 0; i < t.NumElements(); ++i) {
128 dynamic_buffer.AddString(tf_data(i).data(), tf_data(i).size());
129 }
130
131 dynamic_buffer.WriteToTensor(output, /*new_shape=*/nullptr);
132 return kTfLiteOk;
133 }
134
135 // TODO(b/179094265): This is an experimental implementation, subject to
136 // change. This can be re-implemented with life cycle management mechanism
137 // like reference counting.
138 // When copying resource and variant tensors from Flex delegate to TensorFlow
139 // Lite tensors, the CopyFromBufferHandle method of the Flex delegate is
140 // invoked and it will store the `data` field of the given TensorFlow Lite
141 // tensor and pass the TensorFlow Lite tensor pointer. Copying the `data`
142 // field will act as passing pointers between TensorFlow Lite tensors.
143 //
144 // The life cycle of the pointer will be managed by the reference counting in
145 // the TensorFlow world and the pointer will be freed when all the buffer
146 // maps, who own it, are gone.
147 if (IsResourceOrVariant(output)) {
148 const size_t required_bytes = sizeof(tensorflow::Tensor**);
149 const tensorflow::Tensor** tf_tensor_ptr =
150 reinterpret_cast<const tensorflow::Tensor**>(malloc(required_bytes));
151 *tf_tensor_ptr = buffer_map->GetTensorPtr(buffer_handle);
152
153 TfLiteTensorDataFree(output);
154 output->data.raw = reinterpret_cast<char*>(tf_tensor_ptr);
155 output->bytes = required_bytes;
156 output->data_is_stale = true;
157 return kTfLiteOk;
158 }
159
160 tensorflow::StringPiece t_data = t.tensor_data();
161
162 if (output->bytes != t_data.size()) {
163 TF_LITE_KERNEL_LOG(context,
164 absl::StrCat("The given ", output->bytes,
165 " bytes are not enough to store "
166 "TensorFlow's aligned buffer of size ",
167 t_data.size(), " bytes.")
168 .c_str());
169 return kTfLiteError;
170 }
171
172 memcpy(output->data.raw, t_data.data(), t_data.size());
173 return kTfLiteOk;
174 }
175
Cancel()176 void FlexDelegate::Cancel() { cancellation_manager_->StartCancel(); }
177
HasCancelled(void * data)178 bool FlexDelegate::HasCancelled(void* data) {
179 if (data == nullptr) {
180 return false;
181 }
182
183 auto* flex_delegate = static_cast<FlexDelegate*>(data);
184 return flex_delegate->cancellation_manager_->IsCancelled();
185 }
186
187 } // namespace tflite
188