xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/common/data_type.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
17 
18 #include <stddef.h>
19 
20 #include <string>
21 
22 #include "absl/strings/str_cat.h"
23 
24 namespace tflite {
25 namespace gpu {
26 namespace {
ToGlslType(const std::string & scalar_type,const std::string & vec_type,int vec_size)27 std::string ToGlslType(const std::string& scalar_type,
28                        const std::string& vec_type, int vec_size) {
29   return vec_size == 1 ? scalar_type : absl::StrCat(vec_type, vec_size);
30 }
31 
GetGlslPrecisionModifier(DataType data_type)32 std::string GetGlslPrecisionModifier(DataType data_type) {
33   switch (data_type) {
34     case DataType::UINT8:
35     case DataType::INT8:
36       return "lowp ";
37     case DataType::FLOAT16:
38     case DataType::INT16:
39     case DataType::UINT16:
40       return "mediump ";
41     case DataType::FLOAT32:
42     case DataType::INT32:
43     case DataType::UINT32:
44       return "highp ";
45     case DataType::BOOL:
46       return "";
47     default:
48       return "";
49   }
50 }
51 }  // namespace
52 
SizeOf(DataType data_type)53 size_t SizeOf(DataType data_type) {
54   switch (data_type) {
55     case DataType::UINT8:
56     case DataType::INT8:
57     case DataType::BOOL:
58       return 1;
59     case DataType::FLOAT16:
60     case DataType::INT16:
61     case DataType::UINT16:
62       return 2;
63     case DataType::FLOAT32:
64     case DataType::INT32:
65     case DataType::UINT32:
66       return 4;
67     case DataType::FLOAT64:
68     case DataType::INT64:
69     case DataType::UINT64:
70       return 8;
71     case DataType::UNKNOWN:
72       return 0;
73   }
74   return 0;
75 }
76 
ToString(DataType data_type)77 std::string ToString(DataType data_type) {
78   switch (data_type) {
79     case DataType::FLOAT16:
80       return "float16";
81     case DataType::FLOAT32:
82       return "float32";
83     case DataType::FLOAT64:
84       return "float64";
85     case DataType::INT16:
86       return "int16";
87     case DataType::INT32:
88       return "int32";
89     case DataType::INT64:
90       return "int64";
91     case DataType::INT8:
92       return "int8";
93     case DataType::UINT16:
94       return "uint16";
95     case DataType::UINT32:
96       return "uint32";
97     case DataType::UINT64:
98       return "uint64";
99     case DataType::UINT8:
100       return "uint8";
101     case DataType::BOOL:
102       return "bool";
103     case DataType::UNKNOWN:
104       return "unknown";
105   }
106   return "undefined";
107 }
108 
ToCLDataType(DataType data_type,int vec_size)109 std::string ToCLDataType(DataType data_type, int vec_size) {
110   const std::string postfix = vec_size == 1 ? "" : std::to_string(vec_size);
111   switch (data_type) {
112     case DataType::FLOAT16:
113       return "half" + postfix;
114     case DataType::FLOAT32:
115       return "float" + postfix;
116     case DataType::FLOAT64:
117       return "double" + postfix;
118     case DataType::INT16:
119       return "short" + postfix;
120     case DataType::INT32:
121       return "int" + postfix;
122     case DataType::INT64:
123       return "long" + postfix;
124     case DataType::INT8:
125       return "char" + postfix;
126     case DataType::UINT16:
127       return "ushort" + postfix;
128     case DataType::UINT32:
129       return "uint" + postfix;
130     case DataType::UINT64:
131       return "ulong" + postfix;
132     case DataType::UINT8:
133       return "uchar" + postfix;
134     case DataType::BOOL:
135       return "bool" + postfix;
136     case DataType::UNKNOWN:
137       return "unknown";
138   }
139   return "undefined";
140 }
141 
ToMetalDataType(DataType data_type,int vec_size)142 std::string ToMetalDataType(DataType data_type, int vec_size) {
143   const std::string postfix = vec_size == 1 ? "" : std::to_string(vec_size);
144   switch (data_type) {
145     case DataType::FLOAT16:
146       return "half" + postfix;
147     case DataType::FLOAT32:
148       return "float" + postfix;
149     case DataType::FLOAT64:
150       return "double" + postfix;
151     case DataType::INT16:
152       return "short" + postfix;
153     case DataType::INT32:
154       return "int" + postfix;
155     case DataType::INT64:
156       return "long" + postfix;
157     case DataType::INT8:
158       return "char" + postfix;
159     case DataType::UINT16:
160       return "ushort" + postfix;
161     case DataType::UINT32:
162       return "uint" + postfix;
163     case DataType::UINT64:
164       return "ulong" + postfix;
165     case DataType::UINT8:
166       return "uchar" + postfix;
167     case DataType::BOOL:
168       return "bool" + postfix;
169     case DataType::UNKNOWN:
170       return "unknown";
171   }
172   return "undefined";
173 }
174 
ToMetalTextureType(DataType data_type)175 DataType ToMetalTextureType(DataType data_type) {
176   switch (data_type) {
177     case DataType::FLOAT32:
178     case DataType::FLOAT16:
179     case DataType::INT32:
180     case DataType::INT16:
181     case DataType::UINT32:
182     case DataType::UINT16:
183       return data_type;
184     case DataType::INT8:
185       return DataType::INT16;
186     case DataType::UINT8:
187     case DataType::BOOL:
188       return DataType::UINT16;
189     default:
190       return DataType::UNKNOWN;
191   }
192 }
193 
ToGlslShaderDataType(DataType data_type,int vec_size,bool add_precision,bool explicit_fp16)194 std::string ToGlslShaderDataType(DataType data_type, int vec_size,
195                                  bool add_precision, bool explicit_fp16) {
196   const std::string precision_modifier =
197       add_precision ? GetGlslPrecisionModifier(data_type) : "";
198   switch (data_type) {
199     case DataType::FLOAT16:
200       if (explicit_fp16) {
201         return ToGlslType("float16_t", "f16vec", vec_size);
202       } else {
203         return precision_modifier + ToGlslType("float", "vec", vec_size);
204       }
205     case DataType::FLOAT32:
206       return precision_modifier + ToGlslType("float", "vec", vec_size);
207     case DataType::FLOAT64:
208       return precision_modifier + ToGlslType("double", "dvec", vec_size);
209     case DataType::INT8:
210     case DataType::INT16:
211     case DataType::INT32:
212     case DataType::INT64:
213       return precision_modifier + ToGlslType("int", "ivec", vec_size);
214     case DataType::UINT8:
215     case DataType::UINT16:
216     case DataType::UINT32:
217     case DataType::UINT64:
218       return precision_modifier + ToGlslType("uint", "uvec", vec_size);
219     case DataType::BOOL:
220       return ToGlslType("bool", "bvec", vec_size);
221     case DataType::UNKNOWN:
222       return "unknown";
223   }
224   return "unknown";
225 }
226 
227 }  // namespace gpu
228 }  // namespace tflite
229