1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/api.h"
17
18 #include <variant>
19
20 namespace tflite {
21 namespace gpu {
22 namespace {
23
24 struct ObjectTypeGetter {
operator ()tflite::gpu::__anon7b9c542e0111::ObjectTypeGetter25 ObjectType operator()(std::monostate) const { return ObjectType::UNKNOWN; }
operator ()tflite::gpu::__anon7b9c542e0111::ObjectTypeGetter26 ObjectType operator()(OpenGlBuffer) const { return ObjectType::OPENGL_SSBO; }
operator ()tflite::gpu::__anon7b9c542e0111::ObjectTypeGetter27 ObjectType operator()(OpenGlTexture) const {
28 return ObjectType::OPENGL_TEXTURE;
29 }
operator ()tflite::gpu::__anon7b9c542e0111::ObjectTypeGetter30 ObjectType operator()(OpenClBuffer) const {
31 return ObjectType::OPENCL_BUFFER;
32 }
operator ()tflite::gpu::__anon7b9c542e0111::ObjectTypeGetter33 ObjectType operator()(OpenClTexture) const {
34 return ObjectType::OPENCL_TEXTURE;
35 }
operator ()tflite::gpu::__anon7b9c542e0111::ObjectTypeGetter36 ObjectType operator()(VulkanBuffer) const {
37 return ObjectType::VULKAN_BUFFER;
38 }
operator ()tflite::gpu::__anon7b9c542e0111::ObjectTypeGetter39 ObjectType operator()(VulkanTexture) const {
40 return ObjectType::VULKAN_TEXTURE;
41 }
operator ()tflite::gpu::__anon7b9c542e0111::ObjectTypeGetter42 ObjectType operator()(CpuMemory) const { return ObjectType::CPU_MEMORY; }
43 };
44
45 struct ObjectValidityChecker {
operator ()tflite::gpu::__anon7b9c542e0111::ObjectValidityChecker46 bool operator()(std::monostate) const { return false; }
operator ()tflite::gpu::__anon7b9c542e0111::ObjectValidityChecker47 bool operator()(OpenGlBuffer obj) const { return obj.id != GL_INVALID_INDEX; }
operator ()tflite::gpu::__anon7b9c542e0111::ObjectValidityChecker48 bool operator()(OpenGlTexture obj) const {
49 return obj.id != GL_INVALID_INDEX && obj.format != GL_INVALID_ENUM;
50 }
operator ()tflite::gpu::__anon7b9c542e0111::ObjectValidityChecker51 bool operator()(OpenClBuffer obj) const { return obj.memobj; }
operator ()tflite::gpu::__anon7b9c542e0111::ObjectValidityChecker52 bool operator()(OpenClTexture obj) const { return obj.memobj; }
operator ()tflite::gpu::__anon7b9c542e0111::ObjectValidityChecker53 bool operator()(VulkanBuffer obj) const { return obj.memory; }
operator ()tflite::gpu::__anon7b9c542e0111::ObjectValidityChecker54 bool operator()(VulkanTexture obj) const { return obj.memory; }
operator ()tflite::gpu::__anon7b9c542e0111::ObjectValidityChecker55 bool operator()(CpuMemory obj) const {
56 return obj.data != nullptr && obj.size_bytes > 0 &&
57 (data_type == DataType::UNKNOWN || data_type == DataType::BOOL ||
58 obj.size_bytes % SizeOf(data_type) == 0);
59 }
60 DataType data_type;
61 };
62
63 } // namespace
64
IsValid(const ObjectDef & def)65 bool IsValid(const ObjectDef& def) {
66 return def.data_type != DataType::UNKNOWN &&
67 def.data_layout != DataLayout::UNKNOWN &&
68 def.object_type != ObjectType::UNKNOWN;
69 }
70
GetType(const TensorObject & object)71 ObjectType GetType(const TensorObject& object) {
72 return std::visit(ObjectTypeGetter{}, object);
73 }
74
IsValid(const TensorObjectDef & def)75 bool IsValid(const TensorObjectDef& def) { return IsValid(def.object_def); }
76
IsValid(const TensorObjectDef & def,const TensorObject & object)77 bool IsValid(const TensorObjectDef& def, const TensorObject& object) {
78 return GetType(object) == def.object_def.object_type &&
79 std::visit(ObjectValidityChecker{def.object_def.data_type}, object);
80 }
81
IsObjectPresent(ObjectType type,const TensorObject & obj)82 bool IsObjectPresent(ObjectType type, const TensorObject& obj) {
83 switch (type) {
84 case ObjectType::CPU_MEMORY:
85 return std::holds_alternative<CpuMemory>(obj);
86 case ObjectType::OPENGL_SSBO:
87 return std::holds_alternative<OpenGlBuffer>(obj);
88 case ObjectType::OPENGL_TEXTURE:
89 return std::holds_alternative<OpenGlTexture>(obj);
90 case ObjectType::OPENCL_BUFFER:
91 return std::holds_alternative<OpenClBuffer>(obj);
92 case ObjectType::OPENCL_TEXTURE:
93 return std::holds_alternative<OpenClTexture>(obj);
94 case ObjectType::VULKAN_BUFFER:
95 return std::holds_alternative<VulkanBuffer>(obj);
96 case ObjectType::VULKAN_TEXTURE:
97 return std::holds_alternative<VulkanTexture>(obj);
98 case ObjectType::UNKNOWN:
99 return false;
100 }
101 }
102
IsObjectInitialized(const TensorObject & obj)103 bool IsObjectInitialized(const TensorObject& obj) {
104 return GetType(obj) != ObjectType::UNKNOWN;
105 }
106
NumElements(const TensorObjectDef & def)107 uint32_t NumElements(const TensorObjectDef& def) {
108 const auto& d = def.dimensions;
109 switch (def.object_def.data_layout) {
110 case DataLayout::BHWC:
111 return d.product();
112 case DataLayout::HWDC4:
113 case DataLayout::HDWC4:
114 case DataLayout::DHWC4:
115 return d.b * d.h * d.w * AlignByN(d.c, 4);
116 case DataLayout::UNKNOWN:
117 return 0;
118 }
119 return 0;
120 }
121
GetPosition(const InferenceOptions & options,InferencePriority p)122 int GetPosition(const InferenceOptions& options, InferencePriority p) {
123 if (options.priority1 == p) return 1;
124 if (options.priority2 == p) return 2;
125 if (options.priority3 == p) return 3;
126 return 4; // least important
127 }
128
GetRelativeImportance(const InferenceOptions & options,InferencePriority p1,InferencePriority p2)129 PriorityImportance GetRelativeImportance(const InferenceOptions& options,
130 InferencePriority p1,
131 InferencePriority p2) {
132 int p1_position = GetPosition(options, p1);
133 int p2_position = GetPosition(options, p2);
134 if (p1_position == p2_position) return PriorityImportance::UNKNOWN;
135 return p1_position < p2_position ? PriorityImportance::HIGHER
136 : PriorityImportance::LOWER;
137 }
138
IsValid(const InferenceOptions & options)139 bool IsValid(const InferenceOptions& options) {
140 if (options.usage == InferenceUsage::UNKNOWN) {
141 return false;
142 }
143 if (options.priority1 == InferencePriority::UNKNOWN ||
144 options.priority2 == InferencePriority::UNKNOWN ||
145 options.priority3 == InferencePriority::UNKNOWN) {
146 return false;
147 }
148 if (options.priority1 == InferencePriority::AUTO) {
149 return false;
150 }
151 if (options.priority2 == InferencePriority::AUTO &&
152 options.priority3 != InferencePriority::AUTO) {
153 return false;
154 }
155 if (options.priority1 == options.priority2 ||
156 options.priority1 == options.priority3) {
157 return false;
158 }
159 if (options.priority2 == options.priority3 &&
160 options.priority2 != InferencePriority::AUTO) {
161 return false;
162 }
163 return true;
164 }
165
166 // Implementation note: this resolution logic is shared between GL and CL
167 // backends, but they might have own logic. Thus, the function is defined
168 // here just for code re-use purposes.
ResolveAutoPriority(InferenceOptions * options)169 void ResolveAutoPriority(InferenceOptions* options) {
170 // priority1 can not be AUTO as it would make options invalid.
171 if (options->priority2 == InferencePriority::AUTO) {
172 switch (options->priority1) {
173 case InferencePriority::MIN_LATENCY:
174 options->priority2 = InferencePriority::MIN_MEMORY_USAGE;
175 options->priority3 = InferencePriority::MAX_PRECISION;
176 return;
177 case InferencePriority::MIN_MEMORY_USAGE:
178 options->priority2 = InferencePriority::MAX_PRECISION;
179 options->priority3 = InferencePriority::MIN_LATENCY;
180 return;
181 case InferencePriority::MAX_PRECISION:
182 options->priority2 = InferencePriority::MIN_LATENCY;
183 options->priority3 = InferencePriority::MIN_MEMORY_USAGE;
184 return;
185 case InferencePriority::UNKNOWN:
186 case InferencePriority::AUTO:
187 // Invalid and unreachable option.
188 return;
189 }
190 }
191
192 if (options->priority3 == InferencePriority::AUTO) {
193 // Simply add missing priority
194 if (GetPosition(*options, InferencePriority::MIN_LATENCY) == 4) {
195 options->priority3 = InferencePriority::MIN_LATENCY;
196 } else if (GetPosition(*options, InferencePriority::MAX_PRECISION) == 4) {
197 options->priority3 = InferencePriority::MAX_PRECISION;
198 } else if (GetPosition(*options, InferencePriority::MIN_MEMORY_USAGE) ==
199 4) {
200 options->priority3 = InferencePriority::MIN_MEMORY_USAGE;
201 }
202 }
203 }
204
205 } // namespace gpu
206 } // namespace tflite
207