1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
17
18 #include <stddef.h>
19
20 #include <string>
21
22 #include "absl/strings/str_cat.h"
23
24 namespace tflite {
25 namespace gpu {
26 namespace {
ToGlslType(const std::string & scalar_type,const std::string & vec_type,int vec_size)27 std::string ToGlslType(const std::string& scalar_type,
28 const std::string& vec_type, int vec_size) {
29 return vec_size == 1 ? scalar_type : absl::StrCat(vec_type, vec_size);
30 }
31
GetGlslPrecisionModifier(DataType data_type)32 std::string GetGlslPrecisionModifier(DataType data_type) {
33 switch (data_type) {
34 case DataType::UINT8:
35 case DataType::INT8:
36 return "lowp ";
37 case DataType::FLOAT16:
38 case DataType::INT16:
39 case DataType::UINT16:
40 return "mediump ";
41 case DataType::FLOAT32:
42 case DataType::INT32:
43 case DataType::UINT32:
44 return "highp ";
45 case DataType::BOOL:
46 return "";
47 default:
48 return "";
49 }
50 }
51 } // namespace
52
SizeOf(DataType data_type)53 size_t SizeOf(DataType data_type) {
54 switch (data_type) {
55 case DataType::UINT8:
56 case DataType::INT8:
57 case DataType::BOOL:
58 return 1;
59 case DataType::FLOAT16:
60 case DataType::INT16:
61 case DataType::UINT16:
62 return 2;
63 case DataType::FLOAT32:
64 case DataType::INT32:
65 case DataType::UINT32:
66 return 4;
67 case DataType::FLOAT64:
68 case DataType::INT64:
69 case DataType::UINT64:
70 return 8;
71 case DataType::UNKNOWN:
72 return 0;
73 }
74 return 0;
75 }
76
ToString(DataType data_type)77 std::string ToString(DataType data_type) {
78 switch (data_type) {
79 case DataType::FLOAT16:
80 return "float16";
81 case DataType::FLOAT32:
82 return "float32";
83 case DataType::FLOAT64:
84 return "float64";
85 case DataType::INT16:
86 return "int16";
87 case DataType::INT32:
88 return "int32";
89 case DataType::INT64:
90 return "int64";
91 case DataType::INT8:
92 return "int8";
93 case DataType::UINT16:
94 return "uint16";
95 case DataType::UINT32:
96 return "uint32";
97 case DataType::UINT64:
98 return "uint64";
99 case DataType::UINT8:
100 return "uint8";
101 case DataType::BOOL:
102 return "bool";
103 case DataType::UNKNOWN:
104 return "unknown";
105 }
106 return "undefined";
107 }
108
ToCLDataType(DataType data_type,int vec_size)109 std::string ToCLDataType(DataType data_type, int vec_size) {
110 const std::string postfix = vec_size == 1 ? "" : std::to_string(vec_size);
111 switch (data_type) {
112 case DataType::FLOAT16:
113 return "half" + postfix;
114 case DataType::FLOAT32:
115 return "float" + postfix;
116 case DataType::FLOAT64:
117 return "double" + postfix;
118 case DataType::INT16:
119 return "short" + postfix;
120 case DataType::INT32:
121 return "int" + postfix;
122 case DataType::INT64:
123 return "long" + postfix;
124 case DataType::INT8:
125 return "char" + postfix;
126 case DataType::UINT16:
127 return "ushort" + postfix;
128 case DataType::UINT32:
129 return "uint" + postfix;
130 case DataType::UINT64:
131 return "ulong" + postfix;
132 case DataType::UINT8:
133 return "uchar" + postfix;
134 case DataType::BOOL:
135 return "bool" + postfix;
136 case DataType::UNKNOWN:
137 return "unknown";
138 }
139 return "undefined";
140 }
141
ToMetalDataType(DataType data_type,int vec_size)142 std::string ToMetalDataType(DataType data_type, int vec_size) {
143 const std::string postfix = vec_size == 1 ? "" : std::to_string(vec_size);
144 switch (data_type) {
145 case DataType::FLOAT16:
146 return "half" + postfix;
147 case DataType::FLOAT32:
148 return "float" + postfix;
149 case DataType::FLOAT64:
150 return "double" + postfix;
151 case DataType::INT16:
152 return "short" + postfix;
153 case DataType::INT32:
154 return "int" + postfix;
155 case DataType::INT64:
156 return "long" + postfix;
157 case DataType::INT8:
158 return "char" + postfix;
159 case DataType::UINT16:
160 return "ushort" + postfix;
161 case DataType::UINT32:
162 return "uint" + postfix;
163 case DataType::UINT64:
164 return "ulong" + postfix;
165 case DataType::UINT8:
166 return "uchar" + postfix;
167 case DataType::BOOL:
168 return "bool" + postfix;
169 case DataType::UNKNOWN:
170 return "unknown";
171 }
172 return "undefined";
173 }
174
ToMetalTextureType(DataType data_type)175 DataType ToMetalTextureType(DataType data_type) {
176 switch (data_type) {
177 case DataType::FLOAT32:
178 case DataType::FLOAT16:
179 case DataType::INT32:
180 case DataType::INT16:
181 case DataType::UINT32:
182 case DataType::UINT16:
183 return data_type;
184 case DataType::INT8:
185 return DataType::INT16;
186 case DataType::UINT8:
187 case DataType::BOOL:
188 return DataType::UINT16;
189 default:
190 return DataType::UNKNOWN;
191 }
192 }
193
ToGlslShaderDataType(DataType data_type,int vec_size,bool add_precision,bool explicit_fp16)194 std::string ToGlslShaderDataType(DataType data_type, int vec_size,
195 bool add_precision, bool explicit_fp16) {
196 const std::string precision_modifier =
197 add_precision ? GetGlslPrecisionModifier(data_type) : "";
198 switch (data_type) {
199 case DataType::FLOAT16:
200 if (explicit_fp16) {
201 return ToGlslType("float16_t", "f16vec", vec_size);
202 } else {
203 return precision_modifier + ToGlslType("float", "vec", vec_size);
204 }
205 case DataType::FLOAT32:
206 return precision_modifier + ToGlslType("float", "vec", vec_size);
207 case DataType::FLOAT64:
208 return precision_modifier + ToGlslType("double", "dvec", vec_size);
209 case DataType::INT8:
210 case DataType::INT16:
211 case DataType::INT32:
212 case DataType::INT64:
213 return precision_modifier + ToGlslType("int", "ivec", vec_size);
214 case DataType::UINT8:
215 case DataType::UINT16:
216 case DataType::UINT32:
217 case DataType::UINT64:
218 return precision_modifier + ToGlslType("uint", "uvec", vec_size);
219 case DataType::BOOL:
220 return ToGlslType("bool", "bvec", vec_size);
221 case DataType::UNKNOWN:
222 return "unknown";
223 }
224 return "unknown";
225 }
226
227 } // namespace gpu
228 } // namespace tflite
229