1 #pragma once
2
3 #include <ATen/ATen.h>
4 #include <ATen/NativeFunctions.h>
5 #include <ATen/Operators.h>
6 #include <torch/library.h>
7
8 #include <c10/core/impl/LocalDispatchKeySet.h>
9 #include <c10/util/intrusive_ptr.h>
10
11 namespace at::autocast {
12
13 TORCH_API bool is_autocast_enabled(at::DeviceType device_type);
14 TORCH_API void set_autocast_enabled(at::DeviceType device_type, bool enabled);
15 TORCH_API at::ScalarType get_autocast_dtype(at::DeviceType device_type);
16 TORCH_API void set_autocast_dtype(
17 at::DeviceType device_type,
18 at::ScalarType dtype);
19 TORCH_API void clear_cache();
20 TORCH_API int increment_nesting();
21 TORCH_API int decrement_nesting();
22 TORCH_API bool is_autocast_cache_enabled();
23 TORCH_API void set_autocast_cache_enabled(bool enabled);
24
25 // deprecated CUDA-specific autocast APIs
26 C10_DEPRECATED_MESSAGE(
27 "at::autocast::is_enabled() is deprecated. Please use at::autocast::is_autocast_enabled(at::kCUDA) instead.")
is_enabled()28 TORCH_API inline bool is_enabled() {
29 TORCH_WARN_DEPRECATION(
30 "at::autocast::",
31 __func__,
32 "() is deprecated. Please use at::autocast::is_autocast_enabled(at::kCUDA) instead.")
33 return is_autocast_enabled(at::kCUDA);
34 }
35 C10_DEPRECATED_MESSAGE(
36 "at::autocast::set_enabled(enabled) is deprecated. Please use at::autocast::set_autocast_enabled(at::kCUDA, enabled) instead.")
set_enabled(bool enabled)37 TORCH_API inline void set_enabled(bool enabled) {
38 TORCH_WARN_DEPRECATION(
39 "at::autocast::",
40 __func__,
41 "(enabled) is deprecated. Please use at::autocast::set_autocast_enabled(at::kCUDA, enabled) instead.")
42 set_autocast_enabled(at::kCUDA, enabled);
43 }
44 C10_DEPRECATED_MESSAGE(
45 "at::autocast::get_autocast_gpu_dtype() is deprecated. Please use at::autocast::get_autocast_dtype(at::kCUDA) instead.")
get_autocast_gpu_dtype()46 TORCH_API inline at::ScalarType get_autocast_gpu_dtype() {
47 TORCH_WARN_DEPRECATION(
48 "at::autocast::",
49 __func__,
50 "() is deprecated. Please use at::autocast::get_autocast_dtype(at::kCUDA) instead.")
51 return get_autocast_dtype(at::kCUDA);
52 }
53 C10_DEPRECATED_MESSAGE(
54 "at::autocast::set_autocast_gpu_dtype(dtype) is deprecated. Please use at::autocast::set_autocast_dtype(at::kCUDA, dtype) instead.")
set_autocast_gpu_dtype(at::ScalarType dtype)55 TORCH_API inline void set_autocast_gpu_dtype(at::ScalarType dtype) {
56 TORCH_WARN_DEPRECATION(
57 "at::autocast::",
58 __func__,
59 "(dtype) is deprecated. Please use at::autocast::set_autocast_dtype(at::kCUDA, dtype) instead.")
60 set_autocast_dtype(at::kCUDA, dtype);
61 }
62
63 #define DECLARE_DEPRECATED_AUTOCAST_APIS(name, device_type) \
64 C10_DEPRECATED_MESSAGE( \
65 "at::autocast::is_" #name \
66 "_enabled() is deprecated. Please use at::autocast::is_autocast_enabled(" #device_type \
67 ") instead.") \
68 TORCH_API inline bool is_##name##_enabled() { \
69 TORCH_WARN_DEPRECATION( \
70 "at::autocast::", \
71 __func__, \
72 "() is deprecated. Please use at::autocast::is_autocast_enabled(" #device_type \
73 ") instead.") \
74 return is_autocast_enabled(device_type); \
75 } \
76 \
77 C10_DEPRECATED_MESSAGE( \
78 "at::autocast::set_" #name \
79 "_enabled(enabled) is deprecated. Please use at::autocast::set_autocast_enabled(" #device_type \
80 ", enabled) instead.") \
81 TORCH_API inline void set_##name##_enabled(bool enabled) { \
82 TORCH_WARN_DEPRECATION( \
83 "at::autocast::", \
84 __func__, \
85 "(enabled) is deprecated. Please use at::autocast::set_autocast_enabled(" #device_type \
86 ", enabled) instead.") \
87 set_autocast_enabled(device_type, enabled); \
88 } \
89 \
90 C10_DEPRECATED_MESSAGE( \
91 "at::autocast::get_autocast_" #name \
92 "_dtype() is deprecated. Please use at::autocast::get_autocast_dtype(" #device_type \
93 ") instead.") \
94 TORCH_API inline at::ScalarType get_autocast_##name##_dtype() { \
95 TORCH_WARN_DEPRECATION( \
96 "at::autocast::", \
97 __func__, \
98 "() is deprecated. Please at::autocast::get_autocast_dtype(" #device_type \
99 ") instead.") \
100 return get_autocast_dtype(device_type); \
101 } \
102 \
103 C10_DEPRECATED_MESSAGE( \
104 "at::autocast::set_autocast_" #name \
105 "_dtype(dtype) is deprecated. Please use at::autocast::set_autocast_dtype(" #device_type \
106 ", dtype) instead.") \
107 TORCH_API inline void set_autocast_##name##_dtype(at::ScalarType dtype) { \
108 TORCH_WARN_DEPRECATION( \
109 "at::autocast::", \
110 __func__, \
111 "(dtype) is deprecated. Please use at::autocast::set_autocast_dtype(" #device_type \
112 ", dtype) instead.") \
113 set_autocast_dtype(device_type, dtype); \
114 }
115
116 #define AT_FORALL_DEPRECATED_AUTOCAST_BAKCNEDS(_) \
117 _(cpu, at::kCPU) \
118 _(xpu, at::kXPU) \
119 _(xla, at::kXLA) \
120 _(hpu, at::kHPU) \
121 _(ipu, at::kIPU) \
122 _(privateuseone, at::kPrivateUse1)
123
124 // deprecated other backend specific autocast APIs
AT_FORALL_DEPRECATED_AUTOCAST_BAKCNEDS(DECLARE_DEPRECATED_AUTOCAST_APIS)125 AT_FORALL_DEPRECATED_AUTOCAST_BAKCNEDS(DECLARE_DEPRECATED_AUTOCAST_APIS)
126
127 namespace {
128 inline bool is_autocast_eligible(
129 const Tensor& tensor,
130 c10::DeviceType device_type) {
131 switch (device_type) {
132 case c10::DeviceType::CUDA:
133 return (tensor.is_cuda() || tensor.is_xla()) &&
134 tensor.is_floating_point();
135 case c10::DeviceType::CPU:
136 return (tensor.is_cpu() || tensor.is_mkldnn()) &&
137 tensor.is_floating_point();
138 case c10::DeviceType::XPU:
139 return tensor.is_xpu() && tensor.is_floating_point();
140 case c10::DeviceType::IPU:
141 return tensor.is_ipu() && tensor.is_floating_point();
142 case c10::DeviceType::HPU:
143 return tensor.is_hpu() && tensor.is_floating_point();
144 case c10::DeviceType::XLA:
145 return tensor.is_xla() && tensor.is_floating_point();
146 case c10::DeviceType::PrivateUse1:
147 return tensor.is_privateuseone() && tensor.is_floating_point();
148 case c10::DeviceType::MPS:
149 return tensor.is_mps() && tensor.is_floating_point();
150 default:
151 return false;
152 }
153 }
154 } // namespace
155
get_autocast_dispatch_key_from_device_type(c10::DeviceType device_type)156 inline DispatchKey get_autocast_dispatch_key_from_device_type(
157 c10::DeviceType device_type) {
158 switch (device_type) {
159 case c10::DeviceType::CUDA:
160 return DispatchKey::Autocast;
161 case c10::DeviceType::CPU:
162 return DispatchKey::AutocastCPU;
163 case c10::DeviceType::XPU:
164 return DispatchKey::AutocastXPU;
165 case c10::DeviceType::IPU:
166 return DispatchKey::AutocastIPU;
167 case c10::DeviceType::HPU:
168 return DispatchKey::AutocastHPU;
169 case c10::DeviceType::XLA:
170 return DispatchKey::AutocastXLA;
171 case c10::DeviceType::PrivateUse1:
172 return DispatchKey::AutocastPrivateUse1;
173 case c10::DeviceType::MPS:
174 return DispatchKey::AutocastMPS;
175 default:
176 throw std::runtime_error(
177 "unknown device type for autocast in get_autocast_dispatch_key_from_device_type");
178 }
179 }
180
is_autocast_available(c10::DeviceType device_type)181 inline bool is_autocast_available(c10::DeviceType device_type) {
182 if (device_type == at::kCPU || device_type == at::kCUDA ||
183 device_type == at::kXPU || device_type == at::kIPU ||
184 device_type == at::kHPU || device_type == at::kXLA ||
185 device_type == at::kPrivateUse1 || device_type == at::kMPS) {
186 return true;
187 } else {
188 return false;
189 }
190 }
191
get_lower_precision_fp_from_device_type(c10::DeviceType device_type)192 inline at::ScalarType get_lower_precision_fp_from_device_type(
193 c10::DeviceType device_type) {
194 if (is_autocast_available(device_type)) {
195 return get_autocast_dtype(device_type);
196 } else {
197 throw std::runtime_error(
198 "unknown device type for autocast in get_lower_precision_fp_from_device_type");
199 }
200 }
201
202 /********************************************************************
203 Logic to extract the promote type from any Tensor or TensorList args.
204 ********************************************************************/
205
206 // Overload to catch Tensor args.
207 // If nextArg is floating-point, compare its scalar_type with our
208 // current best guess for the promote type, and update if necessary.
209 inline at::ScalarType prioritize(
210 at::ScalarType current,
211 const Tensor& nextArg,
212 c10::DeviceType device_type = c10::DeviceType::CUDA) {
213 if (current == at::kDouble) {
214 AT_ERROR("promote type is double in at::autocast::prioritize");
215 return current;
216 }
217 at::ScalarType lower_precision_fp =
218 get_lower_precision_fp_from_device_type(device_type);
219 if (is_autocast_eligible(nextArg, device_type)) {
220 auto next = nextArg.scalar_type();
221 if (next == at::kDouble) {
222 return current; // ignores double tensors
223 } else if (current == at::kFloat || next == at::kFloat) {
224 return at::kFloat; // prioritizes float over lower_precision_fp
225 } else if (current == lower_precision_fp && next == lower_precision_fp) {
226 return lower_precision_fp;
227 } else {
228 AT_ERROR("Unexpected floating ScalarType in at::autocast::prioritize");
229 return current;
230 }
231 } else {
232 return current;
233 }
234 }
235
236 // Overload to catch TensorList args (for e.g. cat, stack).
237 // Reuses the overload above to process each Tensor in the list.
238 inline at::ScalarType prioritize(
239 at::ScalarType current,
240 const TensorList& list,
241 c10::DeviceType device_type = c10::DeviceType::CUDA) {
242 for (const auto& tensor : list) {
243 current = prioritize(current, tensor, device_type);
244 }
245 return current;
246 }
247
248 inline at::ScalarType prioritize(
249 at::ScalarType current,
250 const ITensorListRef& list,
251 c10::DeviceType device_type = c10::DeviceType::CUDA) {
252 for (const auto& tensor : list) {
253 current = prioritize(current, tensor, device_type);
254 }
255 return current;
256 }
257
258 // Template to catch non-Tensor args (no-op that returns current best guess)
259 template <typename T>
260 inline at::ScalarType prioritize(
261 at::ScalarType current,
262 T nextArg,
263 c10::DeviceType device_type = c10::DeviceType::CUDA) {
264 return current;
265 }
266
267 // Overload for the tail case.
promote_type(at::ScalarType current,c10::DeviceType device_type)268 inline at::ScalarType promote_type(
269 at::ScalarType current,
270 c10::DeviceType device_type) {
271 return current;
272 }
273
274 // Unpack args and determine if incoming lower_precision_fp tensors need to be
275 // promoted to float32. Non-Tensor arguments are ignored.
276 template <typename Arg0, typename... Args>
promote_type(at::ScalarType current,c10::DeviceType device_type,Arg0 arg0,Args...args)277 inline at::ScalarType promote_type(
278 at::ScalarType current,
279 c10::DeviceType device_type,
280 Arg0 arg0,
281 Args... args) {
282 auto new_current = prioritize(current, arg0, device_type);
283 return promote_type(new_current, device_type, args...);
284 }
285
286 /****************************************************
287 Logic to apply cached casting to any Tensor argument.
288 ****************************************************/
289 inline bool is_eligible(
290 const Tensor& arg,
291 c10::DeviceType device_type = c10::DeviceType::CUDA) {
292 return (
293 arg.defined() && is_autocast_eligible(arg, device_type) &&
294 (arg.scalar_type() != at::kDouble));
295 }
296
297 // Overload to catch Tensor args
298 TORCH_API Tensor cached_cast(
299 at::ScalarType to_type,
300 const Tensor& arg,
301 c10::DeviceType device_type = c10::DeviceType::CUDA);
302
303 // Overload to process std::optional<Tensor>
304 inline std::optional<Tensor> cached_cast(
305 at::ScalarType to_type,
306 const std::optional<Tensor>& arg,
307 c10::DeviceType device_type = c10::DeviceType::CUDA) {
308 if (arg.has_value()) {
309 return cached_cast(to_type, *arg, device_type);
310 } else {
311 return std::nullopt;
312 }
313 }
314
315 // Overload to process TensorLists
316 inline std::vector<Tensor> cached_cast(
317 at::ScalarType to_type,
318 const TensorList& arg,
319 c10::DeviceType device_type = c10::DeviceType::CUDA) {
320 std::vector<Tensor> vec;
321 vec.reserve(arg.size());
322 for (const auto& t : arg) {
323 vec.emplace_back(cached_cast(to_type, t, device_type));
324 }
325 return vec;
326 }
327
328 inline std::vector<Tensor> cached_cast(
329 at::ScalarType to_type,
330 const ITensorListRef& arg,
331 c10::DeviceType device_type = c10::DeviceType::CUDA) {
332 std::vector<Tensor> vec;
333 vec.reserve(arg.size());
334 for (const auto& t : arg) {
335 vec.emplace_back(cached_cast(to_type, t, device_type));
336 }
337 return vec;
338 }
339
340 // Template to catch non-Tensor args.
341 template <typename T>
342 inline T cached_cast(
343 at::ScalarType to_type,
344 T arg,
345 c10::DeviceType device_type = c10::DeviceType::CUDA) {
346 return arg;
347 }
348
349 /*******************************************************
350 Logic to flip an output dtype flag.
351 Keep it simple for now by assuming only one such flag is
352 present in the argument list. If I ever need a function
353 with more than flag I'll figure out something else.
354 The policy is:
355 If the user has explicity specified a dtype, respect it.
356 Otherwise, set it to the autocast type.
357 ********************************************************/
358
359 // Overload to catch dtype flags
set_opt_dtype(at::ScalarType to_type,const std::optional<ScalarType> & dtype)360 std::optional<ScalarType> inline set_opt_dtype(
361 at::ScalarType to_type,
362 const std::optional<ScalarType>& dtype) {
363 return dtype.has_value() ? dtype : to_type;
364 }
365
366 // Template to catch other args
367 template <typename T>
set_opt_dtype(at::ScalarType to_type,T arg)368 inline T set_opt_dtype(at::ScalarType to_type, T arg) {
369 return arg;
370 }
371
372 template <typename... Args>
firstarg_is_eligible(c10::DeviceType device_type,const Tensor & arg,Args...args)373 inline bool firstarg_is_eligible(
374 c10::DeviceType device_type,
375 const Tensor& arg,
376 Args... args) {
377 return is_eligible(arg, device_type);
378 }
379
380 template <typename... Args>
type_from_firstarg(c10::DeviceType device_type,at::ScalarType to_type,const Tensor & arg,Args...args)381 inline at::ScalarType type_from_firstarg(
382 c10::DeviceType device_type,
383 at::ScalarType to_type,
384 const Tensor& arg,
385 Args... args) {
386 return (is_eligible(arg, device_type) ? to_type : arg.scalar_type());
387 }
388
389 // Policies correspond to op categories that need code-divergent handling.
390 // Wrapper templates below are specialized based on a policy template parameter.
391 enum class CastPolicy : uint8_t {
392 lower_precision_fp = 0, // Cast all inputs to lower_precision_fp before
393 // running the op. Currently, lower_precision_fp is
394 // fp16 for AutocastCUDA, and is defined by user
395 // (default bf16) for AutocastCPU or other device.
396 fp32, // Cast all inputs to at::kFloat before running the op.
397 fp32_set_opt_dtype, // Treats functions (like softmax) that
398 // 1. we'd like to run in fp32 and
399 // 2. have a std::optional<ScalarType> arg that controls
400 // the output type.
401 // fp32_set_opt_dtype wrappers' policy is: if the output
402 // type is already set, don't touch it, otherwise, set
403 // it to at::kFloat.
404 fp32_append_dtype, // Treats functions (like norm) that
405 // 1. we'd like to run in fp32 and
406 // 2. have some overloads that accept an output type and
407 // other overloads that don't.
408 // fp32_append_dtype wrappers wrap the overloads that don't
409 // have an output dtype.
410 // The wrapper policy is: append at::kFloat to the args,
411 // and redispatch to the type-aware overload.
412 promote, // Run in the widest dtype among several args.
413 };
414
415 /********************************************************************************************************
416 Templates to provide wrapper functions
417
418 I'm copying the pattern used in core/boxing/impl/WrapFunctionIntoFunctor.h to
419 extract args and return type. (see also
420 https://stackoverflow.com/questions/46533698/how-to-deduce-argument-list-from-function-pointer)
421
422 This strategy uses an exterior "WrapFunction" that extracts arguments on behalf
423 of (in my case several specializations of) an interior "WrapFunction_".
424 Interior WrapFunction_ specializations are defined for each CastPolicy.
425 ********************************************************************************************************/
426
427 // Base template for WrapFunction_, which is specialized to contain a "call"
428 // method each CastPolicy
429 template <
430 CastPolicy policy,
431 c10::DeviceType device_type,
432 class Redispatch,
433 Redispatch* F,
434 class Ret,
435 class ArgList>
436 struct WrapFunction_ {};
437
438 // CastPolicy::lower_precision_fp General_DeviceType
439 template <
440 c10::DeviceType device_type,
441 class Redispatch,
442 Redispatch* F,
443 class Ret,
444 class... Args>
445 struct WrapFunction_<
446 CastPolicy::lower_precision_fp,
447 device_type,
448 Redispatch,
449 F,
450 Ret,
451 guts::typelist::typelist<Args...>> {
452 static Ret call(Args... args) {
453 c10::impl::ExcludeDispatchKeyGuard no_autocast(
454 get_autocast_dispatch_key_from_device_type(device_type));
455 return (*F)(cached_cast(
456 get_lower_precision_fp_from_device_type(device_type),
457 args,
458 device_type)...);
459 }
460 };
461
462 // CastPolicy::fp32 General_DeviceType
463 template <
464 c10::DeviceType device_type,
465 class Redispatch,
466 Redispatch* F,
467 class Ret,
468 class... Args>
469 struct WrapFunction_<
470 CastPolicy::fp32,
471 device_type,
472 Redispatch,
473 F,
474 Ret,
475 guts::typelist::typelist<Args...>> {
476 static Ret call(Args... args) {
477 c10::impl::ExcludeDispatchKeyGuard no_autocast(
478 get_autocast_dispatch_key_from_device_type(device_type));
479 return (*F)(cached_cast(at::kFloat, args, device_type)...);
480 }
481 };
482
483 // CastPolicy::fp32_set_opt_dtype General_DeviceType
484 template <
485 c10::DeviceType device_type,
486 class Redispatch,
487 Redispatch* F,
488 class Ret,
489 class... Args>
490 struct WrapFunction_<
491 CastPolicy::fp32_set_opt_dtype,
492 device_type,
493 Redispatch,
494 F,
495 Ret,
496 guts::typelist::typelist<Args...>> {
497 static Ret call(Args... args) {
498 c10::impl::ExcludeDispatchKeyGuard no_autocast(
499 get_autocast_dispatch_key_from_device_type(device_type));
500 if (firstarg_is_eligible(device_type, args...)) {
501 return (*F)(set_opt_dtype(at::kFloat, args)...);
502 } else {
503 // If ineligible, calls F with unaltered args. Does not set opt dtype,
504 // because setting opt dtype explicitly may interfere with internal
505 // implicit promotion decisions.
506 return (*F)(args...);
507 }
508 }
509 };
510
511 // CastPolicy::fp32_append_dtype General_DeviceType
512 template <
513 c10::DeviceType device_type,
514 class Redispatch,
515 Redispatch* F,
516 class Ret,
517 class... Args>
518 struct WrapFunction_<
519 CastPolicy::fp32_append_dtype,
520 device_type,
521 Redispatch,
522 F,
523 Ret,
524 guts::typelist::typelist<Args...>> {
525 static Ret call(Args... args) {
526 c10::impl::ExcludeDispatchKeyGuard no_autocast(
527 get_autocast_dispatch_key_from_device_type(device_type));
528 at::ScalarType out_type =
529 type_from_firstarg(device_type, at::kFloat, args...);
530 return (*F)(args..., out_type);
531 }
532 };
533
534 // CastPolicy::promote General_DeviceType
535 template <
536 c10::DeviceType device_type,
537 class Redispatch,
538 Redispatch* F,
539 class Ret,
540 class... Args>
541 struct WrapFunction_<
542 CastPolicy::promote,
543 device_type,
544 Redispatch,
545 F,
546 Ret,
547 guts::typelist::typelist<Args...>> {
548 static Ret call(Args... args) {
549 c10::impl::ExcludeDispatchKeyGuard no_autocast(
550 get_autocast_dispatch_key_from_device_type(device_type));
551 auto to_type = promote_type(
552 get_lower_precision_fp_from_device_type(device_type),
553 device_type,
554 args...);
555 return (*F)(cached_cast(to_type, args, device_type)...);
556 }
557 };
558
559 // Wrapper to infer return_type and parameter_types for WrapFunction_ (imitating
560 // core/boxing/impl/WrapFunctionIntoFunctor.h)
561 template <
562 CastPolicy policy,
563 c10::DeviceType device_type,
564 class Registered, // The signature for which we're registering. The
565 // dispatcher's calling code invokes our registered
566 // functions with arguments matching Registered, so we
567 // register WrapFunction_::call methods with a matching
568 // signature to properly field those arguments.
569 // guts::function_traits below extracts return_type and
570 // parameter_types from Registered, which WrapFunction_
571 // templates above use to declare their call methods.
572 class Redispatch, // The signature for the function we're redispatching to.
573 // In most cases this is the same as Registered, but for
574 // some ops (for example, ops where we append a dtype)
575 // it's useful to redispatch to a function with a
576 // different signature.
577 Redispatch* F> // The actual function we're redispatching to.
578 struct WrapFunction final {
579 using type = WrapFunction_<
580 policy,
581 device_type,
582 Redispatch,
583 F,
584 typename guts::function_traits<Registered>::return_type,
585 typename guts::function_traits<Registered>::parameter_types>;
586 };
587
588 /*****************************************************************************************************************
589 This section performs load-time registration for autocast wrappers.
590
591 It's debatable at what level operations should be patched. We'd like casts to
592 be autograd-exposed and precede autograd history recording, so that for
593 lower_precision_fp ops, input tensors are saved for backward in
594 lower_precision_fp rather than fp32. Saving inputs in lower_precision_fp
595 can significantly reduce a model's memory footprint.
596
597 Option 1 (strawman): Patch only at the level of explicit calls into
598 cudnn/cublas (cudnn_convolution, etc), because those are the code paths that are
599 guaranteed to use Tensor Cores, therefore they're the ones that will benefit
600 most from lower_precision_fp. Potential pitfall: convolutions (and other ops)
601 are wrapped in several layers of at::* calls. If one of those happens to record
602 autograd history, then we've lost the opportunity to save inputs in
603 lower_precision_fp.
604
605 Option 2: Patch the Python-exposed surface of calls, to make 100% sure autograd
606 history recording can't sneak in ahead of autocast. This mirrors Apex most
607 closely.
608
609 I think Option 2 is the right answer for all ops, not just convolutions. Option
610 2 is what I implement here.
611 *****************************************************************************************************************/
612
613 /********************************************************************************************************************
614 Explicit registration for out-of-place ops
615
616 The stuff below could be codegenned. Ed said
617 > you are going to have to write the function definition at some point, I
618 wouldn't try to get clever about it Therefore, for the moment, this is all
619 copy pasted in from VariableTypeEverything.cpp with appropriate substitutions.
620 ********************************************************************************************************************/
621
622 } // namespace at::autocast
623
624 #define ADD_NS(RAW_OP) at::RAW_OP
625
626 #define _KERNEL_OVERLOAD_NARG_IMPL(_0, _1, _2, N, ...) N
627 #define _KERNEL_OVERLOAD_NARG(...) \
628 C10_EXPAND_MSVC_WORKAROUND(_KERNEL_OVERLOAD_NARG_IMPL(__VA_ARGS__, 2, 1))
629
630 // Common cases where registration signature matches redispatch signature
631 // (that's why SIGNATURE is repeated in the WrapFunction instantiation)
632 #define KERNEL1(DISPATCHKEY, OP, POLICY) \
633 m.impl( \
634 TORCH_SELECTIVE_NAME("aten::" #OP), \
635 &::at::autocast::WrapFunction< \
636 ::at::autocast::CastPolicy::POLICY, \
637 DISPATCHKEY, \
638 decltype(ATEN_FN(OP)), \
639 decltype(ATEN_FN(OP)), \
640 &ATEN_FN(OP)>::type::call);
641
642 #define KERNEL2(DISPATCHKEY, OP, OVERLOAD, POLICY) \
643 m.impl( \
644 TORCH_SELECTIVE_NAME("aten::" #OP "." #OVERLOAD), \
645 &::at::autocast::WrapFunction< \
646 ::at::autocast::CastPolicy::POLICY, \
647 DISPATCHKEY, \
648 decltype(ATEN_FN2(OP, OVERLOAD)), \
649 decltype(ATEN_FN2(OP, OVERLOAD)), \
650 &ATEN_FN2(OP, OVERLOAD)>::type::call);
651
652 #define _KERNEL_DISPATCH(DISPATCHKEY, NARG, ...) \
653 C10_CONCATENATE(KERNEL, NARG)(DISPATCHKEY, __VA_ARGS__)
654
655 #define _KERNEL_IMPL(DISPATCHKEY, ...) \
656 _KERNEL_DISPATCH(DISPATCHKEY, _KERNEL_OVERLOAD_NARG(__VA_ARGS__), __VA_ARGS__)
657
658 // It will dispatch to KERNEL1 or KERNEL2 based on its inputs.
659 #define KERNEL(DISPATCHKEY, ...) _KERNEL_IMPL(DISPATCHKEY, __VA_ARGS__)
660
661 // Less-common but still useful case: redispatching to a function
662 // with a new signature (e.g. appending a dtype)
663 #define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \
664 DISPATCHKEY, \
665 REDISPATCH_FUNC, \
666 REGISTER_NAME, \
667 REGISTER_SIGNATURE, \
668 REDISPATCH_SIGNATURE, \
669 POLICY) \
670 m.impl( \
671 TORCH_SELECTIVE_NAME("aten::" REGISTER_NAME), \
672 &::at::autocast::WrapFunction< \
673 ::at::autocast::CastPolicy::POLICY, \
674 DISPATCHKEY, \
675 REGISTER_SIGNATURE, \
676 REDISPATCH_SIGNATURE, \
677 &REDISPATCH_FUNC>::type::call);
678
679 // KERNEL_CPU/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CPU
680 // registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastCPU
681 #define KERNEL_CPU(...) KERNEL(c10::DeviceType::CPU, __VA_ARGS__)
682
683 #define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CPU( \
684 REDISPATCH_FUNC, \
685 REGISTER_NAME, \
686 REGISTER_SIGNATURE, \
687 REDISPATCH_SIGNATURE, \
688 POLICY) \
689 KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \
690 c10::DeviceType::CPU, \
691 REDISPATCH_FUNC, \
692 REGISTER_NAME, \
693 REGISTER_SIGNATURE, \
694 REDISPATCH_SIGNATURE, \
695 POLICY)
696
697 // KERNEL_CUDA/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CUDA
698 // registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastCUDA
699 #define KERNEL_CUDA(...) KERNEL(c10::DeviceType::CUDA, __VA_ARGS__)
700
701 #define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CUDA( \
702 REDISPATCH_FUNC, \
703 REGISTER_NAME, \
704 REGISTER_SIGNATURE, \
705 REDISPATCH_SIGNATURE, \
706 POLICY) \
707 KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \
708 c10::DeviceType::CUDA, \
709 REDISPATCH_FUNC, \
710 REGISTER_NAME, \
711 REGISTER_SIGNATURE, \
712 REDISPATCH_SIGNATURE, \
713 POLICY)
714
715 // KERNEL_XPU/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_XPU
716 // registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastXPU
717 #define KERNEL_XPU(...) KERNEL(c10::DeviceType::XPU, __VA_ARGS__)
718
719 #define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_XPU( \
720 REDISPATCH_FUNC, \
721 REGISTER_NAME, \
722 REGISTER_SIGNATURE, \
723 REDISPATCH_SIGNATURE, \
724 POLICY) \
725 KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \
726 c10::DeviceType::XPU, \
727 REDISPATCH_FUNC, \
728 REGISTER_NAME, \
729 REGISTER_SIGNATURE, \
730 REDISPATCH_SIGNATURE, \
731 POLICY)
732
733 // KERNEL_PRIVATEUSEONE/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_PRIVATEUSEONE
734 // registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastPrivateUse1
735 #define KERNEL_PRIVATEUSEONE(...) \
736 KERNEL(c10::DeviceType::PrivateUse1, __VA_ARGS__)
737
738 #define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_PRIVATEUSEONE( \
739 REDISPATCH_FUNC, \
740 REGISTER_NAME, \
741 REGISTER_SIGNATURE, \
742 REDISPATCH_SIGNATURE, \
743 POLICY) \
744 KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \
745 c10::DeviceType::PrivateUse1, \
746 REDISPATCH_FUNC, \
747 REGISTER_NAME, \
748 REGISTER_SIGNATURE, \
749 REDISPATCH_SIGNATURE, \
750 POLICY)
751
752 // KERNEL_MPS registration for AutocastMPS
753 #define KERNEL_MPS(OP, POLICY) \
754 m.impl( \
755 TORCH_SELECTIVE_NAME("aten::" #OP), \
756 &WrapFunction< \
757 CastPolicy::POLICY, \
758 DeviceType::MPS, \
759 decltype(ATEN_FN(OP)), \
760 decltype(ATEN_FN(OP)), \
761 &ATEN_FN(OP)>::type::call);
762
763 #define KERNEL_MPS2(OP, OVERLOAD, POLICY) \
764 m.impl( \
765 TORCH_SELECTIVE_NAME("aten::" #OP "." #OVERLOAD), \
766 &WrapFunction< \
767 CastPolicy::POLICY, \
768 DeviceType::MPS, \
769 decltype(ATEN_FN2(OP, OVERLOAD)), \
770 decltype(ATEN_FN2(OP, OVERLOAD)), \
771 &ATEN_FN2(OP, OVERLOAD)>::type::call);
772
773 // Op lists for different policies.
774 // To make sure other backends can reuse the policy op list.
775 #define AT_FORALL_LOWER_PRECISION_FP(_) \
776 _(_convolution, deprecated) \
777 _(_convolution) \
778 _(conv1d) \
779 _(conv2d) \
780 _(conv3d) \
781 _(conv_tbc) \
782 _(conv_transpose1d) \
783 _(conv_transpose2d, input) \
784 _(conv_transpose3d, input) \
785 _(convolution) \
786 _(prelu) \
787 _(addmm) \
788 _(addmv) \
789 _(addr) \
790 _(matmul) \
791 _(einsum) \
792 _(mm) \
793 _(mv) \
794 _(linalg_vecdot) \
795 _(linear) \
796 _(addbmm) \
797 _(baddbmm) \
798 _(bmm) \
799 _(chain_matmul) \
800 _(linalg_multi_dot) \
801 _(_thnn_fused_lstm_cell) \
802 _(_thnn_fused_gru_cell) \
803 _(lstm_cell) \
804 _(gru_cell) \
805 _(rnn_tanh_cell) \
806 _(rnn_relu_cell) \
807 _(_scaled_dot_product_flash_attention) \
808 _(scaled_dot_product_attention)
809
810 #define AT_FORALL_FP32(_) \
811 _(acos) \
812 _(asin) \
813 _(cosh) \
814 _(erfinv) \
815 _(exp) \
816 _(expm1) \
817 _(log) \
818 _(log10) \
819 _(log2) \
820 _(log1p) \
821 _(reciprocal) \
822 _(rsqrt) \
823 _(sinh) \
824 _(tan) \
825 _(pow, Tensor_Scalar) \
826 _(pow, Tensor_Tensor) \
827 _(pow, Scalar) \
828 _(softplus) \
829 _(layer_norm) \
830 _(native_layer_norm) \
831 _(group_norm) \
832 _(frobenius_norm, dim) \
833 _(nuclear_norm) \
834 _(nuclear_norm, dim) \
835 _(cosine_similarity) \
836 _(poisson_nll_loss) \
837 _(cosine_embedding_loss) \
838 _(nll_loss) \
839 _(nll_loss2d) \
840 _(hinge_embedding_loss) \
841 _(kl_div) \
842 _(l1_loss) \
843 _(smooth_l1_loss) \
844 _(huber_loss) \
845 _(mse_loss) \
846 _(margin_ranking_loss) \
847 _(multilabel_margin_loss) \
848 _(soft_margin_loss) \
849 _(triplet_margin_loss) \
850 _(multi_margin_loss) \
851 _(binary_cross_entropy_with_logits) \
852 _(dist) \
853 _(pdist) \
854 _(cdist) \
855 _(renorm) \
856 _(logsumexp) \
857 _(upsample_nearest1d) \
858 _(_upsample_nearest_exact1d) \
859 _(upsample_nearest2d) \
860 _(_upsample_nearest_exact2d) \
861 _(upsample_nearest3d) \
862 _(_upsample_nearest_exact3d) \
863 _(upsample_linear1d) \
864 _(upsample_bilinear2d) \
865 _(_upsample_bilinear2d_aa) \
866 _(upsample_trilinear3d) \
867 _(upsample_bicubic2d) \
868 _(_upsample_bicubic2d_aa)
869
870 #define AT_FORALL_FP32_SET_OPT_DTYPE(_) \
871 _(prod) \
872 _(prod, dim_int) \
873 _(prod, dim_Dimname) \
874 _(softmax, int) \
875 _(softmax, Dimname) \
876 _(log_softmax, int) \
877 _(log_softmax, Dimname) \
878 _(cumprod) \
879 _(cumprod, dimname) \
880 _(cumsum) \
881 _(cumsum, dimname) \
882 _(linalg_vector_norm) \
883 _(linalg_matrix_norm) \
884 _(linalg_matrix_norm, str_ord) \
885 _(sum) \
886 _(sum, dim_IntList) \
887 _(sum, dim_DimnameList)
888
889 #define AT_FORALL_DIFFERENT_REDISPATCH_SIGNATURE(_) \
890 _(ADD_NS(norm), \
891 "norm.Scalar", \
892 Tensor(const Tensor&, const Scalar&), \
893 Tensor(const Tensor&, const std::optional<Scalar>&, ScalarType), \
894 fp32_append_dtype) \
895 _(ADD_NS(norm), \
896 "norm.ScalarOpt_dim", \
897 Tensor(const Tensor&, const std::optional<Scalar>&, IntArrayRef, bool), \
898 Tensor( \
899 const Tensor&, \
900 const std::optional<Scalar>&, \
901 IntArrayRef, \
902 bool, \
903 ScalarType), \
904 fp32_append_dtype) \
905 _(ADD_NS(norm), \
906 "norm.names_ScalarOpt_dim", \
907 Tensor(const Tensor&, const std::optional<Scalar>&, DimnameList, bool), \
908 Tensor( \
909 const Tensor&, \
910 const std::optional<Scalar>&, \
911 DimnameList, \
912 bool, \
913 ScalarType), \
914 fp32_append_dtype)
915
916 #define AT_FORALL_PROMOTE(_) \
917 _(addcdiv) \
918 _(addcmul) \
919 _(atan2) \
920 _(bilinear) \
921 _(cross) \
922 _(dot) \
923 _(vdot) \
924 _(grid_sampler) \
925 _(index_put) \
926 _(tensordot) \
927 _(scatter_add)
928