xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/delegate.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/delegate.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <memory>
21 #include <string>
22 #include <thread>  // NOLINT(build/c++11)
23 #include <vector>
24 
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/types/span.h"
28 #include "tensorflow/lite/builtin_ops.h"
29 #include "tensorflow/lite/c/common.h"
30 #include "tensorflow/lite/delegates/gpu/api.h"
31 #include "tensorflow/lite/delegates/gpu/cl/api.h"
32 #include "tensorflow/lite/delegates/gpu/cl/util.h"
33 #include "tensorflow/lite/delegates/gpu/common/model_builder.h"
34 #include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h"
35 #include "tensorflow/lite/delegates/gpu/common/quantization_util.h"
36 #include "tensorflow/lite/delegates/serialization.h"
37 #include "tensorflow/lite/kernels/kernel_util.h"
38 #include "tensorflow/lite/minimal_logging.h"
39 
40 #ifndef CL_DELEGATE_NO_GL
41 #include "tensorflow/lite/delegates/gpu/gl/api2.h"
42 #endif
43 
44 namespace tflite {
45 namespace gpu {
46 namespace {
47 
48 using delegates::Serialization;
49 using delegates::SerializationParams;
50 
51 constexpr char kSerializedDataPrefix[] = "gpuv2_data_";
52 
ToPriority(int32_t priority)53 InferencePriority ToPriority(int32_t priority) {
54   switch (priority) {
55     case TFLITE_GPU_INFERENCE_PRIORITY_AUTO:
56       return InferencePriority::AUTO;
57     case TFLITE_GPU_INFERENCE_PRIORITY_MAX_PRECISION:
58       return InferencePriority::MAX_PRECISION;
59     case TFLITE_GPU_INFERENCE_PRIORITY_MIN_LATENCY:
60       return InferencePriority::MIN_LATENCY;
61     case TFLITE_GPU_INFERENCE_PRIORITY_MIN_MEMORY_USAGE:
62       return InferencePriority::MIN_MEMORY_USAGE;
63   }
64   return InferencePriority::UNKNOWN;
65 }
66 
ToUsage(int32_t usage)67 InferenceUsage ToUsage(int32_t usage) {
68   switch (usage) {
69     case TFLITE_GPU_INFERENCE_PREFERENCE_FAST_SINGLE_ANSWER:
70       return InferenceUsage::FAST_SINGLE_ANSWER;
71     case TFLITE_GPU_INFERENCE_PREFERENCE_SUSTAINED_SPEED:
72       return InferenceUsage::SUSTAINED_SPEED;
73   }
74   return InferenceUsage::UNKNOWN;
75 }
76 
77 // Forward declarations.
78 TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate);
79 
80 class Delegate {
81  public:
Delegate(const TfLiteGpuDelegateOptionsV2 * options)82   explicit Delegate(const TfLiteGpuDelegateOptionsV2* options)
83       : num_delegate_kernels_(0) {
84     delegate_.data_ = reinterpret_cast<void*>(this);
85     delegate_.Prepare = DelegatePrepare;
86     delegate_.CopyFromBufferHandle = nullptr;
87     delegate_.CopyToBufferHandle = nullptr;
88     delegate_.FreeBufferHandle = nullptr;
89     delegate_.flags = kTfLiteDelegateFlagsNone;
90     options_ = options ? *options : TfLiteGpuDelegateOptionsV2Default();
91     if (options_.max_delegated_partitions <= 0) {
92       options_.max_delegated_partitions = 1;
93     }
94     if (options_.experimental_flags &
95             TFLITE_GPU_EXPERIMENTAL_FLAGS_ENABLE_SERIALIZATION &&
96         options_.model_token && options_.serialization_dir) {
97       SerializationParams params;
98       params.model_token = options_.model_token;
99       params.cache_dir = options_.serialization_dir;
100       serialization_ = std::make_unique<Serialization>(params);
101     }
102   }
103 
tflite_delegate()104   TfLiteDelegate* tflite_delegate() { return &delegate_; }
serialization()105   Serialization* serialization() { return serialization_.get(); }
options() const106   const TfLiteGpuDelegateOptionsV2& options() const { return options_; }
107 
IsQuantOpsAllowed() const108   bool IsQuantOpsAllowed() const {
109     return options_.experimental_flags &
110            TFLITE_GPU_EXPERIMENTAL_FLAGS_ENABLE_QUANT;
111   }
MaxDelegatedPartitions() const112   int MaxDelegatedPartitions() const {
113     return options_.max_delegated_partitions;
114   }
num_delegate_kernels() const115   int num_delegate_kernels() const { return num_delegate_kernels_; }
116 
117  private:
118   TfLiteDelegate delegate_;
119   TfLiteGpuDelegateOptionsV2 options_;
120   int num_delegate_kernels_ = 0;
121 
122   std::unique_ptr<Serialization> serialization_;
123 
124   friend class DelegateKernel;
125 };
126 
127 // Represent the execution of a subset of nodes on GPU.
128 class DelegateKernel {
129  public:
DelegateKernel(Delegate * delegate)130   explicit DelegateKernel(Delegate* delegate) : delegate_(delegate) {
131     ++delegate_->num_delegate_kernels_;
132   }
~DelegateKernel()133   ~DelegateKernel() { --delegate_->num_delegate_kernels_; }
134 
Prepare(TfLiteContext * context,const TfLiteDelegateParams * delegate_params)135   absl::Status Prepare(TfLiteContext* context,
136                        const TfLiteDelegateParams* delegate_params) {
137     thread_id_prepare_ = std::this_thread::get_id();
138 
139     // Extract TFLite delegate execution plan from the context and convert it
140     // into GraphFloat32.
141     GraphFloat32 graph;
142     std::vector<uint32_t> input_refs;
143     std::vector<uint32_t> output_refs;
144     RETURN_IF_ERROR(InitializeGraph(context, delegate_params, &graph,
145                                     &input_refs, &output_refs));
146 
147     std::unique_ptr<InferenceBuilder> builder;
148     bool graph_is_destroyed;
149     const int experimental_flags = delegate_->options().experimental_flags;
150     if (experimental_flags & TFLITE_GPU_EXPERIMENTAL_FLAGS_CL_ONLY) {
151       RETURN_IF_ERROR(InitializeOpenClApi(&graph, &builder, &graph_is_destroyed,
152                                           context, delegate_params,
153                                           delegate_->serialization()));
154     } else if (experimental_flags & TFLITE_GPU_EXPERIMENTAL_FLAGS_GL_ONLY) {
155       RETURN_IF_ERROR(InitializeOpenGlApi(&graph, &builder));
156     } else {
157       // By default, we try CL first & fall back to GL if that fails.
158       absl::Status status =
159           InitializeOpenClApi(&graph, &builder, &graph_is_destroyed, context,
160                               delegate_params, delegate_->serialization());
161       if (!status.ok()) {
162         TF_LITE_KERNEL_LOG(context, std::string(status.message()).c_str());
163         TF_LITE_KERNEL_LOG(context, "Falling back to OpenGL");
164 
165         // Graph needs to be re-created because it is moved above.
166         GraphFloat32 graph2;
167         if (graph_is_destroyed) {
168           RETURN_IF_ERROR(InitializeGraph(context, delegate_params, &graph2,
169                                           &input_refs, &output_refs));
170         }
171         RETURN_IF_ERROR(InitializeOpenGlApi(
172             graph_is_destroyed ? &graph2 : &graph, &builder));
173       }
174     }
175 
176     // At this point, TFLite hasn't allocated tensors yet, therefore, collect
177     // indices and set all input and output tensors from TFLite later.
178     input_indices_.reserve(input_refs.size());
179     for (uint32_t tensor_index : input_refs) {
180       const int64_t object_index = input_indices_.size();
181       input_indices_.push_back(tensor_index);
182       const TfLiteTensor& tflite_tensor = context->tensors[tensor_index];
183       const DataType data_type = ToDataType(tflite_tensor.type);
184       RETURN_IF_ERROR(builder->SetInputObjectDef(
185           object_index, GetObjectDef(tensor_index, data_type)));
186     }
187     output_indices_.reserve(output_refs.size());
188     for (uint32_t tensor_index : output_refs) {
189       const int64_t object_index = output_indices_.size();
190       output_indices_.push_back(tensor_index);
191       const TfLiteTensor& tflite_tensor = context->tensors[tensor_index];
192       const DataType data_type = ToDataType(tflite_tensor.type);
193       RETURN_IF_ERROR(builder->SetOutputObjectDef(
194           object_index, GetObjectDef(tensor_index, data_type)));
195     }
196 
197     return builder->Build(&runner_);
198   }
199 
200   // This directs the runtime to allocate memory for input/output temporary
201   // tensors that require dequantization/quantization.
GetRequiredTemporaries(TfLiteContext * context,TfLiteNode * node,TfLiteIntArray ** temporaries_array_ptr)202   absl::Status GetRequiredTemporaries(TfLiteContext* context, TfLiteNode* node,
203                                       TfLiteIntArray** temporaries_array_ptr) {
204     if (quant_conversion_map_.empty()) return absl::OkStatus();
205 
206     std::vector<int> temporary_tensors;
207     for (auto index : input_indices_) {
208       if (quant_conversion_map_.find(index) != quant_conversion_map_.end()) {
209         temporary_tensors.push_back(index);
210       }
211     }
212     for (auto index : output_indices_) {
213       if (quant_conversion_map_.find(index) != quant_conversion_map_.end()) {
214         temporary_tensors.push_back(index);
215       }
216     }
217     *temporaries_array_ptr = TfLiteIntArrayCreate(temporary_tensors.size());
218     for (int i = 0; i < temporary_tensors.size(); ++i) {
219       (*temporaries_array_ptr)->data[i] = temporary_tensors[i];
220     }
221     return absl::OkStatus();
222   }
223 
Invoke(TfLiteContext * context)224   absl::Status Invoke(TfLiteContext* context) {
225     if (thread_id_prepare_ != std::this_thread::get_id()) {
226       TFLITE_LOG(tflite::TFLITE_LOG_WARNING,
227                  "GpuDelegate invoke thread != prepare thread");
228       if (enforce_same_thread_) {
229         return absl::FailedPreconditionError(
230             "GpuDelegate must run on the same thread where it was "
231             "initialized.");
232       }
233     }
234 
235     const bool is_dequant_required = !quant_conversion_map_.empty();
236     if (is_dequant_required) {
237       RETURN_IF_ERROR(
238           DequantizeInputs(context, input_indices_, quant_conversion_map_));
239     }
240     RETURN_IF_ERROR(SetInputsAndOutputs(context));
241     RETURN_IF_ERROR(runner_->Run());
242     if (is_dequant_required) {
243       RETURN_IF_ERROR(
244           QuantizeOutputs(context, output_indices_, quant_conversion_map_));
245     }
246     return absl::OkStatus();
247   }
248 
249  private:
SetInputsAndOutputs(TfLiteContext * context)250   absl::Status SetInputsAndOutputs(TfLiteContext* context) {
251     for (int i = 0; i < input_indices_.size(); ++i) {
252       RETURN_IF_ERROR(runner_->SetInputObject(
253           i, GetTensorObject(input_indices_[i], context)));
254     }
255     for (int i = 0; i < output_indices_.size(); ++i) {
256       RETURN_IF_ERROR(runner_->SetOutputObject(
257           i, GetTensorObject(output_indices_[i], context)));
258     }
259     return absl::OkStatus();
260   }
261 
GetObjectDef(int index,DataType data_type=DataType::FLOAT32) const262   ObjectDef GetObjectDef(int index,
263                          DataType data_type = DataType::FLOAT32) const {
264     ObjectDef default_object_def;
265     default_object_def.data_type = data_type;
266     default_object_def.data_layout = DataLayout::BHWC;
267     default_object_def.object_type = ObjectType::CPU_MEMORY;
268     default_object_def.user_provided = true;
269     return default_object_def;
270   }
271 
GetTensorObject(int index,TfLiteContext * context) const272   TensorObject GetTensorObject(int index, TfLiteContext* context) const {
273     auto& tensor = context->tensors[index];
274     return MakeCpuMemory(absl::MakeSpan(tensor.data.raw, tensor.bytes));
275   }
276 
277  private:
InitializeGraph(TfLiteContext * context,const TfLiteDelegateParams * delegate_params,GraphFloat32 * graph,std::vector<uint32_t> * input_refs,std::vector<uint32_t> * output_refs)278   absl::Status InitializeGraph(TfLiteContext* context,
279                                const TfLiteDelegateParams* delegate_params,
280                                GraphFloat32* graph,
281                                std::vector<uint32_t>* input_refs,
282                                std::vector<uint32_t>* output_refs) {
283     quant_conversion_map_.clear();
284     if (delegate_->IsQuantOpsAllowed()) {
285       RETURN_IF_ERROR(BuildFinalModel(context, delegate_params, graph,
286                                       &quant_conversion_map_));
287     } else {
288       RETURN_IF_ERROR(BuildFinalModel(context, delegate_params, graph));
289     }
290 
291     // TfLiteDelegateParams.input_tensors is an array of all input tensors
292     // including static weights.  GraphFloat32.inputs() is an array of runtime
293     // tensors that don't have a producer and the order may not be the same as
294     // defined by TfLiteDelegateParams.input_tensors.  These two sets are not
295     // the same, especially on a multi-partition delegation.  These are matched
296     // by filtering TfLiteDelegateParams.input_tensors with
297     // !tflite::IsConstantTensor() and then inserting them in the order
298     // specified by TfLiteDelegateParams.input_tensors.  This logic is shared
299     // with ModelBuilder::PrecreateIOTensors() which is eventually called with
300     // BuildFinalModel() above.
301     //
302     // Similarly, TfLiteDelegateParams.output_tensors is an array of all output
303     // tensors, and can contain static tensors with buggy conversion.
304     // GraphFloat32.outputs() is an array of runtime tensors that don't have a
305     // consumer (this is a bug in the assumption) and the order may not be the
306     // same as defined by TfLiteDelegateParams.output_tensors.  Again, these two
307     // sets are not the same, especially on a multi-partition delegation.  These
308     // are matched by inserting the tensors by the order defined by
309     // TfLiteDelegateParams.output_tensors.  Similarly, this logic is shared
310     // with ModelBuilder::PrecreateIOTensors() which is eventually called with
311     // BuildFinalModel() above.
312     //
313     // The aforementioned matching in BuildFinalModel() is ported here to match
314     // input/output_refs.
315     // TODO(b/211393366): Fix this at GraphFloat32.inputs/outputs() level.
316     const std::vector<Value*> inputs = graph->inputs();
317     input_refs->clear();
318     input_refs->reserve(delegate_params->input_tensors->size);
319     for (int i = 0, j = 0; i < delegate_params->input_tensors->size; ++i) {
320       const TfLiteTensor* tensor =
321           context->tensors + delegate_params->input_tensors->data[i];
322       if (tflite::IsConstantTensor(tensor)) continue;
323       input_refs->push_back(inputs[j]->tensor.ref);
324       ++j;
325     }
326     const std::vector<Value*> outputs = graph->outputs();
327     output_refs->clear();
328     const int output_size = std::min(static_cast<int>(graph->outputs().size()),
329                                      delegate_params->output_tensors->size);
330     output_refs->reserve(output_size);
331     for (int i = 0; i < output_size; ++i) {
332       output_refs->push_back(outputs[i]->tensor.ref);
333     }
334 
335     return absl::OkStatus();
336   }
337 
InitializeOpenClApi(GraphFloat32 * graph,std::unique_ptr<InferenceBuilder> * builder,bool * graph_is_destroyed,TfLiteContext * context,const TfLiteDelegateParams * delegate_params,Serialization * serialization=nullptr)338   absl::Status InitializeOpenClApi(GraphFloat32* graph,
339                                    std::unique_ptr<InferenceBuilder>* builder,
340                                    bool* graph_is_destroyed,
341                                    TfLiteContext* context,
342                                    const TfLiteDelegateParams* delegate_params,
343                                    Serialization* serialization = nullptr) {
344     *graph_is_destroyed = false;
345     cl::InferenceEnvironmentOptions env_options;
346     cl::InferenceEnvironmentProperties properties;
347 
348     // OpenCL initialization is parameterized by these InferenceOptions.
349     auto delegate_options = delegate_->options();
350     cl::InferenceOptions options;
351     // If is_precision_loss_allowed == -1, then just use priorities instead
352     // of paying attention to is_precision_loss_allowed value.
353     if (delegate_options.is_precision_loss_allowed == -1) {
354       options.priority1 = ToPriority(delegate_options.inference_priority1);
355       options.priority2 = ToPriority(delegate_options.inference_priority2);
356       options.priority3 = ToPriority(delegate_options.inference_priority3);
357     } else {
358       // Users set is_precision_loss_allowed explicitly, thus use it explicitly.
359       if (delegate_options.is_precision_loss_allowed == 0) {
360         options.priority1 = InferencePriority::MAX_PRECISION;
361       } else {
362         options.priority1 = InferencePriority::MIN_LATENCY;
363       }
364     }
365     options.usage = ToUsage(delegate_options.inference_preference);
366 
367     if (!serialization) {
368       // This path is faster when there is no serialization involved.
369       RETURN_IF_ERROR(cl::NewInferenceEnvironment(env_options, &cl_environment_,
370                                                   &properties));
371       *graph_is_destroyed = true;
372       RETURN_IF_ERROR(cl_environment_->NewInferenceBuilder(
373           options, std::move(*graph), builder));
374     } else {
375       // If serialization data is found, initialize CL from it & return early.
376       if (MaybeInitializeSerializedOpenCL(context, delegate_params, builder,
377                                           &options, &env_options, &properties,
378                                           serialization)
379               .ok()) {
380         return absl::OkStatus();
381       }
382 
383       RETURN_IF_ERROR(cl::NewInferenceEnvironment(env_options, &cl_environment_,
384                                                   &properties));
385       *graph_is_destroyed = true;
386       std::vector<uint8_t> serialized_model;
387       RETURN_IF_ERROR(cl_environment_->BuildSerializedModel(
388           options, std::move(*graph), &serialized_model));
389       RETURN_IF_ERROR(
390           cl_environment_->NewInferenceBuilder(serialized_model, builder));
391 
392       RETURN_IF_ERROR(SaveSerializedOpenCL(context, delegate_params, &options,
393                                            serialization, serialized_model));
394     }
395 
396     TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO,
397                          "Initialized OpenCL-based API.");
398     return absl::OkStatus();
399   }
400 
401   // Returns Ok only if serialized data is successsfully found.
MaybeInitializeSerializedOpenCL(TfLiteContext * context,const TfLiteDelegateParams * delegate_params,std::unique_ptr<InferenceBuilder> * builder,cl::InferenceOptions * options,cl::InferenceEnvironmentOptions * env_options,cl::InferenceEnvironmentProperties * properties,Serialization * serialization)402   absl::Status MaybeInitializeSerializedOpenCL(
403       TfLiteContext* context, const TfLiteDelegateParams* delegate_params,
404       std::unique_ptr<InferenceBuilder>* builder, cl::InferenceOptions* options,
405       cl::InferenceEnvironmentOptions* env_options,
406       cl::InferenceEnvironmentProperties* properties,
407       Serialization* serialization) {
408     if (!serialization) return absl::InvalidArgumentError("No serialization");
409     // We use a fingerprint of the options to ensure compatibility.
410     std::string options_fingerprint =
411         delegates::StrFingerprint(options, sizeof(cl::InferenceOptions));
412     auto data_key = serialization->GetEntryForKernel(
413         std::string(kSerializedDataPrefix) + options_fingerprint, context,
414         delegate_params);
415 
416     std::string model_data;
417     auto model_data_status = data_key.GetData(context, &model_data);
418     if (model_data_status == kTfLiteOk) {
419       absl::Span<const uint8_t> model_span = absl::Span<const uint8_t>{
420           reinterpret_cast<const uint8_t*>(model_data.data()),
421           model_data.size()};
422       RETURN_IF_ERROR(cl::NewInferenceEnvironment(
423           *env_options, &cl_environment_, properties));
424       RETURN_IF_ERROR(
425           cl_environment_->NewInferenceBuilder(model_span, builder));
426       TFLITE_LOG_PROD_ONCE(
427           tflite::TFLITE_LOG_INFO,
428           "Initialized OpenCL-based API from serialized data.");
429       return absl::OkStatus();
430     }
431 
432     return absl::NotFoundError("Serialization data not found");
433   }
434 
435   // Returns Ok only if serialization happens successfully.
SaveSerializedOpenCL(TfLiteContext * context,const TfLiteDelegateParams * delegate_params,cl::InferenceOptions * options,Serialization * serialization,const std::vector<uint8_t> & serialized_model)436   absl::Status SaveSerializedOpenCL(
437       TfLiteContext* context, const TfLiteDelegateParams* delegate_params,
438       cl::InferenceOptions* options, Serialization* serialization,
439       const std::vector<uint8_t>& serialized_model) {
440     if (!serialization) return absl::InvalidArgumentError("No serialization");
441     // We use a fingerprint of the options to ensure compatibility.
442     std::string options_fingerprint =
443         delegates::StrFingerprint(options, sizeof(cl::InferenceOptions));
444 
445     // Save data.
446     auto data_key = serialization->GetEntryForKernel(
447         std::string(kSerializedDataPrefix) + options_fingerprint, context,
448         delegate_params);
449     auto save_status = data_key.SetData(
450         context, reinterpret_cast<const char*>(serialized_model.data()),
451         serialized_model.size());
452     if (save_status != kTfLiteOk) {
453       return absl::InvalidArgumentError("Failed to save serialized data");
454     }
455     return absl::OkStatus();
456   }
457 
InitializeOpenGlApi(GraphFloat32 * graph,std::unique_ptr<InferenceBuilder> * builder)458   absl::Status InitializeOpenGlApi(GraphFloat32* graph,
459                                    std::unique_ptr<InferenceBuilder>* builder) {
460 #ifndef CL_DELEGATE_NO_GL
461     gl::InferenceEnvironmentOptions env_options;
462     gl::InferenceEnvironmentProperties properties;
463     RETURN_IF_ERROR(
464         NewInferenceEnvironment(env_options, &gl_environment_, &properties));
465     auto delegate_options = delegate_->options();
466     gl::InferenceOptions options;
467     options.usage = ToUsage(delegate_options.inference_preference);
468     options.priority1 = ToPriority(delegate_options.inference_priority1);
469     options.priority2 = ToPriority(delegate_options.inference_priority2);
470     options.priority3 = ToPriority(delegate_options.inference_priority3);
471     RETURN_IF_ERROR(gl_environment_->NewInferenceBuilder(std::move(*graph),
472                                                          options, builder));
473     enforce_same_thread_ = true;
474     TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO,
475                          "Initialized OpenGL-based API.");
476     return absl::OkStatus();
477 #else
478     return absl::UnavailableError("OpenGL-based API disabled");
479 #endif
480   }
481 
482   // The Delegate instance that's shared across all DelegateKernel instances.
483   Delegate* const delegate_;  // doesn't own the memory.
484   std::unique_ptr<cl::InferenceEnvironment> cl_environment_;
485 #ifndef CL_DELEGATE_NO_GL
486   std::unique_ptr<gl::InferenceEnvironment> gl_environment_;
487 #endif
488   std::unique_ptr<InferenceRunner> runner_;
489   std::vector<int64_t> input_indices_;
490   std::vector<int64_t> output_indices_;
491   // Whenever quantized inference is enabled, this maps the tensor index of each
492   // originally quantized (8-bit) tensor to its float version added in
493   // model_builder - and vice versa.
494   absl::flat_hash_map<int, int> quant_conversion_map_;
495   std::thread::id thread_id_prepare_;  // thread id used for Prapare()
496   bool enforce_same_thread_ = false;   // flag to enforce same thread for Invoke
497 };
498 
GetDelegateKernel(TfLiteNode * node)499 inline DelegateKernel* GetDelegateKernel(TfLiteNode* node) {
500   return reinterpret_cast<DelegateKernel*>(node->user_data);
501 }
502 
GetDelegate(TfLiteDelegate * delegate)503 inline Delegate* GetDelegate(TfLiteDelegate* delegate) {
504   return reinterpret_cast<Delegate*>(delegate->data_);
505 }
506 
DelegatePrepare(TfLiteContext * context,TfLiteDelegate * delegate)507 TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) {
508   const TfLiteRegistration kRegistration = {
509       // .init
510       [](TfLiteContext* context, const char* buffer, size_t) -> void* {
511         const auto* params =
512             reinterpret_cast<const TfLiteDelegateParams*>(buffer);
513         auto* gpu_delegate = GetDelegate(params->delegate);
514         // Everything below should happen in prepare function call, but TFLite
515         // for whatever reason forbids that.
516         auto gpu_delegate_kernel =
517             std::make_unique<DelegateKernel>(gpu_delegate);
518         const auto status = gpu_delegate_kernel->Prepare(context, params);
519         if (!status.ok()) {
520           TF_LITE_KERNEL_LOG(context, "TfLiteGpuDelegate Init: %s",
521                              std::string(status.message()).c_str());
522           return nullptr;
523         }
524         return gpu_delegate_kernel.release();
525       },
526       // .free
527       [](TfLiteContext*, void* buffer) -> void {
528         delete reinterpret_cast<DelegateKernel*>(buffer);
529       },
530       // .prepare
531       [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
532         if (!node->user_data) {
533           TF_LITE_KERNEL_LOG(
534               context,
535               "TfLiteGpuDelegate Prepare: delegate is not initialized");
536           return kTfLiteError;
537         }
538         auto* gpu_delegate_kernel = GetDelegateKernel(node);
539         const auto status = gpu_delegate_kernel->GetRequiredTemporaries(
540             context, node, &node->temporaries);
541         if (!status.ok()) {
542           TF_LITE_KERNEL_LOG(context, "TfLiteGpuDelegate Prepare: %s",
543                              std::string(status.message()).c_str());
544           return kTfLiteError;
545         }
546         // TODO(akulik): tflite tensors are not allocated here either. It would
547         // be good to set inputs and outputs only once here instead of setting
548         // them every time in .invoke.
549         return kTfLiteOk;
550       },
551       // .invoke
552       [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
553         const auto status = GetDelegateKernel(node)->Invoke(context);
554         if (!status.ok()) {
555           TF_LITE_KERNEL_LOG(context, "TfLiteGpuDelegate Invoke: %s",
556                              std::string(status.message()).c_str());
557           return kTfLiteError;
558         }
559         return kTfLiteOk;
560       },
561       nullptr,                // .profiling_string
562       0,                      // .builtin_code
563       "TfLiteGpuDelegateV2",  // .custom_name
564       1,                      // .version
565   };
566 
567   auto* gpu_delegate = GetDelegate(delegate);
568   absl::flat_hash_set<TfLiteBuiltinOperator> excluded_ops;
569   if (!cl::OpenCLSupported()) {
570     excluded_ops.insert(kTfLiteBuiltinSplit);
571     excluded_ops.insert(kTfLiteBuiltinSplitV);
572   }
573   TfLiteIntArray* ops_to_replace =
574       GetOpsToReplace(context, gpu_delegate->IsQuantOpsAllowed(),
575                       gpu_delegate->MaxDelegatedPartitions(), &excluded_ops);
576   const auto status = context->ReplaceNodeSubsetsWithDelegateKernels(
577       context, kRegistration, ops_to_replace, delegate);
578   TFLITE_LOG_PROD(TFLITE_LOG_INFO, "Created %d GPU delegate kernels.",
579                   gpu_delegate->num_delegate_kernels());
580   TfLiteIntArrayFree(ops_to_replace);
581   return status;
582 }
583 
584 }  // namespace
585 }  // namespace gpu
586 }  // namespace tflite
587 
TfLiteGpuDelegateV2Create(const TfLiteGpuDelegateOptionsV2 * options)588 TfLiteDelegate* TfLiteGpuDelegateV2Create(
589     const TfLiteGpuDelegateOptionsV2* options) {
590   auto* gpu_delegate = new tflite::gpu::Delegate(options);
591   TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO,
592                        "Created TensorFlow Lite delegate for GPU.");
593   return gpu_delegate ? gpu_delegate->tflite_delegate() : nullptr;
594 }
595 
TfLiteGpuDelegateV2Delete(TfLiteDelegate * delegate)596 void TfLiteGpuDelegateV2Delete(TfLiteDelegate* delegate) {
597   delete tflite::gpu::GetDelegate(delegate);
598 }
599