1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 /**
10 * @file
11 *
12 * Forked from
13 * https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h
14 *
15 * See file comment in ../ScalarType.h.
16 *
17 * This file contains all of the non-critical parts of the original ScalarType.h
18 * that are not required for the core ExecuTorch runtime, but may be helpful for
19 * code that uses ScalarType.
20 */
21
22 #pragma once
23
24 #include <array>
25 #include <cinttypes>
26 #include <cstdint>
27 #include <limits>
28 #include <type_traits>
29
30 #include <executorch/runtime/platform/assert.h>
31
32 #ifdef USE_ATEN_LIB
33 // Note that a lot of the macros/functions defined in this ScalarTypeUtil.h file
34 // are also defined in c10/core/ScalarType.h, which is included via
35 // kernel_types.h when building in ATen mode. They tend to use different names
36 // and a different namespace, but if there are conflicts they should be resolved
37 // here.
38 #define ET_FORALL_SCALAR_TYPES AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS
39 #include <c10/core/ScalarType.h>
40 namespace executorch {
41 namespace aten {
42 using ScalarType = at::ScalarType;
43 } // namespace aten
44 } // namespace executorch
45 #else // !USE_ATEN_LIB
46 #include <executorch/runtime/core/portable_type/scalar_type.h>
47 #include <executorch/runtime/core/portable_type/string_view.h>
48 namespace executorch {
49 namespace aten {
50 using ScalarType = torch::executor::ScalarType;
51 using string_view = torch::executor::string_view;
52 } // namespace aten
53 } // namespace executorch
54 #endif // USE_ATEN_LIB
55 // DEPRECATED: The exec_aten:: namespace is deprecated. Use executorch::aten::
56 // instead.
57 namespace exec_aten = ::executorch::aten;
58
59 namespace executorch {
60 namespace runtime {
61
62 #if !defined(USE_ATEN_LIB)
63 // Util to figure out if the scalar type if one of the
64 // supported floating point types.
65 // In aten mode, aten lib already has these utils as part of
66 // its vec_base.h
67 template <typename T>
68 struct is_floating_point
69 : std::integral_constant<
70 bool,
71 std::is_floating_point<T>::value ||
72 std::is_same_v<T, torch::executor::Half> ||
73 std::is_same_v<T, torch::executor::BFloat16>> {};
74
75 // Util to figure out if the scalar type is one of the
76 // reduced precision floating point types.
77 template <typename T>
78 struct is_reduced_floating_point
79 : std::integral_constant<
80 bool,
81 std::is_same_v<T, torch::executor::Half> ||
82 std::is_same_v<T, torch::executor::BFloat16>> {};
83
84 template <typename T>
85 constexpr bool is_reduced_floating_point_v =
86 is_reduced_floating_point<T>::value;
87 #endif
88
89 /// Maps ScalarTypes to C++ types.
90 template <::executorch::aten::ScalarType N>
91 struct ScalarTypeToCppType;
92
93 #define SPECIALIZE_ScalarTypeToCppType(cpp_type, scalar_type) \
94 template <> \
95 struct ScalarTypeToCppType<::executorch::aten::ScalarType::scalar_type> { \
96 using type = cpp_type; \
97 };
98
99 ET_FORALL_SCALAR_TYPES(SPECIALIZE_ScalarTypeToCppType)
100
101 #undef SPECIALIZE_ScalarTypeToCppType
102
103 /// Maps C++ types to ScalarTypes.
104 template <typename T>
105 struct CppTypeToScalarType;
106
107 #define SPECIALIZE_CppTypeToScalarType(cpp_type, scalar_type) \
108 template <> \
109 struct CppTypeToScalarType<cpp_type> \
110 : std::integral_constant< \
111 ::executorch::aten::ScalarType, \
112 ::executorch::aten::ScalarType::scalar_type> {};
113
ET_FORALL_SCALAR_TYPES(SPECIALIZE_CppTypeToScalarType)114 ET_FORALL_SCALAR_TYPES(SPECIALIZE_CppTypeToScalarType)
115
116 #undef SPECIALIZE_CppTypeToScalarType
117
118 //
119 // Macros that iterate across different subsets of ScalarTypes.
120 //
121 // See ET_FORALL_SCALAR_TYPES in ScalarType.h to iterate across all ScalarType
122 // names and types.
123 //
124 // For all of these macros, the final `_` parameter is the name of another macro
125 // that takes two parameters: the name of a C type, and the name of the
126 // corresponding ScalarType enumerator.
127 //
128 // Note that these macros should use fully-qualified namespaces (starting with
129 // `::`) to ensure that they can be called safely in any arbitrary namespace.
130 //
131
132 // In this context, "INT" means integer C types, which is why the quantized
133 // integer types are not included.
134 #define ET_FORALL_INT_TYPES(_) \
135 _(uint8_t, Byte) \
136 _(int8_t, Char) \
137 _(int16_t, Short) \
138 _(int32_t, Int) \
139 _(int64_t, Long)
140
141 // Here `ANOTHER_INPUT` should be another variable to be forwarded to a given
142 // function.
143 #define ET_FORALL_INT_TYPES_WITH(ANOTHER_INPUT, _) \
144 _(ANOTHER_INPUT, uint8_t, Byte) \
145 _(ANOTHER_INPUT, int8_t, Char) \
146 _(ANOTHER_INPUT, int16_t, Short) \
147 _(ANOTHER_INPUT, int32_t, Int) \
148 _(ANOTHER_INPUT, int64_t, Long)
149
150 #define ET_FORALL_INT_TYPES_WITH2(ANOTHER_INPUT1, ANOTHER_INPUT2, _) \
151 _(ANOTHER_INPUT1, ANOTHER_INPUT2, uint8_t, Byte) \
152 _(ANOTHER_INPUT1, ANOTHER_INPUT2, int8_t, Char) \
153 _(ANOTHER_INPUT1, ANOTHER_INPUT2, int16_t, Short) \
154 _(ANOTHER_INPUT1, ANOTHER_INPUT2, int32_t, Int) \
155 _(ANOTHER_INPUT1, ANOTHER_INPUT2, int64_t, Long)
156
157 #define ET_FORALL_INT_TYPES_AND(SCALARTYPE, _) \
158 _(uint8_t, Byte) \
159 _(int8_t, Char) \
160 _(int16_t, Short) \
161 _(int32_t, Int) \
162 _(int64_t, Long) \
163 _(::executorch::runtime::ScalarTypeToCppType< \
164 ::executorch::aten::ScalarType::SCALARTYPE>::type, \
165 SCALARTYPE)
166
167 // In this context, "FLOAT" means float C types, which is why BFloat16 is not
168 // included.
169 #define ET_FORALL_FLOAT_TYPES(_) \
170 _(float, Float) \
171 _(double, Double)
172
173 #define ET_FORALL_FLOAT_TYPES_AND(SCALARTYPE, _) \
174 _(float, Float) \
175 _(double, Double) \
176 _(::executorch::runtime::ScalarTypeToCppType< \
177 ::executorch::aten::ScalarType::SCALARTYPE>::type, \
178 SCALARTYPE)
179
180 #define ET_FORALL_FLOAT_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \
181 _(float, Float) \
182 _(double, Double) \
183 _(::executorch::runtime::ScalarTypeToCppType< \
184 ::executorch::aten::ScalarType::SCALARTYPE1>::type, \
185 SCALARTYPE1) \
186 _(::executorch::runtime::ScalarTypeToCppType< \
187 ::executorch::aten::ScalarType::SCALARTYPE2>::type, \
188 SCALARTYPE2)
189
190 #define ET_FORALL_FLOATH_TYPES(_) ET_FORALL_FLOAT_TYPES_AND(Half, _)
191
192 #define ET_FORALL_FLOATHBF16_TYPES(_) \
193 ET_FORALL_FLOAT_TYPES_AND2(Half, BFloat16, _)
194
195 // Here `ANOTHER_INPUT` should be another variable to be forwarded to a given
196 // function. Not to be confused with another scalar type as in
197 // `ET_FORALL_FLOAT_TYPES_AND`.
198 #define ET_FORALL_FLOAT_TYPES_WITH(ANOTHER_INPUT, _) \
199 _(ANOTHER_INPUT, float, Float) \
200 _(ANOTHER_INPUT, double, Double)
201
202 #define ET_FORALL_FLOAT_TYPES_WITH2(ANOTHER_INPUT1, ANOTHER_INPUT2, _) \
203 _(ANOTHER_INPUT1, ANOTHER_INPUT2, float, Float) \
204 _(ANOTHER_INPUT1, ANOTHER_INPUT2, double, Double)
205
206 #define ET_FORALL_FLOATHBF16_TYPES_WITH2(ANOTHER_INPUT1, ANOTHER_INPUT2, _) \
207 _(ANOTHER_INPUT1, ANOTHER_INPUT2, float, Float) \
208 _(ANOTHER_INPUT1, ANOTHER_INPUT2, double, Double) \
209 _(ANOTHER_INPUT1, ANOTHER_INPUT2, ::executorch::aten::Half, Half) \
210 _(ANOTHER_INPUT1, ANOTHER_INPUT2, ::executorch::aten::BFloat16, BFloat16)
211
212 // In this context, "REAL" means integer/float C types, which is why BFloat16
213 // and Half are not included.
214 #define ET_FORALL_REAL_TYPES(_) \
215 _(uint8_t, Byte) \
216 _(int8_t, Char) \
217 _(int16_t, Short) \
218 _(int32_t, Int) \
219 _(int64_t, Long) \
220 _(float, Float) \
221 _(double, Double)
222
223 // Here `ANOTHER_INPUT` should be another variable to be forwarded to a given
224 // function. Not to be confused with another scalar type as in
225 // `ET_FORALL_REAL_TYPES_AND`.
226 #define ET_FORALL_REAL_TYPES_WITH(ANOTHER_INPUT, _) \
227 _(ANOTHER_INPUT, uint8_t, Byte) \
228 _(ANOTHER_INPUT, int8_t, Char) \
229 _(ANOTHER_INPUT, int16_t, Short) \
230 _(ANOTHER_INPUT, int32_t, Int) \
231 _(ANOTHER_INPUT, int64_t, Long) \
232 _(ANOTHER_INPUT, float, Float) \
233 _(ANOTHER_INPUT, double, Double)
234
235 #define ET_FORALL_REAL_TYPES_WITH2(ANOTHER_INPUT1, ANOTHER_INPUT2, _) \
236 _(ANOTHER_INPUT1, ANOTHER_INPUT2, uint8_t, Byte) \
237 _(ANOTHER_INPUT1, ANOTHER_INPUT2, int8_t, Char) \
238 _(ANOTHER_INPUT1, ANOTHER_INPUT2, int16_t, Short) \
239 _(ANOTHER_INPUT1, ANOTHER_INPUT2, int32_t, Int) \
240 _(ANOTHER_INPUT1, ANOTHER_INPUT2, int64_t, Long) \
241 _(ANOTHER_INPUT1, ANOTHER_INPUT2, float, Float) \
242 _(ANOTHER_INPUT1, ANOTHER_INPUT2, double, Double)
243
244 #define ET_FORALL_REALHBF16_TYPES_WITH2(ANOTHER_INPUT1, ANOTHER_INPUT2, _) \
245 _(ANOTHER_INPUT1, ANOTHER_INPUT2, uint8_t, Byte) \
246 _(ANOTHER_INPUT1, ANOTHER_INPUT2, int8_t, Char) \
247 _(ANOTHER_INPUT1, ANOTHER_INPUT2, int16_t, Short) \
248 _(ANOTHER_INPUT1, ANOTHER_INPUT2, int32_t, Int) \
249 _(ANOTHER_INPUT1, ANOTHER_INPUT2, int64_t, Long) \
250 _(ANOTHER_INPUT1, ANOTHER_INPUT2, float, Float) \
251 _(ANOTHER_INPUT1, ANOTHER_INPUT2, double, Double) \
252 _(ANOTHER_INPUT1, ANOTHER_INPUT2, ::executorch::aten::Half, Half) \
253 _(ANOTHER_INPUT1, ANOTHER_INPUT2, ::executorch::aten::BFloat16, BFloat16)
254
255 // For macros that take `SCALARTYPEn` parameters, those parameters should be
256 // an unquoted/unqualified enumerator name like `Int` or `Float`.
257 #define ET_FORALL_REAL_TYPES_AND(SCALARTYPE, _) \
258 _(uint8_t, Byte) \
259 _(int8_t, Char) \
260 _(int16_t, Short) \
261 _(int32_t, Int) \
262 _(int64_t, Long) \
263 _(float, Float) \
264 _(double, Double) \
265 _(::executorch::runtime::ScalarTypeToCppType< \
266 ::executorch::aten::ScalarType::SCALARTYPE>::type, \
267 SCALARTYPE)
268
269 #define ET_FORALL_REAL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \
270 _(uint8_t, Byte) \
271 _(int8_t, Char) \
272 _(int16_t, Short) \
273 _(int32_t, Int) \
274 _(int64_t, Long) \
275 _(float, Float) \
276 _(double, Double) \
277 _(::executorch::runtime::ScalarTypeToCppType< \
278 ::executorch::aten::ScalarType::SCALARTYPE1>::type, \
279 SCALARTYPE1) \
280 _(::executorch::runtime::ScalarTypeToCppType< \
281 ::executorch::aten::ScalarType::SCALARTYPE2>::type, \
282 SCALARTYPE2)
283
284 #define ET_FORALL_REALH_TYPES(_) ET_FORALL_REAL_TYPES_AND(Half, _)
285
286 #define ET_FORALL_REALHBF16_TYPES(_) \
287 ET_FORALL_REAL_TYPES_AND2(Half, BFloat16, _)
288
289 #define ET_FORALL_REALHBBF16_TYPES(_) \
290 ET_FORALL_REAL_TYPES_AND3(Bool, Half, BFloat16, _)
291
292 #define ET_FORALL_REAL_TYPES_AND_WITH(SCALARTYPE, ANOTHER_INPUT, _) \
293 _(ANOTHER_INPUT, uint8_t, Byte) \
294 _(ANOTHER_INPUT, int8_t, Char) \
295 _(ANOTHER_INPUT, int16_t, Short) \
296 _(ANOTHER_INPUT, int32_t, Int) \
297 _(ANOTHER_INPUT, int64_t, Long) \
298 _(ANOTHER_INPUT, float, Float) \
299 _(ANOTHER_INPUT, double, Double) \
300 _(ANOTHER_INPUT, \
301 ::executorch::runtime::ScalarTypeToCppType< \
302 ::executorch::aten::ScalarType::SCALARTYPE>::type, \
303 SCALARTYPE)
304
305 #define ET_FORALL_REAL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \
306 _(uint8_t, Byte) \
307 _(int8_t, Char) \
308 _(int16_t, Short) \
309 _(int32_t, Int) \
310 _(int64_t, Long) \
311 _(float, Float) \
312 _(double, Double) \
313 _(::executorch::runtime::ScalarTypeToCppType< \
314 ::executorch::aten::ScalarType::SCALARTYPE1>::type, \
315 SCALARTYPE1) \
316 _(::executorch::runtime::ScalarTypeToCppType< \
317 ::executorch::aten::ScalarType::SCALARTYPE2>::type, \
318 SCALARTYPE2)
319
320 #define ET_FORALL_REAL_TYPES_AND3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, _) \
321 _(uint8_t, Byte) \
322 _(int8_t, Char) \
323 _(int16_t, Short) \
324 _(int32_t, Int) \
325 _(int64_t, Long) \
326 _(float, Float) \
327 _(double, Double) \
328 _(::executorch::runtime::ScalarTypeToCppType< \
329 ::executorch::aten::ScalarType::SCALARTYPE1>::type, \
330 SCALARTYPE1) \
331 _(::executorch::runtime::ScalarTypeToCppType< \
332 ::executorch::aten::ScalarType::SCALARTYPE2>::type, \
333 SCALARTYPE2) \
334 _(::executorch::runtime::ScalarTypeToCppType< \
335 ::executorch::aten::ScalarType::SCALARTYPE3>::type, \
336 SCALARTYPE3)
337
338 #define ET_FORALL_QINT_TYPES(_) \
339 _(::torch::executor::qint8, QInt8) \
340 _(::torch::executor::quint8, QUInt8) \
341 _(::torch::executor::qint32, QInt32) \
342 _(::torch::executor::quint4x2, QUInt4x2) \
343 _(::torch::executor::quint2x4, QUInt2x4)
344
345 // In this context, "COMPLEX" means complex types based on primitive C types,
346 // which is why ComplexHalf is not included.
347 #define ET_FORALL_COMPLEX_TYPES(_) \
348 _(::torch::executor::complex<float>, ComplexFloat) \
349 _(::torch::executor::complex<double>, ComplexDouble)
350
351 //
352 // Utility functions to retrieve metadata for a given ScalarType
353 //
354
355 /**
356 * Returns true if the parameter is one of the values covered by
357 * ET_FORALL_SCALAR_TYPES.
358 */
359 inline bool isValid(::executorch::aten::ScalarType type) {
360 return static_cast<int8_t>(type) >= 0 &&
361 type < ::executorch::aten::ScalarType::NumOptions &&
362 type != ::executorch::aten::ScalarType::Undefined;
363 }
364
365 /**
366 * Returns the name of a ScalarType as a C string.
367 *
368 * @param[in] t The type to get the name of.
369 * @return The name of the type, or "UNKNOWN_SCALAR" if the type is not known.
370 */
toString(::executorch::aten::ScalarType t)371 inline const char* toString(::executorch::aten::ScalarType t) {
372 #define DEFINE_CASE(_, name) \
373 case ::executorch::aten::ScalarType::name: \
374 return #name;
375
376 switch (t) {
377 ET_FORALL_SCALAR_TYPES(DEFINE_CASE)
378 case ::executorch::aten::ScalarType::Undefined:
379 return "Undefined";
380 default:
381 return "UNKNOWN_SCALAR";
382 }
383 #undef DEFINE_CASE
384 }
385
386 /**
387 * Returns the size in bytes of the C type associated with the ScalarType.
388 *
389 * Calls ET_CHECK_MSG() if the type is unknown or is ScalarType::Undefined.
390 *
391 * @param[in] t The type to get the underlying C type size of.
392 * @return The size of the associated C type in bytes.
393 */
elementSize(::executorch::aten::ScalarType t)394 inline size_t elementSize(::executorch::aten::ScalarType t) {
395 #define CASE_ELEMENTSIZE_CASE(ctype, name) \
396 case ::executorch::aten::ScalarType::name: \
397 return sizeof(ctype);
398
399 switch (t) {
400 ET_FORALL_SCALAR_TYPES(CASE_ELEMENTSIZE_CASE)
401 default:
402 ET_CHECK_MSG(false, "Unknown ScalarType %" PRId8, static_cast<int8_t>(t));
403 }
404 #undef CASE_ELEMENTSIZE_CASE
405 }
406
isIntegralType(::executorch::aten::ScalarType t,bool includeBool)407 inline constexpr bool isIntegralType(
408 ::executorch::aten::ScalarType t,
409 bool includeBool) {
410 return (includeBool && t == ::executorch::aten::ScalarType::Bool) ||
411 (t == ::executorch::aten::ScalarType::Byte ||
412 t == ::executorch::aten::ScalarType::Char ||
413 t == ::executorch::aten::ScalarType::Int ||
414 t == ::executorch::aten::ScalarType::Long ||
415 t == ::executorch::aten::ScalarType::Short);
416 }
417
418 template <typename T, bool includeBool>
419 struct is_integral_type
420 : public std::integral_constant<
421 bool,
422 isIntegralType(CppTypeToScalarType<T>::value, includeBool)> {};
423
isFloatingType(::executorch::aten::ScalarType t)424 inline constexpr bool isFloatingType(::executorch::aten::ScalarType t) {
425 return (
426 t == ::executorch::aten::ScalarType::Double ||
427 t == ::executorch::aten::ScalarType::Float ||
428 t == ::executorch::aten::ScalarType::Half ||
429 t == ::executorch::aten::ScalarType::BFloat16);
430 }
431
isRealType(::executorch::aten::ScalarType t)432 inline bool isRealType(::executorch::aten::ScalarType t) {
433 return (
434 t == ::executorch::aten::ScalarType::Byte ||
435 t == ::executorch::aten::ScalarType::Char ||
436 t == ::executorch::aten::ScalarType::Short ||
437 t == ::executorch::aten::ScalarType::Int ||
438 t == ::executorch::aten::ScalarType::Long ||
439 t == ::executorch::aten::ScalarType::Float ||
440 t == ::executorch::aten::ScalarType::Double);
441 }
442
isRealHType(::executorch::aten::ScalarType t)443 inline bool isRealHType(::executorch::aten::ScalarType t) {
444 return (
445 t == ::executorch::aten::ScalarType::Byte ||
446 t == ::executorch::aten::ScalarType::Char ||
447 t == ::executorch::aten::ScalarType::Short ||
448 t == ::executorch::aten::ScalarType::Int ||
449 t == ::executorch::aten::ScalarType::Long ||
450 t == ::executorch::aten::ScalarType::Float ||
451 t == ::executorch::aten::ScalarType::Double ||
452 t == ::executorch::aten::ScalarType::Half);
453 }
454
isRealHBType(::executorch::aten::ScalarType t)455 inline bool isRealHBType(::executorch::aten::ScalarType t) {
456 return (isRealHType(t) || t == ::executorch::aten::ScalarType::Bool);
457 }
458
isRealHBF16Type(::executorch::aten::ScalarType t)459 inline bool isRealHBF16Type(::executorch::aten::ScalarType t) {
460 return (isRealHType(t) || t == ::executorch::aten::ScalarType::BFloat16);
461 }
462
isRealHBBF16Type(::executorch::aten::ScalarType t)463 inline bool isRealHBBF16Type(::executorch::aten::ScalarType t) {
464 return (isRealHBType(t) || t == ::executorch::aten::ScalarType::BFloat16);
465 }
466
isComplexType(::executorch::aten::ScalarType t)467 inline constexpr bool isComplexType(::executorch::aten::ScalarType t) {
468 return (
469 t == ::executorch::aten::ScalarType::ComplexHalf ||
470 t == ::executorch::aten::ScalarType::ComplexFloat ||
471 t == ::executorch::aten::ScalarType::ComplexDouble);
472 }
473
474 template <typename T>
475 struct is_complex_type : std::integral_constant<
476 bool,
477 isComplexType(CppTypeToScalarType<T>::value)> {};
478
isBitsType(::executorch::aten::ScalarType t)479 constexpr bool isBitsType(::executorch::aten::ScalarType t) {
480 return t == ::executorch::aten::ScalarType::Bits1x8 ||
481 t == ::executorch::aten::ScalarType::Bits2x4 ||
482 t == ::executorch::aten::ScalarType::Bits4x2 ||
483 t == ::executorch::aten::ScalarType::Bits8 ||
484 t == ::executorch::aten::ScalarType::Bits16;
485 }
486
487 template <typename T>
488 struct is_bits_type
489 : std::integral_constant<bool, isBitsType(CppTypeToScalarType<T>::value)> {
490 };
491
isQIntType(::executorch::aten::ScalarType t)492 constexpr bool isQIntType(::executorch::aten::ScalarType t) {
493 // Don't forget to extend this when adding new QInt types
494 return t == ::executorch::aten::ScalarType::QInt8 ||
495 t == ::executorch::aten::ScalarType::QUInt8 ||
496 t == ::executorch::aten::ScalarType::QInt32 ||
497 t == ::executorch::aten::ScalarType::QUInt4x2 ||
498 t == ::executorch::aten::ScalarType::QUInt2x4;
499 }
500
501 template <typename T>
502 struct is_qint_type
503 : std::integral_constant<bool, isQIntType(CppTypeToScalarType<T>::value)> {
504 };
505
isFloat8Type(::executorch::aten::ScalarType t)506 constexpr bool isFloat8Type(::executorch::aten::ScalarType t) {
507 // Don't forget to extend this when adding new QInt types
508 return t == ::executorch::aten::ScalarType::Float8_e5m2 ||
509 t == ::executorch::aten::ScalarType::Float8_e4m3fn ||
510 t == ::executorch::aten::ScalarType::Float8_e5m2fnuz ||
511 t == ::executorch::aten::ScalarType::Float8_e4m3fnuz;
512 }
513
514 template <typename T>
515 struct is_float8_type
516 : std::
517 integral_constant<bool, isFloat8Type(CppTypeToScalarType<T>::value)> {
518 };
519
isBarebonesUnsignedType(::executorch::aten::ScalarType t)520 constexpr bool isBarebonesUnsignedType(::executorch::aten::ScalarType t) {
521 // Don't forget to extend this when adding new QInt types
522 return t == ::executorch::aten::ScalarType::UInt16 ||
523 t == ::executorch::aten::ScalarType::UInt32 ||
524 t == ::executorch::aten::ScalarType::UInt64;
525 }
526
527 template <typename T>
528 struct is_barebones_unsigned_type
529 : std::integral_constant<
530 bool,
531 isBarebonesUnsignedType(CppTypeToScalarType<T>::value)> {};
532
toQIntType(::executorch::aten::ScalarType t)533 inline ::executorch::aten::ScalarType toQIntType(
534 ::executorch::aten::ScalarType t) {
535 switch (t) {
536 case ::executorch::aten::ScalarType::Byte:
537 return ::executorch::aten::ScalarType::QUInt8;
538 case ::executorch::aten::ScalarType::Char:
539 return ::executorch::aten::ScalarType::QInt8;
540 case ::executorch::aten::ScalarType::Int:
541 return ::executorch::aten::ScalarType::QInt32;
542 default:
543 return t;
544 }
545 }
546
toUnderlying(::executorch::aten::ScalarType t)547 inline ::executorch::aten::ScalarType toUnderlying(
548 ::executorch::aten::ScalarType t) {
549 switch (t) {
550 case ::executorch::aten::ScalarType::QUInt8:
551 return ::executorch::aten::ScalarType::Byte;
552 case ::executorch::aten::ScalarType::QInt8:
553 return ::executorch::aten::ScalarType::Char;
554 case ::executorch::aten::ScalarType::QInt32:
555 return ::executorch::aten::ScalarType::Int;
556 case ::executorch::aten::ScalarType::QUInt4x2:
557 return ::executorch::aten::ScalarType::Byte;
558 case ::executorch::aten::ScalarType::QUInt2x4:
559 return ::executorch::aten::ScalarType::Byte;
560 default:
561 return t;
562 }
563 }
564
isSignedType(::executorch::aten::ScalarType t)565 inline bool isSignedType(::executorch::aten::ScalarType t) {
566 ET_CHECK_MSG(
567 !::executorch::runtime::isQIntType(t),
568 "isSignedType not supported for quantized types like %" PRId8,
569 static_cast<int8_t>(t));
570 #define CASE_SIGNED(ctype, name) \
571 case ::executorch::aten::ScalarType::name: \
572 return std::numeric_limits<ctype>::is_signed;
573
574 switch (t) {
575 case ::executorch::aten::ScalarType::ComplexHalf:
576 case ::executorch::aten::ScalarType::ComplexFloat:
577 case ::executorch::aten::ScalarType::ComplexDouble:
578 return true;
579 ET_FORALL_REAL_TYPES_AND3(Half, Bool, BFloat16, CASE_SIGNED)
580 default:
581 ET_CHECK_MSG(false, "Unknown ScalarType %" PRId8, static_cast<int8_t>(t));
582 }
583 #undef CASE_SIGNED
584 }
585
isUnderlying(::executorch::aten::ScalarType type,::executorch::aten::ScalarType qtype)586 inline bool isUnderlying(
587 ::executorch::aten::ScalarType type,
588 ::executorch::aten::ScalarType qtype) {
589 return type == ::executorch::runtime::toUnderlying(qtype);
590 }
591
toRealValueType(::executorch::aten::ScalarType t)592 inline ::executorch::aten::ScalarType toRealValueType(
593 ::executorch::aten::ScalarType t) {
594 switch (t) {
595 case ::executorch::aten::ScalarType::ComplexHalf:
596 return ::executorch::aten::ScalarType::Half;
597 case ::executorch::aten::ScalarType::ComplexFloat:
598 return ::executorch::aten::ScalarType::Float;
599 case ::executorch::aten::ScalarType::ComplexDouble:
600 return ::executorch::aten::ScalarType::Double;
601 default:
602 return t;
603 }
604 }
605
toComplexType(::executorch::aten::ScalarType t)606 inline ::executorch::aten::ScalarType toComplexType(
607 ::executorch::aten::ScalarType t) {
608 switch (t) {
609 case ::executorch::aten::ScalarType::BFloat16:
610 // BFloat16 has range equivalent to Float,
611 // so we map it to ComplexFloat.
612 return ::executorch::aten::ScalarType::ComplexFloat;
613 case ::executorch::aten::ScalarType::Half:
614 return ::executorch::aten::ScalarType::ComplexHalf;
615 case ::executorch::aten::ScalarType::Float:
616 return ::executorch::aten::ScalarType::ComplexFloat;
617 case ::executorch::aten::ScalarType::Double:
618 return ::executorch::aten::ScalarType::ComplexDouble;
619 case ::executorch::aten::ScalarType::ComplexHalf:
620 return ::executorch::aten::ScalarType::ComplexHalf;
621 case ::executorch::aten::ScalarType::ComplexFloat:
622 return ::executorch::aten::ScalarType::ComplexFloat;
623 case ::executorch::aten::ScalarType::ComplexDouble:
624 return ::executorch::aten::ScalarType::ComplexDouble;
625 default:
626 ET_CHECK_MSG(
627 false,
628 "Unknown Complex ScalarType for %" PRId8,
629 static_cast<int8_t>(t));
630 }
631 }
632
633 /**
634 * Encodes type casting rules that are consistent with ATen behaviour.
635 */
canCast(const::executorch::aten::ScalarType from,const::executorch::aten::ScalarType to)636 inline constexpr bool canCast(
637 const ::executorch::aten::ScalarType from,
638 const ::executorch::aten::ScalarType to) {
639 // Disallow complex -> non-complex
640 return !(::executorch::runtime::isComplexType(from) &&
641 !::executorch::runtime::isComplexType(to)) &&
642 // Disallow float -> integral
643 !(::executorch::runtime::isFloatingType(from) &&
644 ::executorch::runtime::isIntegralType(to, /*includeBool=*/false)) &&
645 // Treat bool as a special category. Disallow non-bool -> bool
646 !(from != ::executorch::aten::ScalarType::Bool &&
647 to == ::executorch::aten::ScalarType::Bool);
648 }
649
650 template <typename T1, typename T2>
651 struct can_cast : std::integral_constant<
652 bool,
653 canCast(
654 CppTypeToScalarType<T1>::value,
655 CppTypeToScalarType<T2>::value)> {};
656
657 /**
658 * When casting from floating point to integral type, if the floating value is
659 * outside the integral type range, then an error is thrown if sanitization is
660 * enabled. To circumvent this, we cast the floating point to int64_t first.
661 */
662 template <
663 typename To,
664 typename From,
665 std::enable_if_t<
666 (std::is_floating_point<From>::value && std::is_integral<To>::value),
667 int> = 0>
convert(From val)668 To convert(From val) {
669 return static_cast<To>(static_cast<int64_t>(val));
670 }
671
672 template <
673 typename To,
674 typename From,
675 std::enable_if_t<
676 !(std::is_floating_point<From>::value && std::is_integral<To>::value),
677 int> = 0>
convert(From val)678 To convert(From val) {
679 return static_cast<To>(val);
680 }
681
682 namespace internal {
683 // This is generated according to NumPy's promote_types
684 inline constexpr auto u1 = ::executorch::aten::ScalarType::Byte;
685 inline constexpr auto i1 = ::executorch::aten::ScalarType::Char;
686 inline constexpr auto i2 = ::executorch::aten::ScalarType::Short;
687 inline constexpr auto i4 = ::executorch::aten::ScalarType::Int;
688 inline constexpr auto i8 = ::executorch::aten::ScalarType::Long;
689 inline constexpr auto f2 = ::executorch::aten::ScalarType::Half;
690 inline constexpr auto f4 = ::executorch::aten::ScalarType::Float;
691 inline constexpr auto f8 = ::executorch::aten::ScalarType::Double;
692 inline constexpr auto c2 = ::executorch::aten::ScalarType::ComplexHalf;
693 inline constexpr auto c4 = ::executorch::aten::ScalarType::ComplexFloat;
694 inline constexpr auto c8 = ::executorch::aten::ScalarType::ComplexDouble;
695 inline constexpr auto b1 = ::executorch::aten::ScalarType::Bool;
696 inline constexpr auto bf = ::executorch::aten::ScalarType::BFloat16;
697
698 using U1 =
699 typename ScalarTypeToCppType<::executorch::aten::ScalarType::Byte>::type;
700 using I1 =
701 typename ScalarTypeToCppType<::executorch::aten::ScalarType::Char>::type;
702 using I2 =
703 typename ScalarTypeToCppType<::executorch::aten::ScalarType::Short>::type;
704 using I4 =
705 typename ScalarTypeToCppType<::executorch::aten::ScalarType::Int>::type;
706 using I8 =
707 typename ScalarTypeToCppType<::executorch::aten::ScalarType::Long>::type;
708 using F2 =
709 typename ScalarTypeToCppType<::executorch::aten::ScalarType::Half>::type;
710 using F4 =
711 typename ScalarTypeToCppType<::executorch::aten::ScalarType::Float>::type;
712 using F8 =
713 typename ScalarTypeToCppType<::executorch::aten::ScalarType::Double>::type;
714 using C2 = typename ScalarTypeToCppType<
715 ::executorch::aten::ScalarType::ComplexHalf>::type;
716 using C4 = typename ScalarTypeToCppType<
717 ::executorch::aten::ScalarType::ComplexFloat>::type;
718 using C8 = typename ScalarTypeToCppType<
719 ::executorch::aten::ScalarType::ComplexDouble>::type;
720 using B1 =
721 typename ScalarTypeToCppType<::executorch::aten::ScalarType::Bool>::type;
722 using BF = typename ScalarTypeToCppType<
723 ::executorch::aten::ScalarType::BFloat16>::type;
724
725 inline constexpr std::array<::executorch::aten::ScalarType, 13> index2dtype = {
726 {u1, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, b1, bf}};
727
728 constexpr std::array<
729 int64_t,
730 static_cast<size_t>(::executorch::aten::ScalarType::NumOptions)>
calculate_dtype2index()731 calculate_dtype2index() {
732 std::array<
733 int64_t,
734 static_cast<size_t>(::executorch::aten::ScalarType::NumOptions)>
735 inverse = {};
736 for (int64_t i = 0;
737 i < static_cast<int64_t>(::executorch::aten::ScalarType::NumOptions);
738 i++) {
739 inverse[i] = -1;
740 }
741 for (int64_t i = 0; i < static_cast<int64_t>(index2dtype.size()); i++) {
742 inverse[static_cast<int64_t>(index2dtype[i])] = i;
743 }
744 return inverse;
745 }
746
747 inline constexpr auto dtype2index = calculate_dtype2index();
748 inline constexpr int NUM_PROMOTE_TYPES = 13;
749 // Should match _promoteTypesLookup in c10/core/ScalarType.cpp so that
750 // we match PyTorch core type promotion semantics.
751 inline constexpr ::executorch::aten::ScalarType
752 promoteTypesLookup[NUM_PROMOTE_TYPES][NUM_PROMOTE_TYPES] = {
753 /* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 bf*/
754 /* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, u1, bf},
755 /* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, i1, bf},
756 /* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, i2, bf},
757 /* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, c2, c4, c8, i4, bf},
758 /* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, c2, c4, c8, i8, bf},
759 /* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, c2, c4, c8, f2, f4},
760 /* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, c4, c4, c8, f4, f4},
761 /* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, c8, c8, c8, f8, f8},
762 /* c2 */ {c2, c2, c2, c2, c2, c2, c4, c8, c2, c4, c8, c2, c4},
763 /* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c4, c8, c4, c4},
764 /* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8},
765 /* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, b1, bf},
766 /* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, c4, c4, c8, bf, bf},
767 };
768
769 } // namespace internal
770
771 /**
772 * Implements type promotion rules that are consistent with ATen behaviour,
773 * which in turn is consistent with NumPy's promote_types.
774 * If half_to_float is set to true, then half and bfloat16 will be promoted to
775 * float instead
776 */
777 inline constexpr ::executorch::aten::ScalarType promoteTypes(
778 ::executorch::aten::ScalarType a,
779 ::executorch::aten::ScalarType b,
780 bool half_to_float = false) {
781 // For QInt types, only allow exact match
782 if (::executorch::runtime::isQIntType(a) && a == b) {
783 return a;
784 }
785 if (::executorch::runtime::isQIntType(a) ||
786 ::executorch::runtime::isQIntType(b)) {
787 ET_CHECK_MSG(false, "promoteTypes not valid for quantized dtypes");
788 }
789
790 // For Bits types, only allow exact match
791 if (::executorch::runtime::isBitsType(a) && a == b) {
792 return a;
793 }
794 if (::executorch::runtime::isBitsType(a) ||
795 ::executorch::runtime::isBitsType(b)) {
796 ET_CHECK_MSG(false, "promoteTypes not valid for bits dtypes");
797 }
798
799 // For Float8 types, only allow exact match
800 if (::executorch::runtime::isFloat8Type(a) && a == b) {
801 return a;
802 }
803 if (::executorch::runtime::isFloat8Type(a) ||
804 ::executorch::runtime::isFloat8Type(b)) {
805 ET_CHECK_MSG(false, "promoteTypes not valid for float8 dtypes");
806 }
807
808 // For barebones uint types, only allow exact match
809 if (::executorch::runtime::isBarebonesUnsignedType(a) && a == b) {
810 return a;
811 }
812 if (::executorch::runtime::isBarebonesUnsignedType(a) ||
813 ::executorch::runtime::isBarebonesUnsignedType(b)) {
814 ET_CHECK_MSG(false, "promoteTypes not valid for barebone unsigned dtypes");
815 }
816
817 auto ix_a = ::executorch::runtime::internal::dtype2index[(int)a];
818 ET_CHECK(ix_a != -1);
819 auto ix_b = ::executorch::runtime::internal::dtype2index[(int)b];
820 ET_CHECK(ix_b != -1);
821 ::executorch::aten::ScalarType promoted_type =
822 ::executorch::runtime::internal::promoteTypesLookup[ix_a][ix_b];
823
824 if (half_to_float &&
825 (promoted_type == ::executorch::aten::ScalarType::Half ||
826 promoted_type == ::executorch::aten::ScalarType::BFloat16)) {
827 promoted_type = ::executorch::aten::ScalarType::Float;
828 }
829
830 return promoted_type;
831 }
832
833 template <typename T1, typename T2, bool half_to_float = false>
834 struct promote_types {
835 private:
836 static_assert(
837 std::is_same_v<T1, T2> ||
838 (!is_qint_type<T1>::value && !is_qint_type<T2>::value),
839 "promote_types not valid for quantized dtypes");
840 static_assert(
841 std::is_same_v<T1, T2> ||
842 (!is_bits_type<T1>::value && !is_bits_type<T2>::value),
843 "promote_types not valid for bits dtypes");
844 static_assert(
845 std::is_same_v<T1, T2> ||
846 (!is_float8_type<T1>::value && !is_float8_type<T2>::value),
847 "promote_types not valid for float8 dtypes");
848 static_assert(
849 std::is_same_v<T1, T2> ||
850 (!is_barebones_unsigned_type<T1>::value &&
851 !is_barebones_unsigned_type<T2>::value),
852 "promote_types not valid for barebones unsigned dtypes");
853
854 using promoted_type_not_respecting_half_to_float =
855 typename ScalarTypeToCppType<promoteTypes(
856 CppTypeToScalarType<T1>::value,
857 CppTypeToScalarType<T2>::value)>::type;
858
859 public:
860 using type = std::conditional_t<
861 half_to_float &&
862 (std::is_same_v<
863 promoted_type_not_respecting_half_to_float,
864 typename ScalarTypeToCppType<
865 ::executorch::aten::ScalarType::Half>::type> ||
866 std::is_same_v<
867 promoted_type_not_respecting_half_to_float,
868 typename ScalarTypeToCppType<
869 ::executorch::aten::ScalarType::BFloat16>::type>),
870 typename ScalarTypeToCppType<::executorch::aten::ScalarType::Float>::type,
871 promoted_type_not_respecting_half_to_float>;
872 };
873
874 //
875 // Helper macros for switch case macros (see below)
876 //
877 // These macros are not meant to be used directly. They provide an easy way to
878 // generate a switch statement that can handle subsets of ScalarTypes supported
879 // by ExecuTorch.
880 //
881
882 #ifdef ET_INTERNAL_CHECK_SELECTIVE_BUILD
883 #define ET_INTERNAL_SWITCH_CASE(enum_type, CTYPE_ALIAS, ...) \
884 case enum_type: { \
885 ET_INTERNAL_CHECK_SELECTIVE_BUILD(enum_type); \
886 using CTYPE_ALIAS = \
887 ::executorch::runtime::ScalarTypeToCppType<enum_type>::type; \
888 return __VA_ARGS__(); \
889 }
890 #else
891 #define ET_INTERNAL_SWITCH_CASE(enum_type, CTYPE_ALIAS, ...) \
892 case enum_type: { \
893 using CTYPE_ALIAS = \
894 ::executorch::runtime::ScalarTypeToCppType<enum_type>::type; \
895 return __VA_ARGS__(); \
896 }
897 #endif
898
899 #define ET_INTERNAL_SWITCH(TYPE, CONTEXT, NAME, ...) \
900 [&] { \
901 const auto& _st = TYPE; \
902 constexpr const char* et_switch_name = NAME; \
903 (void)et_switch_name; /* Suppress unused var */ \
904 switch (_st) { \
905 __VA_ARGS__ \
906 default: \
907 ET_CHECK_MSG( \
908 false, \
909 "Unhandled dtype %s for %s", \
910 ::executorch::runtime::toString(_st), \
911 et_switch_name); \
912 } \
913 }()
914
915 #define ET_INTERNAL_SWITCH_CASE_ALL_TYPES(CTYPE_ALIAS, ...) \
916 ET_INTERNAL_SWITCH_CASE( \
917 ::executorch::aten::ScalarType::Byte, CTYPE_ALIAS, __VA_ARGS__) \
918 ET_INTERNAL_SWITCH_CASE( \
919 ::executorch::aten::ScalarType::Char, CTYPE_ALIAS, __VA_ARGS__) \
920 ET_INTERNAL_SWITCH_CASE( \
921 ::executorch::aten::ScalarType::Short, CTYPE_ALIAS, __VA_ARGS__) \
922 ET_INTERNAL_SWITCH_CASE( \
923 ::executorch::aten::ScalarType::Int, CTYPE_ALIAS, __VA_ARGS__) \
924 ET_INTERNAL_SWITCH_CASE( \
925 ::executorch::aten::ScalarType::Long, CTYPE_ALIAS, __VA_ARGS__) \
926 ET_INTERNAL_SWITCH_CASE( \
927 ::executorch::aten::ScalarType::Half, CTYPE_ALIAS, __VA_ARGS__) \
928 ET_INTERNAL_SWITCH_CASE( \
929 ::executorch::aten::ScalarType::Float, CTYPE_ALIAS, __VA_ARGS__) \
930 ET_INTERNAL_SWITCH_CASE( \
931 ::executorch::aten::ScalarType::Double, CTYPE_ALIAS, __VA_ARGS__) \
932 ET_INTERNAL_SWITCH_CASE( \
933 ::executorch::aten::ScalarType::ComplexHalf, CTYPE_ALIAS, __VA_ARGS__) \
934 ET_INTERNAL_SWITCH_CASE( \
935 ::executorch::aten::ScalarType::ComplexFloat, CTYPE_ALIAS, __VA_ARGS__) \
936 ET_INTERNAL_SWITCH_CASE( \
937 ::executorch::aten::ScalarType::ComplexDouble, CTYPE_ALIAS, __VA_ARGS__) \
938 ET_INTERNAL_SWITCH_CASE( \
939 ::executorch::aten::ScalarType::Bool, CTYPE_ALIAS, __VA_ARGS__) \
940 ET_INTERNAL_SWITCH_CASE( \
941 ::executorch::aten::ScalarType::QInt8, CTYPE_ALIAS, __VA_ARGS__) \
942 ET_INTERNAL_SWITCH_CASE( \
943 ::executorch::aten::ScalarType::QUInt8, CTYPE_ALIAS, __VA_ARGS__) \
944 ET_INTERNAL_SWITCH_CASE( \
945 ::executorch::aten::ScalarType::QInt32, CTYPE_ALIAS, __VA_ARGS__) \
946 ET_INTERNAL_SWITCH_CASE( \
947 ::executorch::aten::ScalarType::BFloat16, CTYPE_ALIAS, __VA_ARGS__) \
948 ET_INTERNAL_SWITCH_CASE( \
949 ::executorch::aten::ScalarType::QUInt4x2, CTYPE_ALIAS, __VA_ARGS__) \
950 ET_INTERNAL_SWITCH_CASE( \
951 ::executorch::aten::ScalarType::QUInt2x4, CTYPE_ALIAS, __VA_ARGS__) \
952 ET_INTERNAL_SWITCH_CASE( \
953 ::executorch::aten::ScalarType::Bits1x8, CTYPE_ALIAS, __VA_ARGS__) \
954 ET_INTERNAL_SWITCH_CASE( \
955 ::executorch::aten::ScalarType::Bits2x4, CTYPE_ALIAS, __VA_ARGS__) \
956 ET_INTERNAL_SWITCH_CASE( \
957 ::executorch::aten::ScalarType::Bits4x2, CTYPE_ALIAS, __VA_ARGS__) \
958 ET_INTERNAL_SWITCH_CASE( \
959 ::executorch::aten::ScalarType::Bits8, CTYPE_ALIAS, __VA_ARGS__) \
960 ET_INTERNAL_SWITCH_CASE( \
961 ::executorch::aten::ScalarType::Bits16, CTYPE_ALIAS, __VA_ARGS__)
962
963 #define ET_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, ...) \
964 ET_INTERNAL_SWITCH_CASE( \
965 ::executorch::aten::ScalarType::Byte, CTYPE_ALIAS, __VA_ARGS__) \
966 ET_INTERNAL_SWITCH_CASE( \
967 ::executorch::aten::ScalarType::Char, CTYPE_ALIAS, __VA_ARGS__) \
968 ET_INTERNAL_SWITCH_CASE( \
969 ::executorch::aten::ScalarType::Short, CTYPE_ALIAS, __VA_ARGS__) \
970 ET_INTERNAL_SWITCH_CASE( \
971 ::executorch::aten::ScalarType::Int, CTYPE_ALIAS, __VA_ARGS__) \
972 ET_INTERNAL_SWITCH_CASE( \
973 ::executorch::aten::ScalarType::Long, CTYPE_ALIAS, __VA_ARGS__) \
974 ET_INTERNAL_SWITCH_CASE( \
975 ::executorch::aten::ScalarType::Float, CTYPE_ALIAS, __VA_ARGS__) \
976 ET_INTERNAL_SWITCH_CASE( \
977 ::executorch::aten::ScalarType::Double, CTYPE_ALIAS, __VA_ARGS__)
978
979 #define ET_INTERNAL_SWITCH_CASE_REAL_TYPES_AND(ADDITIONAL, CTYPE_ALIAS, ...) \
980 ET_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, __VA_ARGS__) \
981 ET_INTERNAL_SWITCH_CASE( \
982 ::executorch::aten::ScalarType::ADDITIONAL, CTYPE_ALIAS, __VA_ARGS__)
983
984 #define ET_INTERNAL_SWITCH_CASE_REAL_TYPES_AND2( \
985 ADDITIONAL1, ADDITIONAL2, CTYPE_ALIAS, ...) \
986 ET_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, __VA_ARGS__) \
987 ET_INTERNAL_SWITCH_CASE( \
988 ::executorch::aten::ScalarType::ADDITIONAL1, CTYPE_ALIAS, __VA_ARGS__) \
989 ET_INTERNAL_SWITCH_CASE( \
990 ::executorch::aten::ScalarType::ADDITIONAL2, CTYPE_ALIAS, __VA_ARGS__)
991
992 #define ET_INTERNAL_SWITCH_CASE_REAL_TYPES_AND3( \
993 ADDITIONAL1, ADDITIONAL2, ADDITIONAL3, CTYPE_ALIAS, ...) \
994 ET_INTERNAL_SWITCH_CASE_REAL_TYPES_AND2( \
995 ADDITIONAL1, ADDITIONAL2, CTYPE_ALIAS, __VA_ARGS__) \
996 ET_INTERNAL_SWITCH_CASE( \
997 ::executorch::aten::ScalarType::ADDITIONAL3, CTYPE_ALIAS, __VA_ARGS__)
998
999 #define ET_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, ...) \
1000 ET_INTERNAL_SWITCH_CASE( \
1001 ::executorch::aten::ScalarType::Byte, CTYPE_ALIAS, __VA_ARGS__) \
1002 ET_INTERNAL_SWITCH_CASE( \
1003 ::executorch::aten::ScalarType::Char, CTYPE_ALIAS, __VA_ARGS__) \
1004 ET_INTERNAL_SWITCH_CASE( \
1005 ::executorch::aten::ScalarType::Short, CTYPE_ALIAS, __VA_ARGS__) \
1006 ET_INTERNAL_SWITCH_CASE( \
1007 ::executorch::aten::ScalarType::Int, CTYPE_ALIAS, __VA_ARGS__) \
1008 ET_INTERNAL_SWITCH_CASE( \
1009 ::executorch::aten::ScalarType::Long, CTYPE_ALIAS, __VA_ARGS__)
1010
1011 #define ET_INTERNAL_SWITCH_CASE_INT_TYPES_AND(ADDITIONAL, CTYPE_ALIAS, ...) \
1012 ET_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, __VA_ARGS__) \
1013 ET_INTERNAL_SWITCH_CASE( \
1014 ::executorch::aten::ScalarType::ADDITIONAL, CTYPE_ALIAS, __VA_ARGS__)
1015
1016 #define ET_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, ...) \
1017 ET_INTERNAL_SWITCH_CASE( \
1018 ::executorch::aten::ScalarType::Double, CTYPE_ALIAS, __VA_ARGS__) \
1019 ET_INTERNAL_SWITCH_CASE( \
1020 ::executorch::aten::ScalarType::Float, CTYPE_ALIAS, __VA_ARGS__)
1021
1022 #define ET_INTERNAL_SWITCH_CASE_FLOAT_TYPES_AND(ADDITIONAL, CTYPE_ALIAS, ...) \
1023 ET_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__) \
1024 ET_INTERNAL_SWITCH_CASE( \
1025 ::executorch::aten::ScalarType::ADDITIONAL, CTYPE_ALIAS, __VA_ARGS__)
1026
1027 #define ET_INTERNAL_SWITCH_CASE_FLOAT_TYPES_AND2( \
1028 ADDITIONAL1, ADDITIONAL2, CTYPE_ALIAS, ...) \
1029 ET_INTERNAL_SWITCH_CASE_FLOAT_TYPES_AND( \
1030 ADDITIONAL1, CTYPE_ALIAS, __VA_ARGS__) \
1031 ET_INTERNAL_SWITCH_CASE( \
1032 ::executorch::aten::ScalarType::ADDITIONAL2, CTYPE_ALIAS, __VA_ARGS__)
1033
1034 #define ET_INTERNAL_SWITCH_CASE_QINT_TYPES(CTYPE_ALIAS, ...) \
1035 ET_INTERNAL_SWITCH_CASE( \
1036 ::executorch::aten::ScalarType::QInt8, CTYPE_ALIAS, __VA_ARGS__) \
1037 ET_INTERNAL_SWITCH_CASE( \
1038 ::executorch::aten::ScalarType::QUInt8, CTYPE_ALIAS, __VA_ARGS__) \
1039 ET_INTERNAL_SWITCH_CASE( \
1040 ::executorch::aten::ScalarType::QInt32, CTYPE_ALIAS, __VA_ARGS__) \
1041 ET_INTERNAL_SWITCH_CASE( \
1042 ::executorch::aten::ScalarType::QUInt4x2, CTYPE_ALIAS, __VA_ARGS__) \
1043 ET_INTERNAL_SWITCH_CASE( \
1044 ::executorch::aten::ScalarType::QUInt2x4, CTYPE_ALIAS, __VA_ARGS__)
1045
1046 #define ET_INTERNAL_SWITCH_CASE_COMPLEX_TYPES(CTYPE_ALIAS, ...) \
1047 ET_INTERNAL_SWITCH_CASE( \
1048 ::executorch::aten::ScalarType::ComplexFloat, CTYPE_ALIAS, __VA_ARGS__) \
1049 ET_INTERNAL_SWITCH_CASE( \
1050 ::executorch::aten::ScalarType::ComplexDouble, CTYPE_ALIAS, __VA_ARGS__)
1051
1052 #define ET_INTERNAL_SWITCH_CASE_SCALAR_OBJ_TYPES(CTYPE_ALIAS, ...) \
1053 ET_INTERNAL_SWITCH_CASE( \
1054 ::executorch::aten::ScalarType::Bool, CTYPE_ALIAS, __VA_ARGS__) \
1055 ET_INTERNAL_SWITCH_CASE( \
1056 ::executorch::aten::ScalarType::Long, CTYPE_ALIAS, __VA_ARGS__) \
1057 ET_INTERNAL_SWITCH_CASE( \
1058 ::executorch::aten::ScalarType::Double, CTYPE_ALIAS, __VA_ARGS__)
1059
1060 #define ET_INTERNAL_SWITCH_CASE_SCALAR_OBJ_REAL_TYPES(CTYPE_ALIAS, ...) \
1061 ET_INTERNAL_SWITCH_CASE( \
1062 ::executorch::aten::ScalarType::Long, CTYPE_ALIAS, __VA_ARGS__) \
1063 ET_INTERNAL_SWITCH_CASE( \
1064 ::executorch::aten::ScalarType::Double, CTYPE_ALIAS, __VA_ARGS__)
1065
1066 #define ET_INTERNAL_SWITCH_CASE_SCALAR_OBJ_INTB_TYPES(CTYPE_ALIAS, ...) \
1067 ET_INTERNAL_SWITCH_CASE( \
1068 ::executorch::aten::ScalarType::Bool, CTYPE_ALIAS, __VA_ARGS__) \
1069 ET_INTERNAL_SWITCH_CASE( \
1070 ::executorch::aten::ScalarType::Long, CTYPE_ALIAS, __VA_ARGS__)
1071
1072 #define ET_INTERNAL_SWITCH_CASE_SCALAR_OBJ_FLOATB_TYPES(CTYPE_ALIAS, ...) \
1073 ET_INTERNAL_SWITCH_CASE( \
1074 ::executorch::aten::ScalarType::Bool, CTYPE_ALIAS, __VA_ARGS__) \
1075 ET_INTERNAL_SWITCH_CASE( \
1076 ::executorch::aten::ScalarType::Double, CTYPE_ALIAS, __VA_ARGS__)
1077
1078 //
1079 // Switch case macros
1080 //
1081 // These macros provide an easy way to generate switch statements that apply a
1082 // common lambda function to subsets of ScalarTypes supported by ExecuTorch.
1083 // The lambda function can type specialize to the ctype associated with the
1084 // ScalarType being handled through an alias passed as the CTYPE_ALIAS argument.
1085 //
1086 // Arguments:
1087 // - ADDITIONAL: Additional ScalarType case to add
1088 // - TYPE: The ScalarType to handle through the switch statement
1089 // - CONTEXT: The KernelRuntimeContext instance used for error handling, etc.
1090 // - NAME: A name for this operation which will be used in error messages
1091 // - CTYPE_ALIAS: A typedef for the ctype associated with the ScalarType.
1092 // - [&](){...}: A lambda function to be applied to each ScalarType case
1093 //
1094 // An example usage is:
1095 //
1096 // ET_SWITCH_REAL_TYPES(input.scalar_type(), "example", CTYPE, [&]() {
1097 // output.mutable_data_ptr<CTYPE>[0] = input.const_data_ptr<CTYPE>[0];
1098 // });
1099 //
1100 // Note that these can be nested as well:
1101 //
1102 // ET_SWITCH_REAL_TYPES(input.scalar_type(), "example", CTYPE_IN, [&]() {
1103 // ET_SWITCH_REAL_TYPES(output.scalar_type(), "example", CTYPE_OUT, [&]() {
1104 // output.mutable_data_ptr<CTYPE_OUT>[0] =
1105 // input.const_data_ptr<CTYPE_IN>[0];
1106 // });
1107 // });
1108 //
1109 // These macros are adapted from Dispatch.h in the ATen library. The primary
1110 // difference is that the CTYPE_ALIAS argument is exposed to users, which is
1111 // used to alias the ctype associated with the ScalarType that is being handled.
1112 //
1113
1114 #define ET_SWITCH_ALL_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \
1115 ET_INTERNAL_SWITCH( \
1116 TYPE, \
1117 CONTEXT, \
1118 NAME, \
1119 ET_INTERNAL_SWITCH_CASE_ALL_TYPES(CTYPE_ALIAS, __VA_ARGS__))
1120
1121 #define ET_SWITCH_REAL_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \
1122 ET_INTERNAL_SWITCH( \
1123 TYPE, \
1124 CONTEXT, \
1125 NAME, \
1126 ET_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, __VA_ARGS__))
1127
1128 #define ET_SWITCH_REAL_TYPES_AND( \
1129 ADDITIONAL, TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \
1130 ET_INTERNAL_SWITCH( \
1131 TYPE, \
1132 CONTEXT, \
1133 NAME, \
1134 ET_INTERNAL_SWITCH_CASE_REAL_TYPES_AND( \
1135 ADDITIONAL, CTYPE_ALIAS, __VA_ARGS__))
1136
1137 #define ET_SWITCH_REAL_TYPES_AND2( \
1138 ADDITIONAL1, ADDITIONAL2, TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \
1139 ET_INTERNAL_SWITCH( \
1140 TYPE, \
1141 CONTEXT, \
1142 NAME, \
1143 ET_INTERNAL_SWITCH_CASE_REAL_TYPES_AND2( \
1144 ADDITIONAL1, ADDITIONAL2, CTYPE_ALIAS, __VA_ARGS__))
1145
1146 #define ET_SWITCH_REAL_TYPES_AND3( \
1147 ADDITIONAL1, \
1148 ADDITIONAL2, \
1149 ADDITIONAL3, \
1150 TYPE, \
1151 CONTEXT, \
1152 NAME, \
1153 CTYPE_ALIAS, \
1154 ...) \
1155 ET_INTERNAL_SWITCH( \
1156 TYPE, \
1157 CONTEXT, \
1158 NAME, \
1159 ET_INTERNAL_SWITCH_CASE_REAL_TYPES_AND3( \
1160 ADDITIONAL1, ADDITIONAL2, ADDITIONAL3, CTYPE_ALIAS, __VA_ARGS__))
1161
1162 #define ET_SWITCH_REALH_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \
1163 ET_SWITCH_REAL_TYPES_AND(Half, TYPE, CONTEXT, NAME, CTYPE_ALIAS, __VA_ARGS__)
1164
1165 #define ET_SWITCH_REALHBF16_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \
1166 ET_SWITCH_REAL_TYPES_AND2( \
1167 Half, BFloat16, TYPE, CONTEXT, NAME, CTYPE_ALIAS, __VA_ARGS__)
1168
1169 #define ET_SWITCH_REALB_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \
1170 ET_SWITCH_REAL_TYPES_AND(Bool, TYPE, CONTEXT, NAME, CTYPE_ALIAS, __VA_ARGS__)
1171
1172 #define ET_SWITCH_REALHB_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \
1173 ET_SWITCH_REAL_TYPES_AND2( \
1174 Half, Bool, TYPE, CONTEXT, NAME, CTYPE_ALIAS, __VA_ARGS__)
1175
1176 #define ET_SWITCH_REALHBBF16_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \
1177 ET_SWITCH_REAL_TYPES_AND3( \
1178 Half, Bool, BFloat16, TYPE, CONTEXT, NAME, CTYPE_ALIAS, __VA_ARGS__)
1179
1180 #define ET_SWITCH_INT_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \
1181 ET_INTERNAL_SWITCH( \
1182 TYPE, \
1183 CONTEXT, \
1184 NAME, \
1185 ET_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, __VA_ARGS__))
1186
1187 #define ET_SWITCH_INT_TYPES_AND( \
1188 ADDITIONAL, TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \
1189 ET_INTERNAL_SWITCH( \
1190 TYPE, \
1191 CONTEXT, \
1192 NAME, \
1193 ET_INTERNAL_SWITCH_CASE_INT_TYPES_AND( \
1194 ADDITIONAL, CTYPE_ALIAS, __VA_ARGS__))
1195
1196 #define ET_SWITCH_FLOAT_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \
1197 ET_INTERNAL_SWITCH( \
1198 TYPE, \
1199 CONTEXT, \
1200 NAME, \
1201 ET_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__))
1202
1203 #define ET_SWITCH_FLOAT_TYPES_AND( \
1204 ADDITIONAL, TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \
1205 ET_INTERNAL_SWITCH( \
1206 TYPE, \
1207 CONTEXT, \
1208 NAME, \
1209 ET_INTERNAL_SWITCH_CASE_FLOAT_TYPES_AND( \
1210 ADDITIONAL, CTYPE_ALIAS, __VA_ARGS__))
1211
1212 #define ET_SWITCH_FLOAT_TYPES_AND2( \
1213 ADDITIONAL1, ADDITIONAL2, TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \
1214 ET_INTERNAL_SWITCH( \
1215 TYPE, \
1216 CONTEXT, \
1217 NAME, \
1218 ET_INTERNAL_SWITCH_CASE_FLOAT_TYPES_AND2( \
1219 ADDITIONAL1, ADDITIONAL2, CTYPE_ALIAS, __VA_ARGS__))
1220
1221 #define ET_SWITCH_FLOATH_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \
1222 ET_SWITCH_FLOAT_TYPES_AND(Half, TYPE, CONTEXT, NAME, CTYPE_ALIAS, __VA_ARGS__)
1223
1224 #define ET_SWITCH_FLOATHBF16_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \
1225 ET_SWITCH_FLOAT_TYPES_AND2( \
1226 Half, BFloat16, TYPE, CONTEXT, NAME, CTYPE_ALIAS, __VA_ARGS__)
1227
1228 #define ET_SWITCH_QINT_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \
1229 ET_INTERNAL_SWITCH( \
1230 TYPE, \
1231 CONTEXT, \
1232 NAME, \
1233 ET_INTERNAL_SWITCH_CASE_QINT_TYPES(CTYPE_ALIAS, __VA_ARGS__))
1234
1235 #define ET_SWITCH_COMPLEX_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \
1236 ET_INTERNAL_SWITCH( \
1237 TYPE, \
1238 CONTEXT, \
1239 NAME, \
1240 ET_INTERNAL_SWITCH_CASE_COMPLEX_TYPES(CTYPE_ALIAS, __VA_ARGS__))
1241
1242 #define ET_SWITCH_SCALAR_OBJ_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \
1243 ET_INTERNAL_SWITCH( \
1244 TYPE, \
1245 CONTEXT, \
1246 NAME, \
1247 ET_INTERNAL_SWITCH_CASE_SCALAR_OBJ_TYPES(CTYPE_ALIAS, __VA_ARGS__))
1248
1249 #define ET_SWITCH_SCALAR_OBJ_REAL_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \
1250 ET_INTERNAL_SWITCH( \
1251 TYPE, \
1252 CONTEXT, \
1253 NAME, \
1254 ET_INTERNAL_SWITCH_CASE_SCALAR_OBJ_REAL_TYPES(CTYPE_ALIAS, __VA_ARGS__))
1255
1256 #define ET_SWITCH_SCALAR_OBJ_INTB_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \
1257 ET_INTERNAL_SWITCH( \
1258 TYPE, \
1259 CONTEXT, \
1260 NAME, \
1261 ET_INTERNAL_SWITCH_CASE_SCALAR_OBJ_INTB_TYPES(CTYPE_ALIAS, __VA_ARGS__))
1262
1263 #define ET_SWITCH_SCALAR_OBJ_FLOATB_TYPES( \
1264 TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \
1265 ET_INTERNAL_SWITCH( \
1266 TYPE, \
1267 CONTEXT, \
1268 NAME, \
1269 ET_INTERNAL_SWITCH_CASE_SCALAR_OBJ_FLOATB_TYPES( \
1270 CTYPE_ALIAS, __VA_ARGS__))
1271
1272 #define ET_SWITCH_TWO_TYPES(T1, T2, TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \
1273 ET_INTERNAL_SWITCH( \
1274 TYPE, \
1275 CONTEXT, \
1276 NAME, \
1277 ET_INTERNAL_SWITCH_CASE( \
1278 ::executorch::aten::ScalarType::T1, CTYPE_ALIAS, __VA_ARGS__) \
1279 ET_INTERNAL_SWITCH_CASE( \
1280 ::executorch::aten::ScalarType::T2, CTYPE_ALIAS, __VA_ARGS__))
1281
1282 #define ET_SWITCH_THREE_TYPES( \
1283 T1, T2, T3, TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \
1284 ET_INTERNAL_SWITCH( \
1285 TYPE, \
1286 CONTEXT, \
1287 NAME, \
1288 ET_INTERNAL_SWITCH_CASE( \
1289 ::executorch::aten::ScalarType::T1, CTYPE_ALIAS, __VA_ARGS__) \
1290 ET_INTERNAL_SWITCH_CASE( \
1291 ::executorch::aten::ScalarType::T2, CTYPE_ALIAS, __VA_ARGS__) \
1292 ET_INTERNAL_SWITCH_CASE( \
1293 ::executorch::aten::ScalarType::T3, \
1294 CTYPE_ALIAS, \
1295 __VA_ARGS__))
1296
1297 } // namespace runtime
1298 } // namespace executorch
1299
1300 namespace executorch {
1301 namespace aten {
1302 #ifdef USE_ATEN_LIB
1303 using ::at::elementSize;
1304 #else // USE_ATEN_LIB
1305 using ::executorch::runtime::elementSize;
1306 #endif // USE_ATEN_LIB
1307 } // namespace aten
1308 } // namespace executorch
1309
1310 namespace torch {
1311 namespace executor {
1312 // TODO(T197294990): Remove these deprecated aliases once all users have moved
1313 // to the new `::executorch` namespaces.
1314 using ::executorch::runtime::can_cast;
1315 using ::executorch::runtime::canCast;
1316 using ::executorch::runtime::convert;
1317 using ::executorch::runtime::CppTypeToScalarType;
1318 using ::executorch::runtime::elementSize;
1319 using ::executorch::runtime::is_barebones_unsigned_type;
1320 using ::executorch::runtime::is_bits_type;
1321 using ::executorch::runtime::is_complex_type;
1322 using ::executorch::runtime::is_float8_type;
1323 using ::executorch::runtime::is_integral_type;
1324 using ::executorch::runtime::is_qint_type;
1325 using ::executorch::runtime::isBitsType;
1326 using ::executorch::runtime::isComplexType;
1327 using ::executorch::runtime::isFloatingType;
1328 using ::executorch::runtime::isIntegralType;
1329 using ::executorch::runtime::isQIntType;
1330 using ::executorch::runtime::isRealHBType;
1331 using ::executorch::runtime::isRealHType;
1332 using ::executorch::runtime::isRealType;
1333 using ::executorch::runtime::isValid;
1334 using ::executorch::runtime::promote_types;
1335 using ::executorch::runtime::promoteTypes;
1336 using ::executorch::runtime::ScalarTypeToCppType;
1337 using ::executorch::runtime::toString;
1338 #if !defined(USE_ATEN_LIB)
1339 using ::executorch::runtime::is_floating_point;
1340 using ::executorch::runtime::is_reduced_floating_point;
1341 #endif
1342 namespace internal {
1343 using ::executorch::runtime::internal::B1;
1344 using ::executorch::runtime::internal::C2;
1345 using ::executorch::runtime::internal::C4;
1346 using ::executorch::runtime::internal::C8;
1347 using ::executorch::runtime::internal::F2;
1348 using ::executorch::runtime::internal::F4;
1349 using ::executorch::runtime::internal::F8;
1350 using ::executorch::runtime::internal::I1;
1351 using ::executorch::runtime::internal::I2;
1352 using ::executorch::runtime::internal::I4;
1353 using ::executorch::runtime::internal::I8;
1354 using ::executorch::runtime::internal::U1;
1355 } // namespace internal
1356 } // namespace executor
1357 } // namespace torch
1358