xref: /aosp_15_r20/external/pytorch/aten/src/ATen/Dispatch.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/DeprecatedTypeProperties.h>
4 #include <c10/macros/Macros.h>
5 #include <c10/util/Exception.h>
6 #include <c10/util/Half.h>
7 #include <c10/util/Metaprogramming.h>
8 #include <c10/util/complex.h>
9 #include <c10/util/string_view.h>
10 
11 #ifdef __CUDACC__
12 #include <cuda.h> // For CUDA_VERSION
13 #endif
14 
15 #ifdef TEMPLATE_SELECTIVE_BUILD
16 #include <ATen/selected_mobile_ops.h>
17 #else
18 namespace at {
19 /**
20  * The method should_include_kernel_dtype() returns true/false
21  * based on whether the switching code for a specific dtype should be
22  * included based on build time constants generated from tracing model
23  * execution. This method will be implemented via code-generation and
24  * included in this file when code-gen is ready.
25  */
should_include_kernel_dtype(const char *,at::ScalarType)26 inline constexpr bool should_include_kernel_dtype(
27     const char* /*kernel_tag_str*/,
28     at::ScalarType /*scalar_type*/
29 ) {
30   return true;
31 }
32 } // namespace at
33 #endif
34 
35 /**
36  * In the Facebook internal build (using BUCK), this macro is enabled by
37  * passing in -c pt.enable_record_kernel_dtype=1 when building the tracer
38  * binary.
39  */
40 #if defined ENABLE_RECORD_KERNEL_FUNCTION_DTYPE
41 namespace at {
42 namespace detail {
43 TORCH_API void record_kernel_function_dtype(std::string name);
44 }
45 } // namespace at
46 
47 #define RECORD_KERNEL_FUNCTION_DTYPE(NAME, enum_type) \
48   at::detail::record_kernel_function_dtype(           \
49       std::string(NAME) + "$" + toString(enum_type));
50 #else
51 #define RECORD_KERNEL_FUNCTION_DTYPE(NAME, enum_type)
52 #endif
53 
54 #define AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type)   \
55   do {                                                \
56     if constexpr (!at::should_include_kernel_dtype(   \
57                       at_dispatch_name, enum_type)) { \
58       AT_ERROR(                                       \
59           "dtype '",                                  \
60           toString(enum_type),                        \
61           "' not selected for kernel tag ",           \
62           at_dispatch_name);                          \
63     }                                                 \
64   } while (0)
65 
66 #define AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, HINT, ...)           \
67   case enum_type: {                                                     \
68     AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type);                        \
69     using HINT C10_UNUSED = c10::impl::ScalarTypeToCPPTypeT<enum_type>; \
70     return __VA_ARGS__();                                               \
71   }
72 
73 #define AT_DISPATCH_CASE(enum_type, ...) \
74   AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, scalar_t, __VA_ARGS__)
75 
76 #define AT_DISPATCH_CASE_QINT(enum_type, scalar_type, ...)            \
77   case enum_type: {                                                   \
78     AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type);                      \
79     using scalar_t = scalar_type;                                     \
80     using underlying_t C10_UNUSED = typename scalar_t::underlying;    \
81     const auto& SCALAR_TYPE C10_UNUSED = enum_type;                   \
82     const auto& UNDERLYING_TYPE C10_UNUSED = toUnderlying(enum_type); \
83     return __VA_ARGS__();                                             \
84   }
85 
86 #define AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE(                           \
87     enum_type, scalar_type, bitwidth, qmin, qmax, ...)                \
88   case enum_type: {                                                   \
89     AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type);                      \
90     using scalar_t = scalar_type;                                     \
91     using underlying_t C10_UNUSED = typename scalar_t::underlying;    \
92     const auto& SCALAR_TYPE C10_UNUSED = enum_type;                   \
93     const auto& UNDERLYING_TYPE C10_UNUSED = toUnderlying(enum_type); \
94     C10_UNUSED int bit_width = bitwidth;                              \
95     C10_UNUSED int64_t quant_min = qmin;                              \
96     C10_UNUSED int64_t quant_max = qmax;                              \
97     return __VA_ARGS__();                                             \
98   }
99 
100 namespace detail {
101 
scalar_type(at::ScalarType s)102 inline at::ScalarType scalar_type(at::ScalarType s) {
103   return s;
104 }
105 
106 C10_DEPRECATED_MESSAGE(
107     "passing at::DeprecatedTypeProperties to an AT_DISPATCH macro is deprecated, "
108     "pass an at::ScalarType instead")
scalar_type(const at::DeprecatedTypeProperties & t)109 inline at::ScalarType scalar_type(const at::DeprecatedTypeProperties& t) {
110   return t.scalarType();
111 }
112 
113 C10_DEPRECATED_MESSAGE(
114     "AT_DISPATCH_ALL_TYPES_AND_HALF is deprecated, "
115     "use AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, ...) instead")
deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF()116 inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF() {}
117 
118 C10_DEPRECATED_MESSAGE(
119     "AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX is deprecated, "
120     "use AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::Half, ...) "
121     "instead")
deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX()122 inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
123 
124 } // namespace detail
125 
126 // The AT_DISPATCH_* family of macros provides the ability to
127 // conveniently generate specializations of a kernel over all of the
128 // dtypes we care about in PyTorch.  We call it "dispatch" because
129 // we are "dispatching" to the correct, dtype-specific kernel.
130 //
131 // A standard usage looks like:
132 //
133 //      AT_DISPATCH_ALL_TYPES(self.scalar_type(), "op_name", [&] {
134 //          // Your code here, with 'scalar_t' now defined to
135 //          // be the dtype in question
136 //      });
137 //
138 // There are many variations of this macro, so it's important to
139 // understand exactly /which/ dtypes you want to get instantiated, as
140 // well as what the "default" set is.
141 //
142 // The default set of dtypes that are instantiated (e.g., by
143 // AT_DISPATCH_ALL_TYPES) are floating point types (float, double),
144 // and integral types (int32_t, int64_t, int16_t, int8_t, uint8_t),
145 // but NOT booleans (bool), half-precision floats (Half) or
146 // complex number (c10::complex<float>, c10::complex<double>).
147 // This "cut" is somewhat historical (the default types are the
148 // ones that TH historically supported), but it also reflects the
149 // fact that the non-default types are "poorly" behaved (booleans
150 // are NOT integers mod 2, half precision operations ~essentially
151 // don't exist on CPU, complex numbers are an experimental application).
152 //
153 // Here are the questions you should generally ask to decide which
154 // dispatch you want:
155 //
156 // 1. Is this an integral or floating point specific operation?
157 //    (If so, you'll want one of the FLOATING or INTEGRAL macros.)
158 //
159 // 2. Should half be supported?  (If you're on CPU, the answer is almost
160 //    definitely no.  If you do want support, use one of the AND_HALF
161 //    macros)
162 //
163 // Much rarer situations:
164 //
165 // 3. Should bool be supported?  (You often have to write your kernel
166 //    differently if arithmetic operations are involved.)  If so,
167 //    Use AT_DISPATCH_ALL_TYPES_AND along with ScalarType::Bool
168 //
169 // 4. Should complex be supported?  The answer is almost always no,
170 //    unless you are working on "generic" code that should work on
171 //    all dtypes.
172 //
173 // Parameters:
174 // -----------
175 //
176 // 1. The NAME argument is a "tag" that is used to trace and then
177 //    conditionally compile fragments of the case statements such
178 //    that the kernel functions are specialized only for the dtypes
179 //    that are needed. The NAME parameter *must* be a build time
180 //    const char* (can't be std::string, etc...)
181 //
182 // Please ensure that the NAME is unique for every implementation
183 // or you run the risk of over-including code for the kernel
184 // functions. There is no risk of missing out on any code, so
185 // it's mostly a risk of a Type-2 error, and not a Type-1 error.
186 //
187 // Switch-like syntax:
188 // -------------------
189 // There is also a switch-case like syntax which is useful if a kernel
190 // needs to be specialized for particular scalar types
191 //
192 //      AT_DISPATCH_SWITCH(self.scalar_type(), "op_name",
193 //          AT_DISPATCH_CASE_INTEGRAL_TYPES([&] {
194 //            op_integral<scalar_t>(iter);
195 //          })
196 //          AT_DISPATCH_CASE_FLOATING_TYPES([&] {
197 //            op_floating<scalar_t>(iter);
198 //          })
199 //          AT_DISPATCH_CASE(kBool, [&] {
200 //            op_bool(iter);
201 //          })
202 //      );
203 //
204 // For each AT_DISPATCH_FOO macro, there is a corresponding
205 // AT_DISPATCH_CASE_FOO macro which can be used inside of an
206 // AT_DISPATCH_SWITCH block.
207 
208 // NB: the the_type variable is not used, but we have kept it for
209 // backwards compatibility.  It's probably not used by anyone though;
210 // but we're just being safe (and it doesn't hurt.)  Note we must
211 // use it to shut up warnings about unused store.
212 
213 #define AT_DISPATCH_SWITCH(TYPE, NAME, ...)                                 \
214   [&] {                                                                     \
215     const auto& the_type = TYPE;                                            \
216     constexpr const char* at_dispatch_name = NAME;                          \
217     /* don't use TYPE again in case it is an expensive or side-effect op */ \
218     at::ScalarType _st = ::detail::scalar_type(the_type);                   \
219     RECORD_KERNEL_FUNCTION_DTYPE(at_dispatch_name, _st);                    \
220     switch (_st) {                                                          \
221       __VA_ARGS__                                                           \
222       default:                                                              \
223         AT_ERROR(                                                           \
224             '"',                                                            \
225             at_dispatch_name,                                               \
226             "\" not implemented for '",                                     \
227             toString(_st),                                                  \
228             "'");                                                           \
229     }                                                                       \
230   }()
231 
232 #define AT_DISPATCH_CASE_FLOATING_TYPES(...)            \
233   AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
234   AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)
235 
236 #define AT_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
237   AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
238 
239 #define AT_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(...)   \
240   AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
241   AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)  \
242   AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
243 
244 #define AT_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...) \
245   AT_DISPATCH_SWITCH(                                        \
246       TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(__VA_ARGS__))
247 
248 #define AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(...)  \
249   AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
250   AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
251 
252 #define AT_DISPATCH_REDUCED_FLOATING_TYPES(TYPE, NAME, ...) \
253   AT_DISPATCH_SWITCH(                                       \
254       TYPE, NAME, AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(__VA_ARGS__))
255 
256 #define AT_DISPATCH_CASE_FLOATING_TYPES_AND(SCALARTYPE, ...) \
257   AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)               \
258   AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
259 
260 #define AT_DISPATCH_FLOATING_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
261   AT_DISPATCH_SWITCH(                                               \
262       TYPE,                                                         \
263       NAME,                                                         \
264       AT_DISPATCH_CASE_FLOATING_TYPES_AND(SCALARTYPE, __VA_ARGS__))
265 
266 #define AT_DISPATCH_CASE_FLOATING_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, ...) \
267   AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)                              \
268   AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__)                                \
269   AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
270 
271 #define AT_DISPATCH_FLOATING_TYPES_AND2(       \
272     SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
273   AT_DISPATCH_SWITCH(                          \
274       TYPE,                                    \
275       NAME,                                    \
276       AT_DISPATCH_CASE_FLOATING_TYPES_AND2(    \
277           SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
278 
279 #define AT_DISPATCH_CASE_FLOATING_TYPES_AND3(   \
280     SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
281   AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)  \
282   AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__)    \
283   AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)    \
284   AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
285 
286 #define AT_DISPATCH_FLOATING_TYPES_AND3(                    \
287     SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
288   AT_DISPATCH_SWITCH(                                       \
289       TYPE,                                                 \
290       NAME,                                                 \
291       AT_DISPATCH_CASE_FLOATING_TYPES_AND3(                 \
292           SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
293 
294 #define AT_DISPATCH_CASE_FLOATING_TYPES_AND4(                \
295     SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \
296   AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)               \
297   AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__)                 \
298   AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)                 \
299   AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)                 \
300   AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)
301 
302 #define AT_DISPATCH_FLOATING_TYPES_AND4(                                 \
303     SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
304   AT_DISPATCH_SWITCH(                                                    \
305       TYPE,                                                              \
306       NAME,                                                              \
307       AT_DISPATCH_CASE_FLOATING_TYPES_AND4(                              \
308           SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
309 
310 #define AT_DISPATCH_CASE_COMPLEX_TYPES(...)                    \
311   AT_DISPATCH_CASE(at::ScalarType::ComplexDouble, __VA_ARGS__) \
312   AT_DISPATCH_CASE(at::ScalarType::ComplexFloat, __VA_ARGS__)
313 
314 #define AT_DISPATCH_COMPLEX_TYPES(TYPE, NAME, ...) \
315   AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__))
316 
317 #define AT_DISPATCH_CASE_COMPLEX_TYPES_AND(SCALARTYPE, ...) \
318   AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__)               \
319   AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
320 
321 #define AT_DISPATCH_COMPLEX_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
322   AT_DISPATCH_SWITCH(                                              \
323       TYPE, NAME, AT_DISPATCH_CASE_COMPLEX_TYPES_AND(SCALARTYPE, __VA_ARGS__))
324 
325 #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(...) \
326   AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)           \
327   AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__)
328 
329 #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \
330   AT_DISPATCH_SWITCH(                                           \
331       TYPE, NAME, AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__))
332 
333 #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND1(SCALARTYPE, ...) \
334   AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__)                \
335   AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
336 
337 #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(    \
338     SCALARTYPE, TYPE, NAME, ...)                        \
339   AT_DISPATCH_SWITCH(                                   \
340       TYPE,                                             \
341       NAME,                                             \
342       AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND1( \
343           SCALARTYPE, __VA_ARGS__))
344 
345 #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND2(  \
346     SCALARTYPE1, SCALARTYPE2, ...)                         \
347   AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
348   AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__)               \
349   AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
350 
351 #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(    \
352     SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...)          \
353   AT_DISPATCH_SWITCH(                                   \
354       TYPE,                                             \
355       NAME,                                             \
356       AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND2( \
357           SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
358 
359 #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND3(  \
360     SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...)            \
361   AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
362   AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__)               \
363   AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)               \
364   AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
365 
366 #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3(        \
367     SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
368   AT_DISPATCH_SWITCH(                                       \
369       TYPE,                                                 \
370       NAME,                                                 \
371       AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND3(     \
372           SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
373 
374 #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND4(    \
375     SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \
376   AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__)   \
377   AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__)                 \
378   AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)                 \
379   AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)                 \
380   AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)
381 
382 #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4(                     \
383     SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
384   AT_DISPATCH_SWITCH(                                                    \
385       TYPE,                                                              \
386       NAME,                                                              \
387       AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND4(                  \
388           SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
389 
390 #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND5(                 \
391     SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, ...) \
392   AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__)                \
393   AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__)                              \
394   AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)                              \
395   AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)                              \
396   AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)                              \
397   AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__)
398 
399 #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND5(    \
400     SCALARTYPE1,                                        \
401     SCALARTYPE2,                                        \
402     SCALARTYPE3,                                        \
403     SCALARTYPE4,                                        \
404     SCALARTYPE5,                                        \
405     TYPE,                                               \
406     NAME,                                               \
407     ...)                                                \
408   AT_DISPATCH_SWITCH(                                   \
409       TYPE,                                             \
410       NAME,                                             \
411       AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND5( \
412           SCALARTYPE1,                                  \
413           SCALARTYPE2,                                  \
414           SCALARTYPE3,                                  \
415           SCALARTYPE4,                                  \
416           SCALARTYPE5,                                  \
417           __VA_ARGS__))
418 
419 #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND6(  \
420     SCALARTYPE1,                                           \
421     SCALARTYPE2,                                           \
422     SCALARTYPE3,                                           \
423     SCALARTYPE4,                                           \
424     SCALARTYPE5,                                           \
425     SCALARTYPE6,                                           \
426     ...)                                                   \
427   AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
428   AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__)               \
429   AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)               \
430   AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)               \
431   AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)               \
432   AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__)               \
433   AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__)
434 
435 #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND6(    \
436     SCALARTYPE1,                                        \
437     SCALARTYPE2,                                        \
438     SCALARTYPE3,                                        \
439     SCALARTYPE4,                                        \
440     SCALARTYPE5,                                        \
441     SCALARTYPE6,                                        \
442     TYPE,                                               \
443     NAME,                                               \
444     ...)                                                \
445   AT_DISPATCH_SWITCH(                                   \
446       TYPE,                                             \
447       NAME,                                             \
448       AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND6( \
449           SCALARTYPE1,                                  \
450           SCALARTYPE2,                                  \
451           SCALARTYPE3,                                  \
452           SCALARTYPE4,                                  \
453           SCALARTYPE5,                                  \
454           SCALARTYPE6,                                  \
455           __VA_ARGS__))
456 
457 #define AT_DISPATCH_CASE_INTEGRAL_TYPES(...)          \
458   AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
459   AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
460   AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__)  \
461   AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
462   AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__)
463 
464 #define AT_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
465   AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
466 
467 #define AT_DISPATCH_CASE_INTEGRAL_TYPES_AND(SCALARTYPE, ...) \
468   AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)               \
469   AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
470 
471 #define AT_DISPATCH_INTEGRAL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
472   AT_DISPATCH_SWITCH(                                               \
473       TYPE,                                                         \
474       NAME,                                                         \
475       AT_DISPATCH_CASE_INTEGRAL_TYPES_AND(SCALARTYPE, __VA_ARGS__))
476 
477 #define AT_DISPATCH_CASE_ALL_TYPES(...)        \
478   AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__) \
479   AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)
480 
481 #define AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \
482   AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__))
483 
484 #define AT_DISPATCH_CASE_QINT_TYPES(...)                      \
485   AT_DISPATCH_CASE_QINT(at::kQInt8, at::qint8, __VA_ARGS__)   \
486   AT_DISPATCH_CASE_QINT(at::kQUInt8, at::quint8, __VA_ARGS__) \
487   AT_DISPATCH_CASE_QINT(at::kQInt32, at::qint32, __VA_ARGS__)
488 
489 #define AT_DISPATCH_QINT_TYPES(TYPE, NAME, ...) \
490   AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_QINT_TYPES(__VA_ARGS__))
491 
492 #define AT_DISPATCH_CASE_QINT_TYPES_AND(SCALARTYPE, ...) \
493   AT_DISPATCH_CASE_QINT_TYPES(__VA_ARGS__)               \
494   AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
495 
496 #define AT_DISPATCH_QINT_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
497   AT_DISPATCH_SWITCH(                                           \
498       TYPE, NAME, AT_DISPATCH_CASE_QINT_TYPES_AND(SCALARTYPE, __VA_ARGS__))
499 
500 #define AT_DISPATCH_CASE_QINT_BYTE_TYPES(...)               \
501   AT_DISPATCH_CASE_QINT(at::kQInt8, at::qint8, __VA_ARGS__) \
502   AT_DISPATCH_CASE_QINT(at::kQUInt8, at::quint8, __VA_ARGS__)
503 
504 #define AT_DISPATCH_QINT_BYTE_TYPES(TYPE, NAME, ...) \
505   AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_QINT_BYTE_TYPES(__VA_ARGS__))
506 
507 #define AT_DISPATCH_CASE_QINT_AND_SUB_BYTE_TYPES(...)                     \
508   AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE(                                     \
509       at::kQInt8, at::qint8, CHAR_BIT, SCHAR_MIN, SCHAR_MAX, __VA_ARGS__) \
510   AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE(                                     \
511       at::kQUInt8, at::quint8, CHAR_BIT, 0, UCHAR_MAX, __VA_ARGS__)       \
512   AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE(                                     \
513       at::kQInt32,                                                        \
514       at::qint32,                                                         \
515       CHAR_BIT * sizeof(int),                                             \
516       INT_MIN,                                                            \
517       INT_MAX,                                                            \
518       __VA_ARGS__)                                                        \
519   AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE(                                     \
520       at::kQUInt4x2, at::quint4x2, 4, 0, 15, __VA_ARGS__)                 \
521   AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE(                                     \
522       at::kQUInt2x4, at::quint2x4, 2, 0, 3, __VA_ARGS__)
523 
524 #define AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(TYPE, NAME, ...) \
525   AT_DISPATCH_SWITCH(                                        \
526       TYPE, NAME, AT_DISPATCH_CASE_QINT_AND_SUB_BYTE_TYPES(__VA_ARGS__))
527 
528 #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(...) \
529   AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__)           \
530   AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__)
531 
532 #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX(TYPE, NAME, ...) \
533   AT_DISPATCH_SWITCH(                                      \
534       TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__))
535 
536 #define AT_DISPATCH_CASE_ALL_TYPES_AND(SCALARTYPE, ...) \
537   AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__)               \
538   AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
539 
540 #define AT_DISPATCH_ALL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
541   AT_DISPATCH_SWITCH(                                          \
542       TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES_AND(SCALARTYPE, __VA_ARGS__))
543 
544 #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, ...) \
545   AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__)               \
546   AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
547 
548 #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, TYPE, NAME, ...) \
549   AT_DISPATCH_SWITCH(                                                      \
550       TYPE,                                                                \
551       NAME,                                                                \
552       AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, __VA_ARGS__))
553 
554 #define AT_DISPATCH_CASE_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, ...) \
555   AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__)                              \
556   AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__)                           \
557   AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
558 
559 #define AT_DISPATCH_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
560   AT_DISPATCH_SWITCH(                                                         \
561       TYPE,                                                                   \
562       NAME,                                                                   \
563       AT_DISPATCH_CASE_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
564 
565 #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND2(  \
566     SCALARTYPE1, SCALARTYPE2, ...)                    \
567   AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
568   AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__)          \
569   AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
570 
571 #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(    \
572     SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...)     \
573   AT_DISPATCH_SWITCH(                              \
574       TYPE,                                        \
575       NAME,                                        \
576       AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND2( \
577           SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
578 
579 #define AT_DISPATCH_CASE_ALL_TYPES_AND3(        \
580     SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
581   AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__)       \
582   AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__)    \
583   AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)    \
584   AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
585 
586 #define AT_DISPATCH_ALL_TYPES_AND3(                         \
587     SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
588   AT_DISPATCH_SWITCH(                                       \
589       TYPE,                                                 \
590       NAME,                                                 \
591       AT_DISPATCH_CASE_ALL_TYPES_AND3(                      \
592           SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
593 
594 #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND3(  \
595     SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...)       \
596   AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
597   AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__)          \
598   AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)          \
599   AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
600 
601 #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(             \
602     SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
603   AT_DISPATCH_SWITCH(                                       \
604       TYPE,                                                 \
605       NAME,                                                 \
606       AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND3(          \
607           SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
608 
609 #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4(         \
610     SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \
611   AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__)        \
612   AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__)                 \
613   AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)                 \
614   AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)                 \
615   AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)
616 
617 #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(                          \
618     SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
619   AT_DISPATCH_SWITCH(                                                    \
620       TYPE,                                                              \
621       NAME,                                                              \
622       AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4(                       \
623           SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
624 
625 #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND5(                      \
626     SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, ...) \
627   AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__)                     \
628   AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__)                              \
629   AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)                              \
630   AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)                              \
631   AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)                              \
632   AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__)
633 
634 #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND5(    \
635     SCALARTYPE1,                                   \
636     SCALARTYPE2,                                   \
637     SCALARTYPE3,                                   \
638     SCALARTYPE4,                                   \
639     SCALARTYPE5,                                   \
640     TYPE,                                          \
641     NAME,                                          \
642     ...)                                           \
643   AT_DISPATCH_SWITCH(                              \
644       TYPE,                                        \
645       NAME,                                        \
646       AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND5( \
647           SCALARTYPE1,                             \
648           SCALARTYPE2,                             \
649           SCALARTYPE3,                             \
650           SCALARTYPE4,                             \
651           SCALARTYPE5,                             \
652           __VA_ARGS__))
653 
654 #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND6(  \
655     SCALARTYPE1,                                      \
656     SCALARTYPE2,                                      \
657     SCALARTYPE3,                                      \
658     SCALARTYPE4,                                      \
659     SCALARTYPE5,                                      \
660     SCALARTYPE6,                                      \
661     ...)                                              \
662   AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
663   AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__)          \
664   AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)          \
665   AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)          \
666   AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)          \
667   AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__)          \
668   AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__)
669 
670 #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6(    \
671     SCALARTYPE1,                                   \
672     SCALARTYPE2,                                   \
673     SCALARTYPE3,                                   \
674     SCALARTYPE4,                                   \
675     SCALARTYPE5,                                   \
676     SCALARTYPE6,                                   \
677     TYPE,                                          \
678     NAME,                                          \
679     ...)                                           \
680   AT_DISPATCH_SWITCH(                              \
681       TYPE,                                        \
682       NAME,                                        \
683       AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND6( \
684           SCALARTYPE1,                             \
685           SCALARTYPE2,                             \
686           SCALARTYPE3,                             \
687           SCALARTYPE4,                             \
688           SCALARTYPE5,                             \
689           SCALARTYPE6,                             \
690           __VA_ARGS__))
691 
692 #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND7(  \
693     SCALARTYPE1,                                      \
694     SCALARTYPE2,                                      \
695     SCALARTYPE3,                                      \
696     SCALARTYPE4,                                      \
697     SCALARTYPE5,                                      \
698     SCALARTYPE6,                                      \
699     SCALARTYPE7,                                      \
700     ...)                                              \
701   AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
702   AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__)          \
703   AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)          \
704   AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)          \
705   AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)          \
706   AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__)          \
707   AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__)          \
708   AT_DISPATCH_CASE(SCALARTYPE7, __VA_ARGS__)
709 
710 #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND7(    \
711     SCALARTYPE1,                                   \
712     SCALARTYPE2,                                   \
713     SCALARTYPE3,                                   \
714     SCALARTYPE4,                                   \
715     SCALARTYPE5,                                   \
716     SCALARTYPE6,                                   \
717     SCALARTYPE7,                                   \
718     TYPE,                                          \
719     NAME,                                          \
720     ...)                                           \
721   AT_DISPATCH_SWITCH(                              \
722       TYPE,                                        \
723       NAME,                                        \
724       AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND7( \
725           SCALARTYPE1,                             \
726           SCALARTYPE2,                             \
727           SCALARTYPE3,                             \
728           SCALARTYPE4,                             \
729           SCALARTYPE5,                             \
730           SCALARTYPE6,                             \
731           SCALARTYPE7,                             \
732           __VA_ARGS__))
733 
734 #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND8(  \
735     SCALARTYPE1,                                      \
736     SCALARTYPE2,                                      \
737     SCALARTYPE3,                                      \
738     SCALARTYPE4,                                      \
739     SCALARTYPE5,                                      \
740     SCALARTYPE6,                                      \
741     SCALARTYPE7,                                      \
742     SCALARTYPE8,                                      \
743     ...)                                              \
744   AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
745   AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__)          \
746   AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)          \
747   AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)          \
748   AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)          \
749   AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__)          \
750   AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__)          \
751   AT_DISPATCH_CASE(SCALARTYPE7, __VA_ARGS__)          \
752   AT_DISPATCH_CASE(SCALARTYPE8, __VA_ARGS__)
753 
754 #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND8(    \
755     SCALARTYPE1,                                   \
756     SCALARTYPE2,                                   \
757     SCALARTYPE3,                                   \
758     SCALARTYPE4,                                   \
759     SCALARTYPE5,                                   \
760     SCALARTYPE6,                                   \
761     SCALARTYPE7,                                   \
762     SCALARTYPE8,                                   \
763     TYPE,                                          \
764     NAME,                                          \
765     ...)                                           \
766   AT_DISPATCH_SWITCH(                              \
767       TYPE,                                        \
768       NAME,                                        \
769       AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND8( \
770           SCALARTYPE1,                             \
771           SCALARTYPE2,                             \
772           SCALARTYPE3,                             \
773           SCALARTYPE4,                             \
774           SCALARTYPE5,                             \
775           SCALARTYPE6,                             \
776           SCALARTYPE7,                             \
777           SCALARTYPE8,                             \
778           __VA_ARGS__))
779 
780 #define AT_DISPATCH_CASE_BIT_TYPES(...)                  \
781   AT_DISPATCH_CASE(at::ScalarType::Bits1x8, __VA_ARGS__) \
782   AT_DISPATCH_CASE(at::ScalarType::Bits2x4, __VA_ARGS__) \
783   AT_DISPATCH_CASE(at::ScalarType::Bits4x2, __VA_ARGS__) \
784   AT_DISPATCH_CASE(at::ScalarType::Bits8, __VA_ARGS__)   \
785   AT_DISPATCH_CASE(at::ScalarType::Bits16, __VA_ARGS__)
786 
787 #define AT_DISPATCH_BIT_TYPES(TYPE, NAME, ...) \
788   AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_BIT_TYPES(__VA_ARGS__))
789 
790 #define AT_DISPATCH_INDEX_TYPES(TYPE, NAME, ...)     \
791   AT_DISPATCH_SWITCH(                                \
792       TYPE,                                          \
793       NAME,                                          \
794       AT_PRIVATE_CASE_TYPE_USING_HINT(               \
795           at::ScalarType::Int, index_t, __VA_ARGS__) \
796           AT_PRIVATE_CASE_TYPE_USING_HINT(           \
797               at::ScalarType::Long, index_t, __VA_ARGS__))
798 
799 // ----------------------------------------------------------------------------
800 // DEPRECATED MACROS, DON'T USE THESE
801 // ----------------------------------------------------------------------------
802 
803 #define AT_DISPATCH_ALL_TYPES_AND_HALF(TYPE, NAME, ...) \
804   detail::deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF();  \
805   AT_DISPATCH_SWITCH(                                   \
806       TYPE,                                             \
807       NAME,                                             \
808       AT_DISPATCH_CASE_ALL_TYPES_AND(at::ScalarType::Half, __VA_ARGS__))
809