1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/delegate.h"
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <memory>
21 #include <string>
22 #include <thread> // NOLINT(build/c++11)
23 #include <vector>
24
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/types/span.h"
28 #include "tensorflow/lite/builtin_ops.h"
29 #include "tensorflow/lite/c/common.h"
30 #include "tensorflow/lite/delegates/gpu/api.h"
31 #include "tensorflow/lite/delegates/gpu/cl/api.h"
32 #include "tensorflow/lite/delegates/gpu/cl/util.h"
33 #include "tensorflow/lite/delegates/gpu/common/model_builder.h"
34 #include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h"
35 #include "tensorflow/lite/delegates/gpu/common/quantization_util.h"
36 #include "tensorflow/lite/delegates/serialization.h"
37 #include "tensorflow/lite/kernels/kernel_util.h"
38 #include "tensorflow/lite/minimal_logging.h"
39
40 #ifndef CL_DELEGATE_NO_GL
41 #include "tensorflow/lite/delegates/gpu/gl/api2.h"
42 #endif
43
44 namespace tflite {
45 namespace gpu {
46 namespace {
47
48 using delegates::Serialization;
49 using delegates::SerializationParams;
50
51 constexpr char kSerializedDataPrefix[] = "gpuv2_data_";
52
ToPriority(int32_t priority)53 InferencePriority ToPriority(int32_t priority) {
54 switch (priority) {
55 case TFLITE_GPU_INFERENCE_PRIORITY_AUTO:
56 return InferencePriority::AUTO;
57 case TFLITE_GPU_INFERENCE_PRIORITY_MAX_PRECISION:
58 return InferencePriority::MAX_PRECISION;
59 case TFLITE_GPU_INFERENCE_PRIORITY_MIN_LATENCY:
60 return InferencePriority::MIN_LATENCY;
61 case TFLITE_GPU_INFERENCE_PRIORITY_MIN_MEMORY_USAGE:
62 return InferencePriority::MIN_MEMORY_USAGE;
63 }
64 return InferencePriority::UNKNOWN;
65 }
66
ToUsage(int32_t usage)67 InferenceUsage ToUsage(int32_t usage) {
68 switch (usage) {
69 case TFLITE_GPU_INFERENCE_PREFERENCE_FAST_SINGLE_ANSWER:
70 return InferenceUsage::FAST_SINGLE_ANSWER;
71 case TFLITE_GPU_INFERENCE_PREFERENCE_SUSTAINED_SPEED:
72 return InferenceUsage::SUSTAINED_SPEED;
73 }
74 return InferenceUsage::UNKNOWN;
75 }
76
77 // Forward declarations.
78 TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate);
79
80 class Delegate {
81 public:
Delegate(const TfLiteGpuDelegateOptionsV2 * options)82 explicit Delegate(const TfLiteGpuDelegateOptionsV2* options)
83 : num_delegate_kernels_(0) {
84 delegate_.data_ = reinterpret_cast<void*>(this);
85 delegate_.Prepare = DelegatePrepare;
86 delegate_.CopyFromBufferHandle = nullptr;
87 delegate_.CopyToBufferHandle = nullptr;
88 delegate_.FreeBufferHandle = nullptr;
89 delegate_.flags = kTfLiteDelegateFlagsNone;
90 options_ = options ? *options : TfLiteGpuDelegateOptionsV2Default();
91 if (options_.max_delegated_partitions <= 0) {
92 options_.max_delegated_partitions = 1;
93 }
94 if (options_.experimental_flags &
95 TFLITE_GPU_EXPERIMENTAL_FLAGS_ENABLE_SERIALIZATION &&
96 options_.model_token && options_.serialization_dir) {
97 SerializationParams params;
98 params.model_token = options_.model_token;
99 params.cache_dir = options_.serialization_dir;
100 serialization_ = std::make_unique<Serialization>(params);
101 }
102 }
103
tflite_delegate()104 TfLiteDelegate* tflite_delegate() { return &delegate_; }
serialization()105 Serialization* serialization() { return serialization_.get(); }
options() const106 const TfLiteGpuDelegateOptionsV2& options() const { return options_; }
107
IsQuantOpsAllowed() const108 bool IsQuantOpsAllowed() const {
109 return options_.experimental_flags &
110 TFLITE_GPU_EXPERIMENTAL_FLAGS_ENABLE_QUANT;
111 }
MaxDelegatedPartitions() const112 int MaxDelegatedPartitions() const {
113 return options_.max_delegated_partitions;
114 }
num_delegate_kernels() const115 int num_delegate_kernels() const { return num_delegate_kernels_; }
116
117 private:
118 TfLiteDelegate delegate_;
119 TfLiteGpuDelegateOptionsV2 options_;
120 int num_delegate_kernels_ = 0;
121
122 std::unique_ptr<Serialization> serialization_;
123
124 friend class DelegateKernel;
125 };
126
127 // Represent the execution of a subset of nodes on GPU.
128 class DelegateKernel {
129 public:
DelegateKernel(Delegate * delegate)130 explicit DelegateKernel(Delegate* delegate) : delegate_(delegate) {
131 ++delegate_->num_delegate_kernels_;
132 }
~DelegateKernel()133 ~DelegateKernel() { --delegate_->num_delegate_kernels_; }
134
Prepare(TfLiteContext * context,const TfLiteDelegateParams * delegate_params)135 absl::Status Prepare(TfLiteContext* context,
136 const TfLiteDelegateParams* delegate_params) {
137 thread_id_prepare_ = std::this_thread::get_id();
138
139 // Extract TFLite delegate execution plan from the context and convert it
140 // into GraphFloat32.
141 GraphFloat32 graph;
142 std::vector<uint32_t> input_refs;
143 std::vector<uint32_t> output_refs;
144 RETURN_IF_ERROR(InitializeGraph(context, delegate_params, &graph,
145 &input_refs, &output_refs));
146
147 std::unique_ptr<InferenceBuilder> builder;
148 bool graph_is_destroyed;
149 const int experimental_flags = delegate_->options().experimental_flags;
150 if (experimental_flags & TFLITE_GPU_EXPERIMENTAL_FLAGS_CL_ONLY) {
151 RETURN_IF_ERROR(InitializeOpenClApi(&graph, &builder, &graph_is_destroyed,
152 context, delegate_params,
153 delegate_->serialization()));
154 } else if (experimental_flags & TFLITE_GPU_EXPERIMENTAL_FLAGS_GL_ONLY) {
155 RETURN_IF_ERROR(InitializeOpenGlApi(&graph, &builder));
156 } else {
157 // By default, we try CL first & fall back to GL if that fails.
158 absl::Status status =
159 InitializeOpenClApi(&graph, &builder, &graph_is_destroyed, context,
160 delegate_params, delegate_->serialization());
161 if (!status.ok()) {
162 TF_LITE_KERNEL_LOG(context, std::string(status.message()).c_str());
163 TF_LITE_KERNEL_LOG(context, "Falling back to OpenGL");
164
165 // Graph needs to be re-created because it is moved above.
166 GraphFloat32 graph2;
167 if (graph_is_destroyed) {
168 RETURN_IF_ERROR(InitializeGraph(context, delegate_params, &graph2,
169 &input_refs, &output_refs));
170 }
171 RETURN_IF_ERROR(InitializeOpenGlApi(
172 graph_is_destroyed ? &graph2 : &graph, &builder));
173 }
174 }
175
176 // At this point, TFLite hasn't allocated tensors yet, therefore, collect
177 // indices and set all input and output tensors from TFLite later.
178 input_indices_.reserve(input_refs.size());
179 for (uint32_t tensor_index : input_refs) {
180 const int64_t object_index = input_indices_.size();
181 input_indices_.push_back(tensor_index);
182 const TfLiteTensor& tflite_tensor = context->tensors[tensor_index];
183 const DataType data_type = ToDataType(tflite_tensor.type);
184 RETURN_IF_ERROR(builder->SetInputObjectDef(
185 object_index, GetObjectDef(tensor_index, data_type)));
186 }
187 output_indices_.reserve(output_refs.size());
188 for (uint32_t tensor_index : output_refs) {
189 const int64_t object_index = output_indices_.size();
190 output_indices_.push_back(tensor_index);
191 const TfLiteTensor& tflite_tensor = context->tensors[tensor_index];
192 const DataType data_type = ToDataType(tflite_tensor.type);
193 RETURN_IF_ERROR(builder->SetOutputObjectDef(
194 object_index, GetObjectDef(tensor_index, data_type)));
195 }
196
197 return builder->Build(&runner_);
198 }
199
200 // This directs the runtime to allocate memory for input/output temporary
201 // tensors that require dequantization/quantization.
GetRequiredTemporaries(TfLiteContext * context,TfLiteNode * node,TfLiteIntArray ** temporaries_array_ptr)202 absl::Status GetRequiredTemporaries(TfLiteContext* context, TfLiteNode* node,
203 TfLiteIntArray** temporaries_array_ptr) {
204 if (quant_conversion_map_.empty()) return absl::OkStatus();
205
206 std::vector<int> temporary_tensors;
207 for (auto index : input_indices_) {
208 if (quant_conversion_map_.find(index) != quant_conversion_map_.end()) {
209 temporary_tensors.push_back(index);
210 }
211 }
212 for (auto index : output_indices_) {
213 if (quant_conversion_map_.find(index) != quant_conversion_map_.end()) {
214 temporary_tensors.push_back(index);
215 }
216 }
217 *temporaries_array_ptr = TfLiteIntArrayCreate(temporary_tensors.size());
218 for (int i = 0; i < temporary_tensors.size(); ++i) {
219 (*temporaries_array_ptr)->data[i] = temporary_tensors[i];
220 }
221 return absl::OkStatus();
222 }
223
Invoke(TfLiteContext * context)224 absl::Status Invoke(TfLiteContext* context) {
225 if (thread_id_prepare_ != std::this_thread::get_id()) {
226 TFLITE_LOG(tflite::TFLITE_LOG_WARNING,
227 "GpuDelegate invoke thread != prepare thread");
228 if (enforce_same_thread_) {
229 return absl::FailedPreconditionError(
230 "GpuDelegate must run on the same thread where it was "
231 "initialized.");
232 }
233 }
234
235 const bool is_dequant_required = !quant_conversion_map_.empty();
236 if (is_dequant_required) {
237 RETURN_IF_ERROR(
238 DequantizeInputs(context, input_indices_, quant_conversion_map_));
239 }
240 RETURN_IF_ERROR(SetInputsAndOutputs(context));
241 RETURN_IF_ERROR(runner_->Run());
242 if (is_dequant_required) {
243 RETURN_IF_ERROR(
244 QuantizeOutputs(context, output_indices_, quant_conversion_map_));
245 }
246 return absl::OkStatus();
247 }
248
249 private:
SetInputsAndOutputs(TfLiteContext * context)250 absl::Status SetInputsAndOutputs(TfLiteContext* context) {
251 for (int i = 0; i < input_indices_.size(); ++i) {
252 RETURN_IF_ERROR(runner_->SetInputObject(
253 i, GetTensorObject(input_indices_[i], context)));
254 }
255 for (int i = 0; i < output_indices_.size(); ++i) {
256 RETURN_IF_ERROR(runner_->SetOutputObject(
257 i, GetTensorObject(output_indices_[i], context)));
258 }
259 return absl::OkStatus();
260 }
261
GetObjectDef(int index,DataType data_type=DataType::FLOAT32) const262 ObjectDef GetObjectDef(int index,
263 DataType data_type = DataType::FLOAT32) const {
264 ObjectDef default_object_def;
265 default_object_def.data_type = data_type;
266 default_object_def.data_layout = DataLayout::BHWC;
267 default_object_def.object_type = ObjectType::CPU_MEMORY;
268 default_object_def.user_provided = true;
269 return default_object_def;
270 }
271
GetTensorObject(int index,TfLiteContext * context) const272 TensorObject GetTensorObject(int index, TfLiteContext* context) const {
273 auto& tensor = context->tensors[index];
274 return MakeCpuMemory(absl::MakeSpan(tensor.data.raw, tensor.bytes));
275 }
276
277 private:
InitializeGraph(TfLiteContext * context,const TfLiteDelegateParams * delegate_params,GraphFloat32 * graph,std::vector<uint32_t> * input_refs,std::vector<uint32_t> * output_refs)278 absl::Status InitializeGraph(TfLiteContext* context,
279 const TfLiteDelegateParams* delegate_params,
280 GraphFloat32* graph,
281 std::vector<uint32_t>* input_refs,
282 std::vector<uint32_t>* output_refs) {
283 quant_conversion_map_.clear();
284 if (delegate_->IsQuantOpsAllowed()) {
285 RETURN_IF_ERROR(BuildFinalModel(context, delegate_params, graph,
286 &quant_conversion_map_));
287 } else {
288 RETURN_IF_ERROR(BuildFinalModel(context, delegate_params, graph));
289 }
290
291 // TfLiteDelegateParams.input_tensors is an array of all input tensors
292 // including static weights. GraphFloat32.inputs() is an array of runtime
293 // tensors that don't have a producer and the order may not be the same as
294 // defined by TfLiteDelegateParams.input_tensors. These two sets are not
295 // the same, especially on a multi-partition delegation. These are matched
296 // by filtering TfLiteDelegateParams.input_tensors with
297 // !tflite::IsConstantTensor() and then inserting them in the order
298 // specified by TfLiteDelegateParams.input_tensors. This logic is shared
299 // with ModelBuilder::PrecreateIOTensors() which is eventually called with
300 // BuildFinalModel() above.
301 //
302 // Similarly, TfLiteDelegateParams.output_tensors is an array of all output
303 // tensors, and can contain static tensors with buggy conversion.
304 // GraphFloat32.outputs() is an array of runtime tensors that don't have a
305 // consumer (this is a bug in the assumption) and the order may not be the
306 // same as defined by TfLiteDelegateParams.output_tensors. Again, these two
307 // sets are not the same, especially on a multi-partition delegation. These
308 // are matched by inserting the tensors by the order defined by
309 // TfLiteDelegateParams.output_tensors. Similarly, this logic is shared
310 // with ModelBuilder::PrecreateIOTensors() which is eventually called with
311 // BuildFinalModel() above.
312 //
313 // The aforementioned matching in BuildFinalModel() is ported here to match
314 // input/output_refs.
315 // TODO(b/211393366): Fix this at GraphFloat32.inputs/outputs() level.
316 const std::vector<Value*> inputs = graph->inputs();
317 input_refs->clear();
318 input_refs->reserve(delegate_params->input_tensors->size);
319 for (int i = 0, j = 0; i < delegate_params->input_tensors->size; ++i) {
320 const TfLiteTensor* tensor =
321 context->tensors + delegate_params->input_tensors->data[i];
322 if (tflite::IsConstantTensor(tensor)) continue;
323 input_refs->push_back(inputs[j]->tensor.ref);
324 ++j;
325 }
326 const std::vector<Value*> outputs = graph->outputs();
327 output_refs->clear();
328 const int output_size = std::min(static_cast<int>(graph->outputs().size()),
329 delegate_params->output_tensors->size);
330 output_refs->reserve(output_size);
331 for (int i = 0; i < output_size; ++i) {
332 output_refs->push_back(outputs[i]->tensor.ref);
333 }
334
335 return absl::OkStatus();
336 }
337
InitializeOpenClApi(GraphFloat32 * graph,std::unique_ptr<InferenceBuilder> * builder,bool * graph_is_destroyed,TfLiteContext * context,const TfLiteDelegateParams * delegate_params,Serialization * serialization=nullptr)338 absl::Status InitializeOpenClApi(GraphFloat32* graph,
339 std::unique_ptr<InferenceBuilder>* builder,
340 bool* graph_is_destroyed,
341 TfLiteContext* context,
342 const TfLiteDelegateParams* delegate_params,
343 Serialization* serialization = nullptr) {
344 *graph_is_destroyed = false;
345 cl::InferenceEnvironmentOptions env_options;
346 cl::InferenceEnvironmentProperties properties;
347
348 // OpenCL initialization is parameterized by these InferenceOptions.
349 auto delegate_options = delegate_->options();
350 cl::InferenceOptions options;
351 // If is_precision_loss_allowed == -1, then just use priorities instead
352 // of paying attention to is_precision_loss_allowed value.
353 if (delegate_options.is_precision_loss_allowed == -1) {
354 options.priority1 = ToPriority(delegate_options.inference_priority1);
355 options.priority2 = ToPriority(delegate_options.inference_priority2);
356 options.priority3 = ToPriority(delegate_options.inference_priority3);
357 } else {
358 // Users set is_precision_loss_allowed explicitly, thus use it explicitly.
359 if (delegate_options.is_precision_loss_allowed == 0) {
360 options.priority1 = InferencePriority::MAX_PRECISION;
361 } else {
362 options.priority1 = InferencePriority::MIN_LATENCY;
363 }
364 }
365 options.usage = ToUsage(delegate_options.inference_preference);
366
367 if (!serialization) {
368 // This path is faster when there is no serialization involved.
369 RETURN_IF_ERROR(cl::NewInferenceEnvironment(env_options, &cl_environment_,
370 &properties));
371 *graph_is_destroyed = true;
372 RETURN_IF_ERROR(cl_environment_->NewInferenceBuilder(
373 options, std::move(*graph), builder));
374 } else {
375 // If serialization data is found, initialize CL from it & return early.
376 if (MaybeInitializeSerializedOpenCL(context, delegate_params, builder,
377 &options, &env_options, &properties,
378 serialization)
379 .ok()) {
380 return absl::OkStatus();
381 }
382
383 RETURN_IF_ERROR(cl::NewInferenceEnvironment(env_options, &cl_environment_,
384 &properties));
385 *graph_is_destroyed = true;
386 std::vector<uint8_t> serialized_model;
387 RETURN_IF_ERROR(cl_environment_->BuildSerializedModel(
388 options, std::move(*graph), &serialized_model));
389 RETURN_IF_ERROR(
390 cl_environment_->NewInferenceBuilder(serialized_model, builder));
391
392 RETURN_IF_ERROR(SaveSerializedOpenCL(context, delegate_params, &options,
393 serialization, serialized_model));
394 }
395
396 TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO,
397 "Initialized OpenCL-based API.");
398 return absl::OkStatus();
399 }
400
401 // Returns Ok only if serialized data is successsfully found.
MaybeInitializeSerializedOpenCL(TfLiteContext * context,const TfLiteDelegateParams * delegate_params,std::unique_ptr<InferenceBuilder> * builder,cl::InferenceOptions * options,cl::InferenceEnvironmentOptions * env_options,cl::InferenceEnvironmentProperties * properties,Serialization * serialization)402 absl::Status MaybeInitializeSerializedOpenCL(
403 TfLiteContext* context, const TfLiteDelegateParams* delegate_params,
404 std::unique_ptr<InferenceBuilder>* builder, cl::InferenceOptions* options,
405 cl::InferenceEnvironmentOptions* env_options,
406 cl::InferenceEnvironmentProperties* properties,
407 Serialization* serialization) {
408 if (!serialization) return absl::InvalidArgumentError("No serialization");
409 // We use a fingerprint of the options to ensure compatibility.
410 std::string options_fingerprint =
411 delegates::StrFingerprint(options, sizeof(cl::InferenceOptions));
412 auto data_key = serialization->GetEntryForKernel(
413 std::string(kSerializedDataPrefix) + options_fingerprint, context,
414 delegate_params);
415
416 std::string model_data;
417 auto model_data_status = data_key.GetData(context, &model_data);
418 if (model_data_status == kTfLiteOk) {
419 absl::Span<const uint8_t> model_span = absl::Span<const uint8_t>{
420 reinterpret_cast<const uint8_t*>(model_data.data()),
421 model_data.size()};
422 RETURN_IF_ERROR(cl::NewInferenceEnvironment(
423 *env_options, &cl_environment_, properties));
424 RETURN_IF_ERROR(
425 cl_environment_->NewInferenceBuilder(model_span, builder));
426 TFLITE_LOG_PROD_ONCE(
427 tflite::TFLITE_LOG_INFO,
428 "Initialized OpenCL-based API from serialized data.");
429 return absl::OkStatus();
430 }
431
432 return absl::NotFoundError("Serialization data not found");
433 }
434
435 // Returns Ok only if serialization happens successfully.
SaveSerializedOpenCL(TfLiteContext * context,const TfLiteDelegateParams * delegate_params,cl::InferenceOptions * options,Serialization * serialization,const std::vector<uint8_t> & serialized_model)436 absl::Status SaveSerializedOpenCL(
437 TfLiteContext* context, const TfLiteDelegateParams* delegate_params,
438 cl::InferenceOptions* options, Serialization* serialization,
439 const std::vector<uint8_t>& serialized_model) {
440 if (!serialization) return absl::InvalidArgumentError("No serialization");
441 // We use a fingerprint of the options to ensure compatibility.
442 std::string options_fingerprint =
443 delegates::StrFingerprint(options, sizeof(cl::InferenceOptions));
444
445 // Save data.
446 auto data_key = serialization->GetEntryForKernel(
447 std::string(kSerializedDataPrefix) + options_fingerprint, context,
448 delegate_params);
449 auto save_status = data_key.SetData(
450 context, reinterpret_cast<const char*>(serialized_model.data()),
451 serialized_model.size());
452 if (save_status != kTfLiteOk) {
453 return absl::InvalidArgumentError("Failed to save serialized data");
454 }
455 return absl::OkStatus();
456 }
457
InitializeOpenGlApi(GraphFloat32 * graph,std::unique_ptr<InferenceBuilder> * builder)458 absl::Status InitializeOpenGlApi(GraphFloat32* graph,
459 std::unique_ptr<InferenceBuilder>* builder) {
460 #ifndef CL_DELEGATE_NO_GL
461 gl::InferenceEnvironmentOptions env_options;
462 gl::InferenceEnvironmentProperties properties;
463 RETURN_IF_ERROR(
464 NewInferenceEnvironment(env_options, &gl_environment_, &properties));
465 auto delegate_options = delegate_->options();
466 gl::InferenceOptions options;
467 options.usage = ToUsage(delegate_options.inference_preference);
468 options.priority1 = ToPriority(delegate_options.inference_priority1);
469 options.priority2 = ToPriority(delegate_options.inference_priority2);
470 options.priority3 = ToPriority(delegate_options.inference_priority3);
471 RETURN_IF_ERROR(gl_environment_->NewInferenceBuilder(std::move(*graph),
472 options, builder));
473 enforce_same_thread_ = true;
474 TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO,
475 "Initialized OpenGL-based API.");
476 return absl::OkStatus();
477 #else
478 return absl::UnavailableError("OpenGL-based API disabled");
479 #endif
480 }
481
482 // The Delegate instance that's shared across all DelegateKernel instances.
483 Delegate* const delegate_; // doesn't own the memory.
484 std::unique_ptr<cl::InferenceEnvironment> cl_environment_;
485 #ifndef CL_DELEGATE_NO_GL
486 std::unique_ptr<gl::InferenceEnvironment> gl_environment_;
487 #endif
488 std::unique_ptr<InferenceRunner> runner_;
489 std::vector<int64_t> input_indices_;
490 std::vector<int64_t> output_indices_;
491 // Whenever quantized inference is enabled, this maps the tensor index of each
492 // originally quantized (8-bit) tensor to its float version added in
493 // model_builder - and vice versa.
494 absl::flat_hash_map<int, int> quant_conversion_map_;
495 std::thread::id thread_id_prepare_; // thread id used for Prapare()
496 bool enforce_same_thread_ = false; // flag to enforce same thread for Invoke
497 };
498
GetDelegateKernel(TfLiteNode * node)499 inline DelegateKernel* GetDelegateKernel(TfLiteNode* node) {
500 return reinterpret_cast<DelegateKernel*>(node->user_data);
501 }
502
GetDelegate(TfLiteDelegate * delegate)503 inline Delegate* GetDelegate(TfLiteDelegate* delegate) {
504 return reinterpret_cast<Delegate*>(delegate->data_);
505 }
506
DelegatePrepare(TfLiteContext * context,TfLiteDelegate * delegate)507 TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) {
508 const TfLiteRegistration kRegistration = {
509 // .init
510 [](TfLiteContext* context, const char* buffer, size_t) -> void* {
511 const auto* params =
512 reinterpret_cast<const TfLiteDelegateParams*>(buffer);
513 auto* gpu_delegate = GetDelegate(params->delegate);
514 // Everything below should happen in prepare function call, but TFLite
515 // for whatever reason forbids that.
516 auto gpu_delegate_kernel =
517 std::make_unique<DelegateKernel>(gpu_delegate);
518 const auto status = gpu_delegate_kernel->Prepare(context, params);
519 if (!status.ok()) {
520 TF_LITE_KERNEL_LOG(context, "TfLiteGpuDelegate Init: %s",
521 std::string(status.message()).c_str());
522 return nullptr;
523 }
524 return gpu_delegate_kernel.release();
525 },
526 // .free
527 [](TfLiteContext*, void* buffer) -> void {
528 delete reinterpret_cast<DelegateKernel*>(buffer);
529 },
530 // .prepare
531 [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
532 if (!node->user_data) {
533 TF_LITE_KERNEL_LOG(
534 context,
535 "TfLiteGpuDelegate Prepare: delegate is not initialized");
536 return kTfLiteError;
537 }
538 auto* gpu_delegate_kernel = GetDelegateKernel(node);
539 const auto status = gpu_delegate_kernel->GetRequiredTemporaries(
540 context, node, &node->temporaries);
541 if (!status.ok()) {
542 TF_LITE_KERNEL_LOG(context, "TfLiteGpuDelegate Prepare: %s",
543 std::string(status.message()).c_str());
544 return kTfLiteError;
545 }
546 // TODO(akulik): tflite tensors are not allocated here either. It would
547 // be good to set inputs and outputs only once here instead of setting
548 // them every time in .invoke.
549 return kTfLiteOk;
550 },
551 // .invoke
552 [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
553 const auto status = GetDelegateKernel(node)->Invoke(context);
554 if (!status.ok()) {
555 TF_LITE_KERNEL_LOG(context, "TfLiteGpuDelegate Invoke: %s",
556 std::string(status.message()).c_str());
557 return kTfLiteError;
558 }
559 return kTfLiteOk;
560 },
561 nullptr, // .profiling_string
562 0, // .builtin_code
563 "TfLiteGpuDelegateV2", // .custom_name
564 1, // .version
565 };
566
567 auto* gpu_delegate = GetDelegate(delegate);
568 absl::flat_hash_set<TfLiteBuiltinOperator> excluded_ops;
569 if (!cl::OpenCLSupported()) {
570 excluded_ops.insert(kTfLiteBuiltinSplit);
571 excluded_ops.insert(kTfLiteBuiltinSplitV);
572 }
573 TfLiteIntArray* ops_to_replace =
574 GetOpsToReplace(context, gpu_delegate->IsQuantOpsAllowed(),
575 gpu_delegate->MaxDelegatedPartitions(), &excluded_ops);
576 const auto status = context->ReplaceNodeSubsetsWithDelegateKernels(
577 context, kRegistration, ops_to_replace, delegate);
578 TFLITE_LOG_PROD(TFLITE_LOG_INFO, "Created %d GPU delegate kernels.",
579 gpu_delegate->num_delegate_kernels());
580 TfLiteIntArrayFree(ops_to_replace);
581 return status;
582 }
583
584 } // namespace
585 } // namespace gpu
586 } // namespace tflite
587
TfLiteGpuDelegateV2Create(const TfLiteGpuDelegateOptionsV2 * options)588 TfLiteDelegate* TfLiteGpuDelegateV2Create(
589 const TfLiteGpuDelegateOptionsV2* options) {
590 auto* gpu_delegate = new tflite::gpu::Delegate(options);
591 TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO,
592 "Created TensorFlow Lite delegate for GPU.");
593 return gpu_delegate ? gpu_delegate->tflite_delegate() : nullptr;
594 }
595
TfLiteGpuDelegateV2Delete(TfLiteDelegate * delegate)596 void TfLiteGpuDelegateV2Delete(TfLiteDelegate* delegate) {
597 delete tflite::gpu::GetDelegate(delegate);
598 }
599