xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/api.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/api.h"
17 
18 #include <variant>
19 
20 namespace tflite {
21 namespace gpu {
22 namespace {
23 
24 struct ObjectTypeGetter {
operator ()tflite::gpu::__anon7b9c542e0111::ObjectTypeGetter25   ObjectType operator()(std::monostate) const { return ObjectType::UNKNOWN; }
operator ()tflite::gpu::__anon7b9c542e0111::ObjectTypeGetter26   ObjectType operator()(OpenGlBuffer) const { return ObjectType::OPENGL_SSBO; }
operator ()tflite::gpu::__anon7b9c542e0111::ObjectTypeGetter27   ObjectType operator()(OpenGlTexture) const {
28     return ObjectType::OPENGL_TEXTURE;
29   }
operator ()tflite::gpu::__anon7b9c542e0111::ObjectTypeGetter30   ObjectType operator()(OpenClBuffer) const {
31     return ObjectType::OPENCL_BUFFER;
32   }
operator ()tflite::gpu::__anon7b9c542e0111::ObjectTypeGetter33   ObjectType operator()(OpenClTexture) const {
34     return ObjectType::OPENCL_TEXTURE;
35   }
operator ()tflite::gpu::__anon7b9c542e0111::ObjectTypeGetter36   ObjectType operator()(VulkanBuffer) const {
37     return ObjectType::VULKAN_BUFFER;
38   }
operator ()tflite::gpu::__anon7b9c542e0111::ObjectTypeGetter39   ObjectType operator()(VulkanTexture) const {
40     return ObjectType::VULKAN_TEXTURE;
41   }
operator ()tflite::gpu::__anon7b9c542e0111::ObjectTypeGetter42   ObjectType operator()(CpuMemory) const { return ObjectType::CPU_MEMORY; }
43 };
44 
45 struct ObjectValidityChecker {
operator ()tflite::gpu::__anon7b9c542e0111::ObjectValidityChecker46   bool operator()(std::monostate) const { return false; }
operator ()tflite::gpu::__anon7b9c542e0111::ObjectValidityChecker47   bool operator()(OpenGlBuffer obj) const { return obj.id != GL_INVALID_INDEX; }
operator ()tflite::gpu::__anon7b9c542e0111::ObjectValidityChecker48   bool operator()(OpenGlTexture obj) const {
49     return obj.id != GL_INVALID_INDEX && obj.format != GL_INVALID_ENUM;
50   }
operator ()tflite::gpu::__anon7b9c542e0111::ObjectValidityChecker51   bool operator()(OpenClBuffer obj) const { return obj.memobj; }
operator ()tflite::gpu::__anon7b9c542e0111::ObjectValidityChecker52   bool operator()(OpenClTexture obj) const { return obj.memobj; }
operator ()tflite::gpu::__anon7b9c542e0111::ObjectValidityChecker53   bool operator()(VulkanBuffer obj) const { return obj.memory; }
operator ()tflite::gpu::__anon7b9c542e0111::ObjectValidityChecker54   bool operator()(VulkanTexture obj) const { return obj.memory; }
operator ()tflite::gpu::__anon7b9c542e0111::ObjectValidityChecker55   bool operator()(CpuMemory obj) const {
56     return obj.data != nullptr && obj.size_bytes > 0 &&
57            (data_type == DataType::UNKNOWN || data_type == DataType::BOOL ||
58             obj.size_bytes % SizeOf(data_type) == 0);
59   }
60   DataType data_type;
61 };
62 
63 }  // namespace
64 
IsValid(const ObjectDef & def)65 bool IsValid(const ObjectDef& def) {
66   return def.data_type != DataType::UNKNOWN &&
67          def.data_layout != DataLayout::UNKNOWN &&
68          def.object_type != ObjectType::UNKNOWN;
69 }
70 
GetType(const TensorObject & object)71 ObjectType GetType(const TensorObject& object) {
72   return std::visit(ObjectTypeGetter{}, object);
73 }
74 
IsValid(const TensorObjectDef & def)75 bool IsValid(const TensorObjectDef& def) { return IsValid(def.object_def); }
76 
IsValid(const TensorObjectDef & def,const TensorObject & object)77 bool IsValid(const TensorObjectDef& def, const TensorObject& object) {
78   return GetType(object) == def.object_def.object_type &&
79          std::visit(ObjectValidityChecker{def.object_def.data_type}, object);
80 }
81 
IsObjectPresent(ObjectType type,const TensorObject & obj)82 bool IsObjectPresent(ObjectType type, const TensorObject& obj) {
83   switch (type) {
84     case ObjectType::CPU_MEMORY:
85       return std::holds_alternative<CpuMemory>(obj);
86     case ObjectType::OPENGL_SSBO:
87       return std::holds_alternative<OpenGlBuffer>(obj);
88     case ObjectType::OPENGL_TEXTURE:
89       return std::holds_alternative<OpenGlTexture>(obj);
90     case ObjectType::OPENCL_BUFFER:
91       return std::holds_alternative<OpenClBuffer>(obj);
92     case ObjectType::OPENCL_TEXTURE:
93       return std::holds_alternative<OpenClTexture>(obj);
94     case ObjectType::VULKAN_BUFFER:
95       return std::holds_alternative<VulkanBuffer>(obj);
96     case ObjectType::VULKAN_TEXTURE:
97       return std::holds_alternative<VulkanTexture>(obj);
98     case ObjectType::UNKNOWN:
99       return false;
100   }
101 }
102 
IsObjectInitialized(const TensorObject & obj)103 bool IsObjectInitialized(const TensorObject& obj) {
104   return GetType(obj) != ObjectType::UNKNOWN;
105 }
106 
NumElements(const TensorObjectDef & def)107 uint32_t NumElements(const TensorObjectDef& def) {
108   const auto& d = def.dimensions;
109   switch (def.object_def.data_layout) {
110     case DataLayout::BHWC:
111       return d.product();
112     case DataLayout::HWDC4:
113     case DataLayout::HDWC4:
114     case DataLayout::DHWC4:
115       return d.b * d.h * d.w * AlignByN(d.c, 4);
116     case DataLayout::UNKNOWN:
117       return 0;
118   }
119   return 0;
120 }
121 
GetPosition(const InferenceOptions & options,InferencePriority p)122 int GetPosition(const InferenceOptions& options, InferencePriority p) {
123   if (options.priority1 == p) return 1;
124   if (options.priority2 == p) return 2;
125   if (options.priority3 == p) return 3;
126   return 4;  // least important
127 }
128 
GetRelativeImportance(const InferenceOptions & options,InferencePriority p1,InferencePriority p2)129 PriorityImportance GetRelativeImportance(const InferenceOptions& options,
130                                          InferencePriority p1,
131                                          InferencePriority p2) {
132   int p1_position = GetPosition(options, p1);
133   int p2_position = GetPosition(options, p2);
134   if (p1_position == p2_position) return PriorityImportance::UNKNOWN;
135   return p1_position < p2_position ? PriorityImportance::HIGHER
136                                    : PriorityImportance::LOWER;
137 }
138 
IsValid(const InferenceOptions & options)139 bool IsValid(const InferenceOptions& options) {
140   if (options.usage == InferenceUsage::UNKNOWN) {
141     return false;
142   }
143   if (options.priority1 == InferencePriority::UNKNOWN ||
144       options.priority2 == InferencePriority::UNKNOWN ||
145       options.priority3 == InferencePriority::UNKNOWN) {
146     return false;
147   }
148   if (options.priority1 == InferencePriority::AUTO) {
149     return false;
150   }
151   if (options.priority2 == InferencePriority::AUTO &&
152       options.priority3 != InferencePriority::AUTO) {
153     return false;
154   }
155   if (options.priority1 == options.priority2 ||
156       options.priority1 == options.priority3) {
157     return false;
158   }
159   if (options.priority2 == options.priority3 &&
160       options.priority2 != InferencePriority::AUTO) {
161     return false;
162   }
163   return true;
164 }
165 
166 // Implementation note: this resolution logic is shared between GL and CL
167 // backends, but they might have own logic. Thus, the function is defined
168 // here just for code re-use purposes.
ResolveAutoPriority(InferenceOptions * options)169 void ResolveAutoPriority(InferenceOptions* options) {
170   // priority1 can not be AUTO as it would make options invalid.
171   if (options->priority2 == InferencePriority::AUTO) {
172     switch (options->priority1) {
173       case InferencePriority::MIN_LATENCY:
174         options->priority2 = InferencePriority::MIN_MEMORY_USAGE;
175         options->priority3 = InferencePriority::MAX_PRECISION;
176         return;
177       case InferencePriority::MIN_MEMORY_USAGE:
178         options->priority2 = InferencePriority::MAX_PRECISION;
179         options->priority3 = InferencePriority::MIN_LATENCY;
180         return;
181       case InferencePriority::MAX_PRECISION:
182         options->priority2 = InferencePriority::MIN_LATENCY;
183         options->priority3 = InferencePriority::MIN_MEMORY_USAGE;
184         return;
185       case InferencePriority::UNKNOWN:
186       case InferencePriority::AUTO:
187         // Invalid and unreachable option.
188         return;
189     }
190   }
191 
192   if (options->priority3 == InferencePriority::AUTO) {
193     // Simply add missing priority
194     if (GetPosition(*options, InferencePriority::MIN_LATENCY) == 4) {
195       options->priority3 = InferencePriority::MIN_LATENCY;
196     } else if (GetPosition(*options, InferencePriority::MAX_PRECISION) == 4) {
197       options->priority3 = InferencePriority::MAX_PRECISION;
198     } else if (GetPosition(*options, InferencePriority::MIN_MEMORY_USAGE) ==
199                4) {
200       options->priority3 = InferencePriority::MIN_MEMORY_USAGE;
201     }
202   }
203 }
204 
205 }  // namespace gpu
206 }  // namespace tflite
207