xref: /aosp_15_r20/external/pytorch/aten/src/ATen/autocast_mode.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/ATen.h>
4 #include <ATen/NativeFunctions.h>
5 #include <ATen/Operators.h>
6 #include <torch/library.h>
7 
8 #include <c10/core/impl/LocalDispatchKeySet.h>
9 #include <c10/util/intrusive_ptr.h>
10 
11 namespace at::autocast {
12 
13 TORCH_API bool is_autocast_enabled(at::DeviceType device_type);
14 TORCH_API void set_autocast_enabled(at::DeviceType device_type, bool enabled);
15 TORCH_API at::ScalarType get_autocast_dtype(at::DeviceType device_type);
16 TORCH_API void set_autocast_dtype(
17     at::DeviceType device_type,
18     at::ScalarType dtype);
19 TORCH_API void clear_cache();
20 TORCH_API int increment_nesting();
21 TORCH_API int decrement_nesting();
22 TORCH_API bool is_autocast_cache_enabled();
23 TORCH_API void set_autocast_cache_enabled(bool enabled);
24 
25 // deprecated CUDA-specific autocast APIs
26 C10_DEPRECATED_MESSAGE(
27     "at::autocast::is_enabled() is deprecated. Please use at::autocast::is_autocast_enabled(at::kCUDA) instead.")
is_enabled()28 TORCH_API inline bool is_enabled() {
29   TORCH_WARN_DEPRECATION(
30       "at::autocast::",
31       __func__,
32       "() is deprecated. Please use at::autocast::is_autocast_enabled(at::kCUDA) instead.")
33   return is_autocast_enabled(at::kCUDA);
34 }
35 C10_DEPRECATED_MESSAGE(
36     "at::autocast::set_enabled(enabled) is deprecated. Please use at::autocast::set_autocast_enabled(at::kCUDA, enabled) instead.")
set_enabled(bool enabled)37 TORCH_API inline void set_enabled(bool enabled) {
38   TORCH_WARN_DEPRECATION(
39       "at::autocast::",
40       __func__,
41       "(enabled) is deprecated. Please use at::autocast::set_autocast_enabled(at::kCUDA, enabled) instead.")
42   set_autocast_enabled(at::kCUDA, enabled);
43 }
44 C10_DEPRECATED_MESSAGE(
45     "at::autocast::get_autocast_gpu_dtype() is deprecated. Please use at::autocast::get_autocast_dtype(at::kCUDA) instead.")
get_autocast_gpu_dtype()46 TORCH_API inline at::ScalarType get_autocast_gpu_dtype() {
47   TORCH_WARN_DEPRECATION(
48       "at::autocast::",
49       __func__,
50       "() is deprecated. Please use at::autocast::get_autocast_dtype(at::kCUDA) instead.")
51   return get_autocast_dtype(at::kCUDA);
52 }
53 C10_DEPRECATED_MESSAGE(
54     "at::autocast::set_autocast_gpu_dtype(dtype) is deprecated. Please use at::autocast::set_autocast_dtype(at::kCUDA, dtype) instead.")
set_autocast_gpu_dtype(at::ScalarType dtype)55 TORCH_API inline void set_autocast_gpu_dtype(at::ScalarType dtype) {
56   TORCH_WARN_DEPRECATION(
57       "at::autocast::",
58       __func__,
59       "(dtype) is deprecated. Please use at::autocast::set_autocast_dtype(at::kCUDA, dtype) instead.")
60   set_autocast_dtype(at::kCUDA, dtype);
61 }
62 
63 #define DECLARE_DEPRECATED_AUTOCAST_APIS(name, device_type)                                          \
64   C10_DEPRECATED_MESSAGE(                                                                            \
65       "at::autocast::is_" #name                                                                      \
66       "_enabled() is deprecated. Please use at::autocast::is_autocast_enabled(" #device_type         \
67       ") instead.")                                                                                  \
68   TORCH_API inline bool is_##name##_enabled() {                                                      \
69     TORCH_WARN_DEPRECATION(                                                                          \
70         "at::autocast::",                                                                            \
71         __func__,                                                                                    \
72         "() is deprecated. Please use at::autocast::is_autocast_enabled(" #device_type               \
73         ") instead.")                                                                                \
74     return is_autocast_enabled(device_type);                                                         \
75   }                                                                                                  \
76                                                                                                      \
77   C10_DEPRECATED_MESSAGE(                                                                            \
78       "at::autocast::set_" #name                                                                     \
79       "_enabled(enabled) is deprecated. Please use at::autocast::set_autocast_enabled(" #device_type \
80       ", enabled) instead.")                                                                         \
81   TORCH_API inline void set_##name##_enabled(bool enabled) {                                         \
82     TORCH_WARN_DEPRECATION(                                                                          \
83         "at::autocast::",                                                                            \
84         __func__,                                                                                    \
85         "(enabled) is deprecated. Please use at::autocast::set_autocast_enabled(" #device_type       \
86         ", enabled) instead.")                                                                       \
87     set_autocast_enabled(device_type, enabled);                                                      \
88   }                                                                                                  \
89                                                                                                      \
90   C10_DEPRECATED_MESSAGE(                                                                            \
91       "at::autocast::get_autocast_" #name                                                            \
92       "_dtype() is deprecated. Please use at::autocast::get_autocast_dtype(" #device_type            \
93       ") instead.")                                                                                  \
94   TORCH_API inline at::ScalarType get_autocast_##name##_dtype() {                                    \
95     TORCH_WARN_DEPRECATION(                                                                          \
96         "at::autocast::",                                                                            \
97         __func__,                                                                                    \
98         "() is deprecated. Please at::autocast::get_autocast_dtype(" #device_type                    \
99         ") instead.")                                                                                \
100     return get_autocast_dtype(device_type);                                                          \
101   }                                                                                                  \
102                                                                                                      \
103   C10_DEPRECATED_MESSAGE(                                                                            \
104       "at::autocast::set_autocast_" #name                                                            \
105       "_dtype(dtype) is deprecated. Please use at::autocast::set_autocast_dtype(" #device_type       \
106       ", dtype) instead.")                                                                           \
107   TORCH_API inline void set_autocast_##name##_dtype(at::ScalarType dtype) {                          \
108     TORCH_WARN_DEPRECATION(                                                                          \
109         "at::autocast::",                                                                            \
110         __func__,                                                                                    \
111         "(dtype) is deprecated. Please use at::autocast::set_autocast_dtype(" #device_type           \
112         ", dtype) instead.")                                                                         \
113     set_autocast_dtype(device_type, dtype);                                                          \
114   }
115 
116 #define AT_FORALL_DEPRECATED_AUTOCAST_BAKCNEDS(_) \
117   _(cpu, at::kCPU)                                \
118   _(xpu, at::kXPU)                                \
119   _(xla, at::kXLA)                                \
120   _(hpu, at::kHPU)                                \
121   _(ipu, at::kIPU)                                \
122   _(privateuseone, at::kPrivateUse1)
123 
124 // deprecated other backend specific autocast APIs
AT_FORALL_DEPRECATED_AUTOCAST_BAKCNEDS(DECLARE_DEPRECATED_AUTOCAST_APIS)125 AT_FORALL_DEPRECATED_AUTOCAST_BAKCNEDS(DECLARE_DEPRECATED_AUTOCAST_APIS)
126 
127 namespace {
128 inline bool is_autocast_eligible(
129     const Tensor& tensor,
130     c10::DeviceType device_type) {
131   switch (device_type) {
132     case c10::DeviceType::CUDA:
133       return (tensor.is_cuda() || tensor.is_xla()) &&
134           tensor.is_floating_point();
135     case c10::DeviceType::CPU:
136       return (tensor.is_cpu() || tensor.is_mkldnn()) &&
137           tensor.is_floating_point();
138     case c10::DeviceType::XPU:
139       return tensor.is_xpu() && tensor.is_floating_point();
140     case c10::DeviceType::IPU:
141       return tensor.is_ipu() && tensor.is_floating_point();
142     case c10::DeviceType::HPU:
143       return tensor.is_hpu() && tensor.is_floating_point();
144     case c10::DeviceType::XLA:
145       return tensor.is_xla() && tensor.is_floating_point();
146     case c10::DeviceType::PrivateUse1:
147       return tensor.is_privateuseone() && tensor.is_floating_point();
148     case c10::DeviceType::MPS:
149       return tensor.is_mps() && tensor.is_floating_point();
150     default:
151       return false;
152   }
153 }
154 } // namespace
155 
get_autocast_dispatch_key_from_device_type(c10::DeviceType device_type)156 inline DispatchKey get_autocast_dispatch_key_from_device_type(
157     c10::DeviceType device_type) {
158   switch (device_type) {
159     case c10::DeviceType::CUDA:
160       return DispatchKey::Autocast;
161     case c10::DeviceType::CPU:
162       return DispatchKey::AutocastCPU;
163     case c10::DeviceType::XPU:
164       return DispatchKey::AutocastXPU;
165     case c10::DeviceType::IPU:
166       return DispatchKey::AutocastIPU;
167     case c10::DeviceType::HPU:
168       return DispatchKey::AutocastHPU;
169     case c10::DeviceType::XLA:
170       return DispatchKey::AutocastXLA;
171     case c10::DeviceType::PrivateUse1:
172       return DispatchKey::AutocastPrivateUse1;
173     case c10::DeviceType::MPS:
174       return DispatchKey::AutocastMPS;
175     default:
176       throw std::runtime_error(
177           "unknown device type for autocast in get_autocast_dispatch_key_from_device_type");
178   }
179 }
180 
is_autocast_available(c10::DeviceType device_type)181 inline bool is_autocast_available(c10::DeviceType device_type) {
182   if (device_type == at::kCPU || device_type == at::kCUDA ||
183       device_type == at::kXPU || device_type == at::kIPU ||
184       device_type == at::kHPU || device_type == at::kXLA ||
185       device_type == at::kPrivateUse1 || device_type == at::kMPS) {
186     return true;
187   } else {
188     return false;
189   }
190 }
191 
get_lower_precision_fp_from_device_type(c10::DeviceType device_type)192 inline at::ScalarType get_lower_precision_fp_from_device_type(
193     c10::DeviceType device_type) {
194   if (is_autocast_available(device_type)) {
195     return get_autocast_dtype(device_type);
196   } else {
197     throw std::runtime_error(
198         "unknown device type for autocast in get_lower_precision_fp_from_device_type");
199   }
200 }
201 
202 /********************************************************************
203 Logic to extract the promote type from any Tensor or TensorList args.
204 ********************************************************************/
205 
206 // Overload to catch Tensor args.
207 // If nextArg is floating-point, compare its scalar_type with our
208 // current best guess for the promote type, and update if necessary.
209 inline at::ScalarType prioritize(
210     at::ScalarType current,
211     const Tensor& nextArg,
212     c10::DeviceType device_type = c10::DeviceType::CUDA) {
213   if (current == at::kDouble) {
214     AT_ERROR("promote type is double in at::autocast::prioritize");
215     return current;
216   }
217   at::ScalarType lower_precision_fp =
218       get_lower_precision_fp_from_device_type(device_type);
219   if (is_autocast_eligible(nextArg, device_type)) {
220     auto next = nextArg.scalar_type();
221     if (next == at::kDouble) {
222       return current; // ignores double tensors
223     } else if (current == at::kFloat || next == at::kFloat) {
224       return at::kFloat; // prioritizes float over lower_precision_fp
225     } else if (current == lower_precision_fp && next == lower_precision_fp) {
226       return lower_precision_fp;
227     } else {
228       AT_ERROR("Unexpected floating ScalarType in at::autocast::prioritize");
229       return current;
230     }
231   } else {
232     return current;
233   }
234 }
235 
236 // Overload to catch TensorList args (for e.g. cat, stack).
237 // Reuses the overload above to process each Tensor in the list.
238 inline at::ScalarType prioritize(
239     at::ScalarType current,
240     const TensorList& list,
241     c10::DeviceType device_type = c10::DeviceType::CUDA) {
242   for (const auto& tensor : list) {
243     current = prioritize(current, tensor, device_type);
244   }
245   return current;
246 }
247 
248 inline at::ScalarType prioritize(
249     at::ScalarType current,
250     const ITensorListRef& list,
251     c10::DeviceType device_type = c10::DeviceType::CUDA) {
252   for (const auto& tensor : list) {
253     current = prioritize(current, tensor, device_type);
254   }
255   return current;
256 }
257 
258 // Template to catch non-Tensor args (no-op that returns current best guess)
259 template <typename T>
260 inline at::ScalarType prioritize(
261     at::ScalarType current,
262     T nextArg,
263     c10::DeviceType device_type = c10::DeviceType::CUDA) {
264   return current;
265 }
266 
267 // Overload for the tail case.
promote_type(at::ScalarType current,c10::DeviceType device_type)268 inline at::ScalarType promote_type(
269     at::ScalarType current,
270     c10::DeviceType device_type) {
271   return current;
272 }
273 
274 // Unpack args and determine if incoming lower_precision_fp tensors need to be
275 // promoted to float32. Non-Tensor arguments are ignored.
276 template <typename Arg0, typename... Args>
promote_type(at::ScalarType current,c10::DeviceType device_type,Arg0 arg0,Args...args)277 inline at::ScalarType promote_type(
278     at::ScalarType current,
279     c10::DeviceType device_type,
280     Arg0 arg0,
281     Args... args) {
282   auto new_current = prioritize(current, arg0, device_type);
283   return promote_type(new_current, device_type, args...);
284 }
285 
286 /****************************************************
287 Logic to apply cached casting to any Tensor argument.
288 ****************************************************/
289 inline bool is_eligible(
290     const Tensor& arg,
291     c10::DeviceType device_type = c10::DeviceType::CUDA) {
292   return (
293       arg.defined() && is_autocast_eligible(arg, device_type) &&
294       (arg.scalar_type() != at::kDouble));
295 }
296 
297 // Overload to catch Tensor args
298 TORCH_API Tensor cached_cast(
299     at::ScalarType to_type,
300     const Tensor& arg,
301     c10::DeviceType device_type = c10::DeviceType::CUDA);
302 
303 // Overload to process std::optional<Tensor>
304 inline std::optional<Tensor> cached_cast(
305     at::ScalarType to_type,
306     const std::optional<Tensor>& arg,
307     c10::DeviceType device_type = c10::DeviceType::CUDA) {
308   if (arg.has_value()) {
309     return cached_cast(to_type, *arg, device_type);
310   } else {
311     return std::nullopt;
312   }
313 }
314 
315 // Overload to process TensorLists
316 inline std::vector<Tensor> cached_cast(
317     at::ScalarType to_type,
318     const TensorList& arg,
319     c10::DeviceType device_type = c10::DeviceType::CUDA) {
320   std::vector<Tensor> vec;
321   vec.reserve(arg.size());
322   for (const auto& t : arg) {
323     vec.emplace_back(cached_cast(to_type, t, device_type));
324   }
325   return vec;
326 }
327 
328 inline std::vector<Tensor> cached_cast(
329     at::ScalarType to_type,
330     const ITensorListRef& arg,
331     c10::DeviceType device_type = c10::DeviceType::CUDA) {
332   std::vector<Tensor> vec;
333   vec.reserve(arg.size());
334   for (const auto& t : arg) {
335     vec.emplace_back(cached_cast(to_type, t, device_type));
336   }
337   return vec;
338 }
339 
340 // Template to catch non-Tensor args.
341 template <typename T>
342 inline T cached_cast(
343     at::ScalarType to_type,
344     T arg,
345     c10::DeviceType device_type = c10::DeviceType::CUDA) {
346   return arg;
347 }
348 
349 /*******************************************************
350 Logic to flip an output dtype flag.
351 Keep it simple for now by assuming only one such flag is
352 present in the argument list.  If I ever need a function
353 with more than flag I'll figure out something else.
354 The policy is:
355 If the user has explicity specified a dtype, respect it.
356 Otherwise, set it to the autocast type.
357 ********************************************************/
358 
359 // Overload to catch dtype flags
set_opt_dtype(at::ScalarType to_type,const std::optional<ScalarType> & dtype)360 std::optional<ScalarType> inline set_opt_dtype(
361     at::ScalarType to_type,
362     const std::optional<ScalarType>& dtype) {
363   return dtype.has_value() ? dtype : to_type;
364 }
365 
366 // Template to catch other args
367 template <typename T>
set_opt_dtype(at::ScalarType to_type,T arg)368 inline T set_opt_dtype(at::ScalarType to_type, T arg) {
369   return arg;
370 }
371 
372 template <typename... Args>
firstarg_is_eligible(c10::DeviceType device_type,const Tensor & arg,Args...args)373 inline bool firstarg_is_eligible(
374     c10::DeviceType device_type,
375     const Tensor& arg,
376     Args... args) {
377   return is_eligible(arg, device_type);
378 }
379 
380 template <typename... Args>
type_from_firstarg(c10::DeviceType device_type,at::ScalarType to_type,const Tensor & arg,Args...args)381 inline at::ScalarType type_from_firstarg(
382     c10::DeviceType device_type,
383     at::ScalarType to_type,
384     const Tensor& arg,
385     Args... args) {
386   return (is_eligible(arg, device_type) ? to_type : arg.scalar_type());
387 }
388 
389 // Policies correspond to op categories that need code-divergent handling.
390 // Wrapper templates below are specialized based on a policy template parameter.
391 enum class CastPolicy : uint8_t {
392   lower_precision_fp = 0, // Cast all inputs to lower_precision_fp before
393                           // running the op. Currently, lower_precision_fp is
394                           // fp16 for AutocastCUDA, and is defined by user
395                           // (default bf16) for AutocastCPU or other device.
396   fp32, // Cast all inputs to at::kFloat before running the op.
397   fp32_set_opt_dtype, // Treats functions (like softmax) that
398                       //  1. we'd like to run in fp32 and
399                       //  2. have a std::optional<ScalarType> arg that controls
400                       //  the output type.
401                       // fp32_set_opt_dtype wrappers' policy is: if the output
402                       // type is already set, don't touch it, otherwise, set
403                       // it to at::kFloat.
404   fp32_append_dtype, // Treats functions (like norm) that
405                      //  1. we'd like to run in fp32 and
406                      //  2. have some overloads that accept an output type and
407                      //  other overloads that don't.
408                      // fp32_append_dtype wrappers wrap the overloads that don't
409                      // have an output dtype.
410                      // The wrapper policy is:  append at::kFloat to the args,
411                      // and redispatch to the type-aware overload.
412   promote, // Run in the widest dtype among several args.
413 };
414 
415 /********************************************************************************************************
416 Templates to provide wrapper functions
417 
418 I'm copying the pattern used in core/boxing/impl/WrapFunctionIntoFunctor.h to
419 extract args and return type. (see also
420 https://stackoverflow.com/questions/46533698/how-to-deduce-argument-list-from-function-pointer)
421 
422 This strategy uses an exterior "WrapFunction" that extracts arguments on behalf
423 of (in my case several specializations of) an interior "WrapFunction_".
424 Interior WrapFunction_ specializations are defined for each CastPolicy.
425 ********************************************************************************************************/
426 
427 // Base template for WrapFunction_, which is specialized to contain a "call"
428 // method each CastPolicy
429 template <
430     CastPolicy policy,
431     c10::DeviceType device_type,
432     class Redispatch,
433     Redispatch* F,
434     class Ret,
435     class ArgList>
436 struct WrapFunction_ {};
437 
438 // CastPolicy::lower_precision_fp General_DeviceType
439 template <
440     c10::DeviceType device_type,
441     class Redispatch,
442     Redispatch* F,
443     class Ret,
444     class... Args>
445 struct WrapFunction_<
446     CastPolicy::lower_precision_fp,
447     device_type,
448     Redispatch,
449     F,
450     Ret,
451     guts::typelist::typelist<Args...>> {
452   static Ret call(Args... args) {
453     c10::impl::ExcludeDispatchKeyGuard no_autocast(
454         get_autocast_dispatch_key_from_device_type(device_type));
455     return (*F)(cached_cast(
456         get_lower_precision_fp_from_device_type(device_type),
457         args,
458         device_type)...);
459   }
460 };
461 
462 // CastPolicy::fp32 General_DeviceType
463 template <
464     c10::DeviceType device_type,
465     class Redispatch,
466     Redispatch* F,
467     class Ret,
468     class... Args>
469 struct WrapFunction_<
470     CastPolicy::fp32,
471     device_type,
472     Redispatch,
473     F,
474     Ret,
475     guts::typelist::typelist<Args...>> {
476   static Ret call(Args... args) {
477     c10::impl::ExcludeDispatchKeyGuard no_autocast(
478         get_autocast_dispatch_key_from_device_type(device_type));
479     return (*F)(cached_cast(at::kFloat, args, device_type)...);
480   }
481 };
482 
483 // CastPolicy::fp32_set_opt_dtype General_DeviceType
484 template <
485     c10::DeviceType device_type,
486     class Redispatch,
487     Redispatch* F,
488     class Ret,
489     class... Args>
490 struct WrapFunction_<
491     CastPolicy::fp32_set_opt_dtype,
492     device_type,
493     Redispatch,
494     F,
495     Ret,
496     guts::typelist::typelist<Args...>> {
497   static Ret call(Args... args) {
498     c10::impl::ExcludeDispatchKeyGuard no_autocast(
499         get_autocast_dispatch_key_from_device_type(device_type));
500     if (firstarg_is_eligible(device_type, args...)) {
501       return (*F)(set_opt_dtype(at::kFloat, args)...);
502     } else {
503       // If ineligible, calls F with unaltered args.  Does not set opt dtype,
504       // because setting opt dtype explicitly may interfere with internal
505       // implicit promotion decisions.
506       return (*F)(args...);
507     }
508   }
509 };
510 
511 // CastPolicy::fp32_append_dtype General_DeviceType
512 template <
513     c10::DeviceType device_type,
514     class Redispatch,
515     Redispatch* F,
516     class Ret,
517     class... Args>
518 struct WrapFunction_<
519     CastPolicy::fp32_append_dtype,
520     device_type,
521     Redispatch,
522     F,
523     Ret,
524     guts::typelist::typelist<Args...>> {
525   static Ret call(Args... args) {
526     c10::impl::ExcludeDispatchKeyGuard no_autocast(
527         get_autocast_dispatch_key_from_device_type(device_type));
528     at::ScalarType out_type =
529         type_from_firstarg(device_type, at::kFloat, args...);
530     return (*F)(args..., out_type);
531   }
532 };
533 
534 // CastPolicy::promote General_DeviceType
535 template <
536     c10::DeviceType device_type,
537     class Redispatch,
538     Redispatch* F,
539     class Ret,
540     class... Args>
541 struct WrapFunction_<
542     CastPolicy::promote,
543     device_type,
544     Redispatch,
545     F,
546     Ret,
547     guts::typelist::typelist<Args...>> {
548   static Ret call(Args... args) {
549     c10::impl::ExcludeDispatchKeyGuard no_autocast(
550         get_autocast_dispatch_key_from_device_type(device_type));
551     auto to_type = promote_type(
552         get_lower_precision_fp_from_device_type(device_type),
553         device_type,
554         args...);
555     return (*F)(cached_cast(to_type, args, device_type)...);
556   }
557 };
558 
559 // Wrapper to infer return_type and parameter_types for WrapFunction_ (imitating
560 // core/boxing/impl/WrapFunctionIntoFunctor.h)
561 template <
562     CastPolicy policy,
563     c10::DeviceType device_type,
564     class Registered, // The signature for which we're registering.  The
565                       // dispatcher's calling code invokes our registered
566                       // functions with arguments matching Registered, so we
567                       // register WrapFunction_::call methods with a matching
568                       // signature to properly field those arguments.
569     // guts::function_traits below extracts return_type and
570     // parameter_types from Registered, which WrapFunction_
571     // templates above use to declare their call methods.
572     class Redispatch, // The signature for the function we're redispatching to.
573                       // In most cases this is the same as Registered, but for
574                       // some ops (for example, ops where we append a dtype)
575                       // it's useful to redispatch to a function with a
576                       // different signature.
577     Redispatch* F> // The actual function we're redispatching to.
578 struct WrapFunction final {
579   using type = WrapFunction_<
580       policy,
581       device_type,
582       Redispatch,
583       F,
584       typename guts::function_traits<Registered>::return_type,
585       typename guts::function_traits<Registered>::parameter_types>;
586 };
587 
588 /*****************************************************************************************************************
589 This section performs load-time registration for autocast wrappers.
590 
591 It's debatable at what level operations should be patched.  We'd like casts to
592 be autograd-exposed and precede autograd history recording, so that for
593 lower_precision_fp ops, input tensors are saved for backward in
594 lower_precision_fp rather than fp32.  Saving inputs in lower_precision_fp
595 can significantly reduce a model's memory footprint.
596 
597 Option 1 (strawman):  Patch only at the level of explicit calls into
598 cudnn/cublas (cudnn_convolution, etc), because those are the code paths that are
599 guaranteed to use Tensor Cores, therefore they're the ones that will benefit
600 most from lower_precision_fp.   Potential pitfall:  convolutions (and other ops)
601 are wrapped in several layers of at::* calls.  If one of those happens to record
602 autograd history, then we've lost the opportunity to save inputs in
603 lower_precision_fp.
604 
605 Option 2:  Patch the Python-exposed surface of calls, to make 100% sure autograd
606 history recording can't sneak in ahead of autocast.  This mirrors Apex most
607 closely.
608 
609 I think Option 2 is the right answer for all ops, not just convolutions. Option
610 2 is what I implement here.
611 *****************************************************************************************************************/
612 
613 /********************************************************************************************************************
614 Explicit registration for out-of-place ops
615 
616 The stuff below could be codegenned.  Ed said
617 > you are going to have to write the function definition at some point, I
618 wouldn't try to get clever about it Therefore, for the moment, this is all
619 copy pasted in from VariableTypeEverything.cpp with appropriate substitutions.
620 ********************************************************************************************************************/
621 
622 } // namespace at::autocast
623 
624 #define ADD_NS(RAW_OP) at::RAW_OP
625 
626 #define _KERNEL_OVERLOAD_NARG_IMPL(_0, _1, _2, N, ...) N
627 #define _KERNEL_OVERLOAD_NARG(...) \
628   C10_EXPAND_MSVC_WORKAROUND(_KERNEL_OVERLOAD_NARG_IMPL(__VA_ARGS__, 2, 1))
629 
630 // Common cases where registration signature matches redispatch signature
631 // (that's why SIGNATURE is repeated in the WrapFunction instantiation)
632 #define KERNEL1(DISPATCHKEY, OP, POLICY)      \
633   m.impl(                                     \
634       TORCH_SELECTIVE_NAME("aten::" #OP),     \
635       &::at::autocast::WrapFunction<          \
636           ::at::autocast::CastPolicy::POLICY, \
637           DISPATCHKEY,                        \
638           decltype(ATEN_FN(OP)),              \
639           decltype(ATEN_FN(OP)),              \
640           &ATEN_FN(OP)>::type::call);
641 
642 #define KERNEL2(DISPATCHKEY, OP, OVERLOAD, POLICY)      \
643   m.impl(                                               \
644       TORCH_SELECTIVE_NAME("aten::" #OP "." #OVERLOAD), \
645       &::at::autocast::WrapFunction<                    \
646           ::at::autocast::CastPolicy::POLICY,           \
647           DISPATCHKEY,                                  \
648           decltype(ATEN_FN2(OP, OVERLOAD)),             \
649           decltype(ATEN_FN2(OP, OVERLOAD)),             \
650           &ATEN_FN2(OP, OVERLOAD)>::type::call);
651 
652 #define _KERNEL_DISPATCH(DISPATCHKEY, NARG, ...) \
653   C10_CONCATENATE(KERNEL, NARG)(DISPATCHKEY, __VA_ARGS__)
654 
655 #define _KERNEL_IMPL(DISPATCHKEY, ...) \
656   _KERNEL_DISPATCH(DISPATCHKEY, _KERNEL_OVERLOAD_NARG(__VA_ARGS__), __VA_ARGS__)
657 
658 // It will dispatch to KERNEL1 or KERNEL2 based on its inputs.
659 #define KERNEL(DISPATCHKEY, ...) _KERNEL_IMPL(DISPATCHKEY, __VA_ARGS__)
660 
661 // Less-common but still useful case: redispatching to a function
662 // with a new signature (e.g. appending a dtype)
663 #define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(      \
664     DISPATCHKEY,                                    \
665     REDISPATCH_FUNC,                                \
666     REGISTER_NAME,                                  \
667     REGISTER_SIGNATURE,                             \
668     REDISPATCH_SIGNATURE,                           \
669     POLICY)                                         \
670   m.impl(                                           \
671       TORCH_SELECTIVE_NAME("aten::" REGISTER_NAME), \
672       &::at::autocast::WrapFunction<                \
673           ::at::autocast::CastPolicy::POLICY,       \
674           DISPATCHKEY,                              \
675           REGISTER_SIGNATURE,                       \
676           REDISPATCH_SIGNATURE,                     \
677           &REDISPATCH_FUNC>::type::call);
678 
679 // KERNEL_CPU/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CPU
680 // registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastCPU
681 #define KERNEL_CPU(...) KERNEL(c10::DeviceType::CPU, __VA_ARGS__)
682 
683 #define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CPU( \
684     REDISPATCH_FUNC,                               \
685     REGISTER_NAME,                                 \
686     REGISTER_SIGNATURE,                            \
687     REDISPATCH_SIGNATURE,                          \
688     POLICY)                                        \
689   KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(           \
690       c10::DeviceType::CPU,                        \
691       REDISPATCH_FUNC,                             \
692       REGISTER_NAME,                               \
693       REGISTER_SIGNATURE,                          \
694       REDISPATCH_SIGNATURE,                        \
695       POLICY)
696 
697 // KERNEL_CUDA/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CUDA
698 // registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastCUDA
699 #define KERNEL_CUDA(...) KERNEL(c10::DeviceType::CUDA, __VA_ARGS__)
700 
701 #define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CUDA( \
702     REDISPATCH_FUNC,                                \
703     REGISTER_NAME,                                  \
704     REGISTER_SIGNATURE,                             \
705     REDISPATCH_SIGNATURE,                           \
706     POLICY)                                         \
707   KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(            \
708       c10::DeviceType::CUDA,                        \
709       REDISPATCH_FUNC,                              \
710       REGISTER_NAME,                                \
711       REGISTER_SIGNATURE,                           \
712       REDISPATCH_SIGNATURE,                         \
713       POLICY)
714 
715 // KERNEL_XPU/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_XPU
716 // registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastXPU
717 #define KERNEL_XPU(...) KERNEL(c10::DeviceType::XPU, __VA_ARGS__)
718 
719 #define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_XPU( \
720     REDISPATCH_FUNC,                               \
721     REGISTER_NAME,                                 \
722     REGISTER_SIGNATURE,                            \
723     REDISPATCH_SIGNATURE,                          \
724     POLICY)                                        \
725   KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(           \
726       c10::DeviceType::XPU,                        \
727       REDISPATCH_FUNC,                             \
728       REGISTER_NAME,                               \
729       REGISTER_SIGNATURE,                          \
730       REDISPATCH_SIGNATURE,                        \
731       POLICY)
732 
733 // KERNEL_PRIVATEUSEONE/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_PRIVATEUSEONE
734 // registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastPrivateUse1
735 #define KERNEL_PRIVATEUSEONE(...) \
736   KERNEL(c10::DeviceType::PrivateUse1, __VA_ARGS__)
737 
738 #define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_PRIVATEUSEONE( \
739     REDISPATCH_FUNC,                                         \
740     REGISTER_NAME,                                           \
741     REGISTER_SIGNATURE,                                      \
742     REDISPATCH_SIGNATURE,                                    \
743     POLICY)                                                  \
744   KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(                     \
745       c10::DeviceType::PrivateUse1,                          \
746       REDISPATCH_FUNC,                                       \
747       REGISTER_NAME,                                         \
748       REGISTER_SIGNATURE,                                    \
749       REDISPATCH_SIGNATURE,                                  \
750       POLICY)
751 
752 // KERNEL_MPS registration for AutocastMPS
753 #define KERNEL_MPS(OP, POLICY)            \
754   m.impl(                                 \
755       TORCH_SELECTIVE_NAME("aten::" #OP), \
756       &WrapFunction<                      \
757           CastPolicy::POLICY,             \
758           DeviceType::MPS,                \
759           decltype(ATEN_FN(OP)),          \
760           decltype(ATEN_FN(OP)),          \
761           &ATEN_FN(OP)>::type::call);
762 
763 #define KERNEL_MPS2(OP, OVERLOAD, POLICY)               \
764   m.impl(                                               \
765       TORCH_SELECTIVE_NAME("aten::" #OP "." #OVERLOAD), \
766       &WrapFunction<                                    \
767           CastPolicy::POLICY,                           \
768           DeviceType::MPS,                              \
769           decltype(ATEN_FN2(OP, OVERLOAD)),             \
770           decltype(ATEN_FN2(OP, OVERLOAD)),             \
771           &ATEN_FN2(OP, OVERLOAD)>::type::call);
772 
773 // Op lists for different policies.
774 // To make sure other backends can reuse the policy op list.
775 #define AT_FORALL_LOWER_PRECISION_FP(_)  \
776   _(_convolution, deprecated)            \
777   _(_convolution)                        \
778   _(conv1d)                              \
779   _(conv2d)                              \
780   _(conv3d)                              \
781   _(conv_tbc)                            \
782   _(conv_transpose1d)                    \
783   _(conv_transpose2d, input)             \
784   _(conv_transpose3d, input)             \
785   _(convolution)                         \
786   _(prelu)                               \
787   _(addmm)                               \
788   _(addmv)                               \
789   _(addr)                                \
790   _(matmul)                              \
791   _(einsum)                              \
792   _(mm)                                  \
793   _(mv)                                  \
794   _(linalg_vecdot)                       \
795   _(linear)                              \
796   _(addbmm)                              \
797   _(baddbmm)                             \
798   _(bmm)                                 \
799   _(chain_matmul)                        \
800   _(linalg_multi_dot)                    \
801   _(_thnn_fused_lstm_cell)               \
802   _(_thnn_fused_gru_cell)                \
803   _(lstm_cell)                           \
804   _(gru_cell)                            \
805   _(rnn_tanh_cell)                       \
806   _(rnn_relu_cell)                       \
807   _(_scaled_dot_product_flash_attention) \
808   _(scaled_dot_product_attention)
809 
810 #define AT_FORALL_FP32(_)             \
811   _(acos)                             \
812   _(asin)                             \
813   _(cosh)                             \
814   _(erfinv)                           \
815   _(exp)                              \
816   _(expm1)                            \
817   _(log)                              \
818   _(log10)                            \
819   _(log2)                             \
820   _(log1p)                            \
821   _(reciprocal)                       \
822   _(rsqrt)                            \
823   _(sinh)                             \
824   _(tan)                              \
825   _(pow, Tensor_Scalar)               \
826   _(pow, Tensor_Tensor)               \
827   _(pow, Scalar)                      \
828   _(softplus)                         \
829   _(layer_norm)                       \
830   _(native_layer_norm)                \
831   _(group_norm)                       \
832   _(frobenius_norm, dim)              \
833   _(nuclear_norm)                     \
834   _(nuclear_norm, dim)                \
835   _(cosine_similarity)                \
836   _(poisson_nll_loss)                 \
837   _(cosine_embedding_loss)            \
838   _(nll_loss)                         \
839   _(nll_loss2d)                       \
840   _(hinge_embedding_loss)             \
841   _(kl_div)                           \
842   _(l1_loss)                          \
843   _(smooth_l1_loss)                   \
844   _(huber_loss)                       \
845   _(mse_loss)                         \
846   _(margin_ranking_loss)              \
847   _(multilabel_margin_loss)           \
848   _(soft_margin_loss)                 \
849   _(triplet_margin_loss)              \
850   _(multi_margin_loss)                \
851   _(binary_cross_entropy_with_logits) \
852   _(dist)                             \
853   _(pdist)                            \
854   _(cdist)                            \
855   _(renorm)                           \
856   _(logsumexp)                        \
857   _(upsample_nearest1d)               \
858   _(_upsample_nearest_exact1d)        \
859   _(upsample_nearest2d)               \
860   _(_upsample_nearest_exact2d)        \
861   _(upsample_nearest3d)               \
862   _(_upsample_nearest_exact3d)        \
863   _(upsample_linear1d)                \
864   _(upsample_bilinear2d)              \
865   _(_upsample_bilinear2d_aa)          \
866   _(upsample_trilinear3d)             \
867   _(upsample_bicubic2d)               \
868   _(_upsample_bicubic2d_aa)
869 
870 #define AT_FORALL_FP32_SET_OPT_DTYPE(_) \
871   _(prod)                               \
872   _(prod, dim_int)                      \
873   _(prod, dim_Dimname)                  \
874   _(softmax, int)                       \
875   _(softmax, Dimname)                   \
876   _(log_softmax, int)                   \
877   _(log_softmax, Dimname)               \
878   _(cumprod)                            \
879   _(cumprod, dimname)                   \
880   _(cumsum)                             \
881   _(cumsum, dimname)                    \
882   _(linalg_vector_norm)                 \
883   _(linalg_matrix_norm)                 \
884   _(linalg_matrix_norm, str_ord)        \
885   _(sum)                                \
886   _(sum, dim_IntList)                   \
887   _(sum, dim_DimnameList)
888 
889 #define AT_FORALL_DIFFERENT_REDISPATCH_SIGNATURE(_)                         \
890   _(ADD_NS(norm),                                                           \
891     "norm.Scalar",                                                          \
892     Tensor(const Tensor&, const Scalar&),                                   \
893     Tensor(const Tensor&, const std::optional<Scalar>&, ScalarType),        \
894     fp32_append_dtype)                                                      \
895   _(ADD_NS(norm),                                                           \
896     "norm.ScalarOpt_dim",                                                   \
897     Tensor(const Tensor&, const std::optional<Scalar>&, IntArrayRef, bool), \
898     Tensor(                                                                 \
899         const Tensor&,                                                      \
900         const std::optional<Scalar>&,                                       \
901         IntArrayRef,                                                        \
902         bool,                                                               \
903         ScalarType),                                                        \
904     fp32_append_dtype)                                                      \
905   _(ADD_NS(norm),                                                           \
906     "norm.names_ScalarOpt_dim",                                             \
907     Tensor(const Tensor&, const std::optional<Scalar>&, DimnameList, bool), \
908     Tensor(                                                                 \
909         const Tensor&,                                                      \
910         const std::optional<Scalar>&,                                       \
911         DimnameList,                                                        \
912         bool,                                                               \
913         ScalarType),                                                        \
914     fp32_append_dtype)
915 
916 #define AT_FORALL_PROMOTE(_) \
917   _(addcdiv)                 \
918   _(addcmul)                 \
919   _(atan2)                   \
920   _(bilinear)                \
921   _(cross)                   \
922   _(dot)                     \
923   _(vdot)                    \
924   _(grid_sampler)            \
925   _(index_put)               \
926   _(tensordot)               \
927   _(scatter_add)
928