1 #pragma once
2
3 #include <ATen/core/DeprecatedTypeProperties.h>
4 #include <c10/macros/Macros.h>
5 #include <c10/util/Exception.h>
6 #include <c10/util/Half.h>
7 #include <c10/util/Metaprogramming.h>
8 #include <c10/util/complex.h>
9 #include <c10/util/string_view.h>
10
11 #ifdef __CUDACC__
12 #include <cuda.h> // For CUDA_VERSION
13 #endif
14
15 #ifdef TEMPLATE_SELECTIVE_BUILD
16 #include <ATen/selected_mobile_ops.h>
17 #else
18 namespace at {
19 /**
20 * The method should_include_kernel_dtype() returns true/false
21 * based on whether the switching code for a specific dtype should be
22 * included based on build time constants generated from tracing model
23 * execution. This method will be implemented via code-generation and
24 * included in this file when code-gen is ready.
25 */
should_include_kernel_dtype(const char *,at::ScalarType)26 inline constexpr bool should_include_kernel_dtype(
27 const char* /*kernel_tag_str*/,
28 at::ScalarType /*scalar_type*/
29 ) {
30 return true;
31 }
32 } // namespace at
33 #endif
34
35 /**
36 * In the Facebook internal build (using BUCK), this macro is enabled by
37 * passing in -c pt.enable_record_kernel_dtype=1 when building the tracer
38 * binary.
39 */
40 #if defined ENABLE_RECORD_KERNEL_FUNCTION_DTYPE
41 namespace at {
42 namespace detail {
43 TORCH_API void record_kernel_function_dtype(std::string name);
44 }
45 } // namespace at
46
47 #define RECORD_KERNEL_FUNCTION_DTYPE(NAME, enum_type) \
48 at::detail::record_kernel_function_dtype( \
49 std::string(NAME) + "$" + toString(enum_type));
50 #else
51 #define RECORD_KERNEL_FUNCTION_DTYPE(NAME, enum_type)
52 #endif
53
54 #define AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type) \
55 do { \
56 if constexpr (!at::should_include_kernel_dtype( \
57 at_dispatch_name, enum_type)) { \
58 AT_ERROR( \
59 "dtype '", \
60 toString(enum_type), \
61 "' not selected for kernel tag ", \
62 at_dispatch_name); \
63 } \
64 } while (0)
65
66 #define AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, HINT, ...) \
67 case enum_type: { \
68 AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \
69 using HINT C10_UNUSED = c10::impl::ScalarTypeToCPPTypeT<enum_type>; \
70 return __VA_ARGS__(); \
71 }
72
73 #define AT_DISPATCH_CASE(enum_type, ...) \
74 AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, scalar_t, __VA_ARGS__)
75
76 #define AT_DISPATCH_CASE_QINT(enum_type, scalar_type, ...) \
77 case enum_type: { \
78 AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \
79 using scalar_t = scalar_type; \
80 using underlying_t C10_UNUSED = typename scalar_t::underlying; \
81 const auto& SCALAR_TYPE C10_UNUSED = enum_type; \
82 const auto& UNDERLYING_TYPE C10_UNUSED = toUnderlying(enum_type); \
83 return __VA_ARGS__(); \
84 }
85
86 #define AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
87 enum_type, scalar_type, bitwidth, qmin, qmax, ...) \
88 case enum_type: { \
89 AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \
90 using scalar_t = scalar_type; \
91 using underlying_t C10_UNUSED = typename scalar_t::underlying; \
92 const auto& SCALAR_TYPE C10_UNUSED = enum_type; \
93 const auto& UNDERLYING_TYPE C10_UNUSED = toUnderlying(enum_type); \
94 C10_UNUSED int bit_width = bitwidth; \
95 C10_UNUSED int64_t quant_min = qmin; \
96 C10_UNUSED int64_t quant_max = qmax; \
97 return __VA_ARGS__(); \
98 }
99
100 namespace detail {
101
scalar_type(at::ScalarType s)102 inline at::ScalarType scalar_type(at::ScalarType s) {
103 return s;
104 }
105
106 C10_DEPRECATED_MESSAGE(
107 "passing at::DeprecatedTypeProperties to an AT_DISPATCH macro is deprecated, "
108 "pass an at::ScalarType instead")
scalar_type(const at::DeprecatedTypeProperties & t)109 inline at::ScalarType scalar_type(const at::DeprecatedTypeProperties& t) {
110 return t.scalarType();
111 }
112
113 C10_DEPRECATED_MESSAGE(
114 "AT_DISPATCH_ALL_TYPES_AND_HALF is deprecated, "
115 "use AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, ...) instead")
deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF()116 inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF() {}
117
118 C10_DEPRECATED_MESSAGE(
119 "AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX is deprecated, "
120 "use AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::Half, ...) "
121 "instead")
deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX()122 inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
123
124 } // namespace detail
125
126 // The AT_DISPATCH_* family of macros provides the ability to
127 // conveniently generate specializations of a kernel over all of the
128 // dtypes we care about in PyTorch. We call it "dispatch" because
129 // we are "dispatching" to the correct, dtype-specific kernel.
130 //
131 // A standard usage looks like:
132 //
133 // AT_DISPATCH_ALL_TYPES(self.scalar_type(), "op_name", [&] {
134 // // Your code here, with 'scalar_t' now defined to
135 // // be the dtype in question
136 // });
137 //
138 // There are many variations of this macro, so it's important to
139 // understand exactly /which/ dtypes you want to get instantiated, as
140 // well as what the "default" set is.
141 //
142 // The default set of dtypes that are instantiated (e.g., by
143 // AT_DISPATCH_ALL_TYPES) are floating point types (float, double),
144 // and integral types (int32_t, int64_t, int16_t, int8_t, uint8_t),
145 // but NOT booleans (bool), half-precision floats (Half) or
146 // complex number (c10::complex<float>, c10::complex<double>).
147 // This "cut" is somewhat historical (the default types are the
148 // ones that TH historically supported), but it also reflects the
149 // fact that the non-default types are "poorly" behaved (booleans
150 // are NOT integers mod 2, half precision operations ~essentially
151 // don't exist on CPU, complex numbers are an experimental application).
152 //
153 // Here are the questions you should generally ask to decide which
154 // dispatch you want:
155 //
156 // 1. Is this an integral or floating point specific operation?
157 // (If so, you'll want one of the FLOATING or INTEGRAL macros.)
158 //
159 // 2. Should half be supported? (If you're on CPU, the answer is almost
160 // definitely no. If you do want support, use one of the AND_HALF
161 // macros)
162 //
163 // Much rarer situations:
164 //
165 // 3. Should bool be supported? (You often have to write your kernel
166 // differently if arithmetic operations are involved.) If so,
167 // Use AT_DISPATCH_ALL_TYPES_AND along with ScalarType::Bool
168 //
169 // 4. Should complex be supported? The answer is almost always no,
170 // unless you are working on "generic" code that should work on
171 // all dtypes.
172 //
173 // Parameters:
174 // -----------
175 //
176 // 1. The NAME argument is a "tag" that is used to trace and then
177 // conditionally compile fragments of the case statements such
178 // that the kernel functions are specialized only for the dtypes
179 // that are needed. The NAME parameter *must* be a build time
180 // const char* (can't be std::string, etc...)
181 //
182 // Please ensure that the NAME is unique for every implementation
183 // or you run the risk of over-including code for the kernel
184 // functions. There is no risk of missing out on any code, so
185 // it's mostly a risk of a Type-2 error, and not a Type-1 error.
186 //
187 // Switch-like syntax:
188 // -------------------
189 // There is also a switch-case like syntax which is useful if a kernel
190 // needs to be specialized for particular scalar types
191 //
192 // AT_DISPATCH_SWITCH(self.scalar_type(), "op_name",
193 // AT_DISPATCH_CASE_INTEGRAL_TYPES([&] {
194 // op_integral<scalar_t>(iter);
195 // })
196 // AT_DISPATCH_CASE_FLOATING_TYPES([&] {
197 // op_floating<scalar_t>(iter);
198 // })
199 // AT_DISPATCH_CASE(kBool, [&] {
200 // op_bool(iter);
201 // })
202 // );
203 //
204 // For each AT_DISPATCH_FOO macro, there is a corresponding
205 // AT_DISPATCH_CASE_FOO macro which can be used inside of an
206 // AT_DISPATCH_SWITCH block.
207
208 // NB: the the_type variable is not used, but we have kept it for
209 // backwards compatibility. It's probably not used by anyone though;
210 // but we're just being safe (and it doesn't hurt.) Note we must
211 // use it to shut up warnings about unused store.
212
213 #define AT_DISPATCH_SWITCH(TYPE, NAME, ...) \
214 [&] { \
215 const auto& the_type = TYPE; \
216 constexpr const char* at_dispatch_name = NAME; \
217 /* don't use TYPE again in case it is an expensive or side-effect op */ \
218 at::ScalarType _st = ::detail::scalar_type(the_type); \
219 RECORD_KERNEL_FUNCTION_DTYPE(at_dispatch_name, _st); \
220 switch (_st) { \
221 __VA_ARGS__ \
222 default: \
223 AT_ERROR( \
224 '"', \
225 at_dispatch_name, \
226 "\" not implemented for '", \
227 toString(_st), \
228 "'"); \
229 } \
230 }()
231
232 #define AT_DISPATCH_CASE_FLOATING_TYPES(...) \
233 AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
234 AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)
235
236 #define AT_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
237 AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
238
239 #define AT_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(...) \
240 AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
241 AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
242 AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
243
244 #define AT_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...) \
245 AT_DISPATCH_SWITCH( \
246 TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(__VA_ARGS__))
247
248 #define AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(...) \
249 AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
250 AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
251
252 #define AT_DISPATCH_REDUCED_FLOATING_TYPES(TYPE, NAME, ...) \
253 AT_DISPATCH_SWITCH( \
254 TYPE, NAME, AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(__VA_ARGS__))
255
256 #define AT_DISPATCH_CASE_FLOATING_TYPES_AND(SCALARTYPE, ...) \
257 AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
258 AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
259
260 #define AT_DISPATCH_FLOATING_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
261 AT_DISPATCH_SWITCH( \
262 TYPE, \
263 NAME, \
264 AT_DISPATCH_CASE_FLOATING_TYPES_AND(SCALARTYPE, __VA_ARGS__))
265
266 #define AT_DISPATCH_CASE_FLOATING_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, ...) \
267 AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
268 AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
269 AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
270
271 #define AT_DISPATCH_FLOATING_TYPES_AND2( \
272 SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
273 AT_DISPATCH_SWITCH( \
274 TYPE, \
275 NAME, \
276 AT_DISPATCH_CASE_FLOATING_TYPES_AND2( \
277 SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
278
279 #define AT_DISPATCH_CASE_FLOATING_TYPES_AND3( \
280 SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
281 AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
282 AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
283 AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
284 AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
285
286 #define AT_DISPATCH_FLOATING_TYPES_AND3( \
287 SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
288 AT_DISPATCH_SWITCH( \
289 TYPE, \
290 NAME, \
291 AT_DISPATCH_CASE_FLOATING_TYPES_AND3( \
292 SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
293
294 #define AT_DISPATCH_CASE_FLOATING_TYPES_AND4( \
295 SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \
296 AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
297 AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
298 AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
299 AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
300 AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)
301
302 #define AT_DISPATCH_FLOATING_TYPES_AND4( \
303 SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
304 AT_DISPATCH_SWITCH( \
305 TYPE, \
306 NAME, \
307 AT_DISPATCH_CASE_FLOATING_TYPES_AND4( \
308 SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
309
310 #define AT_DISPATCH_CASE_COMPLEX_TYPES(...) \
311 AT_DISPATCH_CASE(at::ScalarType::ComplexDouble, __VA_ARGS__) \
312 AT_DISPATCH_CASE(at::ScalarType::ComplexFloat, __VA_ARGS__)
313
314 #define AT_DISPATCH_COMPLEX_TYPES(TYPE, NAME, ...) \
315 AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__))
316
317 #define AT_DISPATCH_CASE_COMPLEX_TYPES_AND(SCALARTYPE, ...) \
318 AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__) \
319 AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
320
321 #define AT_DISPATCH_COMPLEX_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
322 AT_DISPATCH_SWITCH( \
323 TYPE, NAME, AT_DISPATCH_CASE_COMPLEX_TYPES_AND(SCALARTYPE, __VA_ARGS__))
324
325 #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(...) \
326 AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
327 AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__)
328
329 #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \
330 AT_DISPATCH_SWITCH( \
331 TYPE, NAME, AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__))
332
333 #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND1(SCALARTYPE, ...) \
334 AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
335 AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
336
337 #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1( \
338 SCALARTYPE, TYPE, NAME, ...) \
339 AT_DISPATCH_SWITCH( \
340 TYPE, \
341 NAME, \
342 AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND1( \
343 SCALARTYPE, __VA_ARGS__))
344
345 #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND2( \
346 SCALARTYPE1, SCALARTYPE2, ...) \
347 AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
348 AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
349 AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
350
351 #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( \
352 SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
353 AT_DISPATCH_SWITCH( \
354 TYPE, \
355 NAME, \
356 AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND2( \
357 SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
358
359 #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND3( \
360 SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
361 AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
362 AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
363 AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
364 AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
365
366 #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3( \
367 SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
368 AT_DISPATCH_SWITCH( \
369 TYPE, \
370 NAME, \
371 AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND3( \
372 SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
373
374 #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND4( \
375 SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \
376 AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
377 AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
378 AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
379 AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
380 AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)
381
382 #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4( \
383 SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
384 AT_DISPATCH_SWITCH( \
385 TYPE, \
386 NAME, \
387 AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND4( \
388 SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
389
390 #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND5( \
391 SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, ...) \
392 AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
393 AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
394 AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
395 AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
396 AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
397 AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__)
398
399 #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND5( \
400 SCALARTYPE1, \
401 SCALARTYPE2, \
402 SCALARTYPE3, \
403 SCALARTYPE4, \
404 SCALARTYPE5, \
405 TYPE, \
406 NAME, \
407 ...) \
408 AT_DISPATCH_SWITCH( \
409 TYPE, \
410 NAME, \
411 AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND5( \
412 SCALARTYPE1, \
413 SCALARTYPE2, \
414 SCALARTYPE3, \
415 SCALARTYPE4, \
416 SCALARTYPE5, \
417 __VA_ARGS__))
418
419 #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND6( \
420 SCALARTYPE1, \
421 SCALARTYPE2, \
422 SCALARTYPE3, \
423 SCALARTYPE4, \
424 SCALARTYPE5, \
425 SCALARTYPE6, \
426 ...) \
427 AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
428 AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
429 AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
430 AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
431 AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
432 AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \
433 AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__)
434
435 #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND6( \
436 SCALARTYPE1, \
437 SCALARTYPE2, \
438 SCALARTYPE3, \
439 SCALARTYPE4, \
440 SCALARTYPE5, \
441 SCALARTYPE6, \
442 TYPE, \
443 NAME, \
444 ...) \
445 AT_DISPATCH_SWITCH( \
446 TYPE, \
447 NAME, \
448 AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND6( \
449 SCALARTYPE1, \
450 SCALARTYPE2, \
451 SCALARTYPE3, \
452 SCALARTYPE4, \
453 SCALARTYPE5, \
454 SCALARTYPE6, \
455 __VA_ARGS__))
456
457 #define AT_DISPATCH_CASE_INTEGRAL_TYPES(...) \
458 AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
459 AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
460 AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
461 AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
462 AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__)
463
464 #define AT_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
465 AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
466
467 #define AT_DISPATCH_CASE_INTEGRAL_TYPES_AND(SCALARTYPE, ...) \
468 AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__) \
469 AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
470
471 #define AT_DISPATCH_INTEGRAL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
472 AT_DISPATCH_SWITCH( \
473 TYPE, \
474 NAME, \
475 AT_DISPATCH_CASE_INTEGRAL_TYPES_AND(SCALARTYPE, __VA_ARGS__))
476
477 #define AT_DISPATCH_CASE_ALL_TYPES(...) \
478 AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__) \
479 AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)
480
481 #define AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \
482 AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__))
483
484 #define AT_DISPATCH_CASE_QINT_TYPES(...) \
485 AT_DISPATCH_CASE_QINT(at::kQInt8, at::qint8, __VA_ARGS__) \
486 AT_DISPATCH_CASE_QINT(at::kQUInt8, at::quint8, __VA_ARGS__) \
487 AT_DISPATCH_CASE_QINT(at::kQInt32, at::qint32, __VA_ARGS__)
488
489 #define AT_DISPATCH_QINT_TYPES(TYPE, NAME, ...) \
490 AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_QINT_TYPES(__VA_ARGS__))
491
492 #define AT_DISPATCH_CASE_QINT_TYPES_AND(SCALARTYPE, ...) \
493 AT_DISPATCH_CASE_QINT_TYPES(__VA_ARGS__) \
494 AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
495
496 #define AT_DISPATCH_QINT_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
497 AT_DISPATCH_SWITCH( \
498 TYPE, NAME, AT_DISPATCH_CASE_QINT_TYPES_AND(SCALARTYPE, __VA_ARGS__))
499
500 #define AT_DISPATCH_CASE_QINT_BYTE_TYPES(...) \
501 AT_DISPATCH_CASE_QINT(at::kQInt8, at::qint8, __VA_ARGS__) \
502 AT_DISPATCH_CASE_QINT(at::kQUInt8, at::quint8, __VA_ARGS__)
503
504 #define AT_DISPATCH_QINT_BYTE_TYPES(TYPE, NAME, ...) \
505 AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_QINT_BYTE_TYPES(__VA_ARGS__))
506
507 #define AT_DISPATCH_CASE_QINT_AND_SUB_BYTE_TYPES(...) \
508 AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
509 at::kQInt8, at::qint8, CHAR_BIT, SCHAR_MIN, SCHAR_MAX, __VA_ARGS__) \
510 AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
511 at::kQUInt8, at::quint8, CHAR_BIT, 0, UCHAR_MAX, __VA_ARGS__) \
512 AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
513 at::kQInt32, \
514 at::qint32, \
515 CHAR_BIT * sizeof(int), \
516 INT_MIN, \
517 INT_MAX, \
518 __VA_ARGS__) \
519 AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
520 at::kQUInt4x2, at::quint4x2, 4, 0, 15, __VA_ARGS__) \
521 AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
522 at::kQUInt2x4, at::quint2x4, 2, 0, 3, __VA_ARGS__)
523
524 #define AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(TYPE, NAME, ...) \
525 AT_DISPATCH_SWITCH( \
526 TYPE, NAME, AT_DISPATCH_CASE_QINT_AND_SUB_BYTE_TYPES(__VA_ARGS__))
527
528 #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(...) \
529 AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
530 AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__)
531
532 #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX(TYPE, NAME, ...) \
533 AT_DISPATCH_SWITCH( \
534 TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__))
535
536 #define AT_DISPATCH_CASE_ALL_TYPES_AND(SCALARTYPE, ...) \
537 AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
538 AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
539
540 #define AT_DISPATCH_ALL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
541 AT_DISPATCH_SWITCH( \
542 TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES_AND(SCALARTYPE, __VA_ARGS__))
543
544 #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, ...) \
545 AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
546 AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
547
548 #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, TYPE, NAME, ...) \
549 AT_DISPATCH_SWITCH( \
550 TYPE, \
551 NAME, \
552 AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, __VA_ARGS__))
553
554 #define AT_DISPATCH_CASE_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, ...) \
555 AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
556 AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
557 AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
558
559 #define AT_DISPATCH_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
560 AT_DISPATCH_SWITCH( \
561 TYPE, \
562 NAME, \
563 AT_DISPATCH_CASE_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
564
565 #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND2( \
566 SCALARTYPE1, SCALARTYPE2, ...) \
567 AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
568 AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
569 AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
570
571 #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( \
572 SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
573 AT_DISPATCH_SWITCH( \
574 TYPE, \
575 NAME, \
576 AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND2( \
577 SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
578
579 #define AT_DISPATCH_CASE_ALL_TYPES_AND3( \
580 SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
581 AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
582 AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
583 AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
584 AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
585
586 #define AT_DISPATCH_ALL_TYPES_AND3( \
587 SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
588 AT_DISPATCH_SWITCH( \
589 TYPE, \
590 NAME, \
591 AT_DISPATCH_CASE_ALL_TYPES_AND3( \
592 SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
593
594 #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND3( \
595 SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
596 AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
597 AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
598 AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
599 AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
600
601 #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( \
602 SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
603 AT_DISPATCH_SWITCH( \
604 TYPE, \
605 NAME, \
606 AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND3( \
607 SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
608
609 #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \
610 SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \
611 AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
612 AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
613 AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
614 AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
615 AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)
616
617 #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
618 SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
619 AT_DISPATCH_SWITCH( \
620 TYPE, \
621 NAME, \
622 AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \
623 SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
624
625 #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND5( \
626 SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, ...) \
627 AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
628 AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
629 AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
630 AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
631 AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
632 AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__)
633
634 #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND5( \
635 SCALARTYPE1, \
636 SCALARTYPE2, \
637 SCALARTYPE3, \
638 SCALARTYPE4, \
639 SCALARTYPE5, \
640 TYPE, \
641 NAME, \
642 ...) \
643 AT_DISPATCH_SWITCH( \
644 TYPE, \
645 NAME, \
646 AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND5( \
647 SCALARTYPE1, \
648 SCALARTYPE2, \
649 SCALARTYPE3, \
650 SCALARTYPE4, \
651 SCALARTYPE5, \
652 __VA_ARGS__))
653
654 #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND6( \
655 SCALARTYPE1, \
656 SCALARTYPE2, \
657 SCALARTYPE3, \
658 SCALARTYPE4, \
659 SCALARTYPE5, \
660 SCALARTYPE6, \
661 ...) \
662 AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
663 AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
664 AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
665 AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
666 AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
667 AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \
668 AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__)
669
670 #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6( \
671 SCALARTYPE1, \
672 SCALARTYPE2, \
673 SCALARTYPE3, \
674 SCALARTYPE4, \
675 SCALARTYPE5, \
676 SCALARTYPE6, \
677 TYPE, \
678 NAME, \
679 ...) \
680 AT_DISPATCH_SWITCH( \
681 TYPE, \
682 NAME, \
683 AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND6( \
684 SCALARTYPE1, \
685 SCALARTYPE2, \
686 SCALARTYPE3, \
687 SCALARTYPE4, \
688 SCALARTYPE5, \
689 SCALARTYPE6, \
690 __VA_ARGS__))
691
692 #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND7( \
693 SCALARTYPE1, \
694 SCALARTYPE2, \
695 SCALARTYPE3, \
696 SCALARTYPE4, \
697 SCALARTYPE5, \
698 SCALARTYPE6, \
699 SCALARTYPE7, \
700 ...) \
701 AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
702 AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
703 AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
704 AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
705 AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
706 AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \
707 AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__) \
708 AT_DISPATCH_CASE(SCALARTYPE7, __VA_ARGS__)
709
710 #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND7( \
711 SCALARTYPE1, \
712 SCALARTYPE2, \
713 SCALARTYPE3, \
714 SCALARTYPE4, \
715 SCALARTYPE5, \
716 SCALARTYPE6, \
717 SCALARTYPE7, \
718 TYPE, \
719 NAME, \
720 ...) \
721 AT_DISPATCH_SWITCH( \
722 TYPE, \
723 NAME, \
724 AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND7( \
725 SCALARTYPE1, \
726 SCALARTYPE2, \
727 SCALARTYPE3, \
728 SCALARTYPE4, \
729 SCALARTYPE5, \
730 SCALARTYPE6, \
731 SCALARTYPE7, \
732 __VA_ARGS__))
733
734 #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND8( \
735 SCALARTYPE1, \
736 SCALARTYPE2, \
737 SCALARTYPE3, \
738 SCALARTYPE4, \
739 SCALARTYPE5, \
740 SCALARTYPE6, \
741 SCALARTYPE7, \
742 SCALARTYPE8, \
743 ...) \
744 AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
745 AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
746 AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
747 AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
748 AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
749 AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \
750 AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__) \
751 AT_DISPATCH_CASE(SCALARTYPE7, __VA_ARGS__) \
752 AT_DISPATCH_CASE(SCALARTYPE8, __VA_ARGS__)
753
754 #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND8( \
755 SCALARTYPE1, \
756 SCALARTYPE2, \
757 SCALARTYPE3, \
758 SCALARTYPE4, \
759 SCALARTYPE5, \
760 SCALARTYPE6, \
761 SCALARTYPE7, \
762 SCALARTYPE8, \
763 TYPE, \
764 NAME, \
765 ...) \
766 AT_DISPATCH_SWITCH( \
767 TYPE, \
768 NAME, \
769 AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND8( \
770 SCALARTYPE1, \
771 SCALARTYPE2, \
772 SCALARTYPE3, \
773 SCALARTYPE4, \
774 SCALARTYPE5, \
775 SCALARTYPE6, \
776 SCALARTYPE7, \
777 SCALARTYPE8, \
778 __VA_ARGS__))
779
780 #define AT_DISPATCH_CASE_BIT_TYPES(...) \
781 AT_DISPATCH_CASE(at::ScalarType::Bits1x8, __VA_ARGS__) \
782 AT_DISPATCH_CASE(at::ScalarType::Bits2x4, __VA_ARGS__) \
783 AT_DISPATCH_CASE(at::ScalarType::Bits4x2, __VA_ARGS__) \
784 AT_DISPATCH_CASE(at::ScalarType::Bits8, __VA_ARGS__) \
785 AT_DISPATCH_CASE(at::ScalarType::Bits16, __VA_ARGS__)
786
787 #define AT_DISPATCH_BIT_TYPES(TYPE, NAME, ...) \
788 AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_BIT_TYPES(__VA_ARGS__))
789
790 #define AT_DISPATCH_INDEX_TYPES(TYPE, NAME, ...) \
791 AT_DISPATCH_SWITCH( \
792 TYPE, \
793 NAME, \
794 AT_PRIVATE_CASE_TYPE_USING_HINT( \
795 at::ScalarType::Int, index_t, __VA_ARGS__) \
796 AT_PRIVATE_CASE_TYPE_USING_HINT( \
797 at::ScalarType::Long, index_t, __VA_ARGS__))
798
799 // ----------------------------------------------------------------------------
800 // DEPRECATED MACROS, DON'T USE THESE
801 // ----------------------------------------------------------------------------
802
803 #define AT_DISPATCH_ALL_TYPES_AND_HALF(TYPE, NAME, ...) \
804 detail::deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF(); \
805 AT_DISPATCH_SWITCH( \
806 TYPE, \
807 NAME, \
808 AT_DISPATCH_CASE_ALL_TYPES_AND(at::ScalarType::Half, __VA_ARGS__))
809