xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/AmpKernels.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #define _USE_MATH_DEFINES
3 
4 #include <math.h>
5 
6 #include <ATen/core/Tensor.h>
7 #include <ATen/DeviceGuard.h>
8 #include <ATen/Dispatch.h>
9 #include <ATen/native/cuda/ForeachFunctors.cuh>
10 #include <ATen/native/cuda/Loops.cuh>
11 #include <ATen/native/ForeachUtils.h>
12 #include <ATen/native/TensorIterator.h>
13 
14 
15 namespace {
16 // Thin wrapper around https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH__SINGLE.html#group__CUDA__MATH__SINGLE_1g57a3c8313f570282a1a7bcc78743b08e,
17 // to ensure the Cuda math library's isfinite is actually what gets called in
18 // _amp_non_finite_check_and_unscale_cuda_'s gpu_kernel lambda.
19 //
20 // isfinite_ensure_cuda_math is defined outside at::native because:
21 // - A bare call to "isfinite(val)" inside at::native causes nvcc to prefer the unrelated
22 //   Tensor at::native::isfinite(const Tensor&), resulting in an error:
23 //   "no suitable constructor exists to convert from "float" to "at::Tensor""
24 // - Unfortunately, the Cuda math library documentation doesn't say how (or if) you can provide a full namespace path
25 //   to ensure that its version of a particular function is invoked.  It only shows bare (not-namespaced)
26 //   calls to its routines inside kernel or device functions.
27 // - "std::isfinite(val)" in the gpu_kernel lambda causes an "unspecified launch failure" at runtime with cuda 9 on Windows.
28 //
29 // isfinite_ensure_cuda_math, declared at file scope outside the at::native region, uses isfinite as math library docs
30 // suggest and allows disambiguated usage in the lambda within the at::native region.
31 // GPU_LAMBDA is defined as __host__ __device__ (see Loops.cuh), so I need the __host__ keyword or else nvcc complains that
32 // "calling a __device__ function("isfinite_ensure_cuda_math") from a __host__ __device__ function("operator()") is not allowed."
isfinite_ensure_cuda_math(float val)33 static __host__ __device__ __forceinline__ int isfinite_ensure_cuda_math(float val) {
34   return isfinite(val);
35 }
36 }
37 
38 namespace at::native {
39 
40 namespace {
41 // Single-tensor fallback for _amp_foreach_non_finite_check_and_unscale_cuda_.
42 // Handles individual tensors that are acceptable to unscale but not MTA-safe.
_amp_non_finite_check_and_unscale_cuda_(Tensor & scaled_grad,Tensor & found_inf,const Tensor & inv_scale)43 void _amp_non_finite_check_and_unscale_cuda_(Tensor& scaled_grad,
44                                              Tensor& found_inf,
45                                              const Tensor& inv_scale)
46 {
47   // The only way we reach this function is through _amp_foreach_non_finite_check_and_unscale_cuda_, so no input checks.
48 
49   // It's not obvious gpu_kernel always guards onto its argument.  Guarding here just in case.
50   const OptionalDeviceGuard device_guard(device_of(scaled_grad));
51 
52   // Acts on scaled_grad in place.
53   auto iter = TensorIterator::unary_op(scaled_grad, scaled_grad);
54 
55   AT_DISPATCH_FLOATING_TYPES_AND_HALF(
56     iter.dtype(),
57     "_amp_non_finite_check_and_unscale_cuda",
58     [&iter, &found_inf, &inv_scale] {
59       auto* found_inf_ptr = found_inf.mutable_data_ptr<float>();
60       auto* inv_scale_ptr = inv_scale.const_data_ptr<float>();
61 
62       using opmath_t = at::opmath_type<scalar_t>;
63 
64       gpu_kernel(iter,
65                  [found_inf_ptr, inv_scale_ptr] GPU_LAMBDA (scalar_t val_in) -> scalar_t {
66                    auto val = static_cast<opmath_t>(val_in);
67                    if (!isfinite_ensure_cuda_math(val)) {
68                      *found_inf_ptr = 1.f;
69                    }
70                    // Every thread accesses inv_scale, but it will hit in cache.
71                    const auto inv_scale_val = *inv_scale_ptr;
72                    return static_cast<scalar_t>(inv_scale_val == 1.f ? val : val * inv_scale_val);
73                  });
74     });
75 }
76 } // anonymous namespace
77 
78 
79 // Multiplies each tensor in scaled_grads by inv_scale in-place.
80 // If any element of any tensor in scaled_grads is inf or NaN, sets found_inf to 1.0.
81 // Uses multi tensor apply (MTA) to process all MTA-safe tensors.
82 //
83 // Args:
84 // scaled_grads:  A TensorList of scaled gradient tensors.  May contain infs or NaNs.
85 // found_inf:  A single-element float tensor to which 1.0 will be written if any gradient contain infs/nans.
86 //             Pre-zeroing found_inf, if appropriate, is the responsibility of the caller.
87 // inv_scale:  The inverse of the scale factor by which scaled_grads are currently multiplied.
_amp_foreach_non_finite_check_and_unscale_cuda_(TensorList scaled_grads,Tensor & found_inf,const Tensor & inv_scale)88 void _amp_foreach_non_finite_check_and_unscale_cuda_(TensorList scaled_grads,
89                                                      Tensor& found_inf,
90                                                      const Tensor& inv_scale)
91 {
92   if (scaled_grads.size() == 0) {
93     return;
94   }
95 
96   TORCH_CHECK(inv_scale.is_cuda(), "inv_scale must be a CUDA tensor.");
97   TORCH_CHECK(found_inf.is_cuda(), "found_inf must be a CUDA tensor.");
98   TORCH_CHECK(inv_scale.numel() == 1, "inv_scale must be a 1-element tensor.");
99   TORCH_CHECK(found_inf.numel() == 1, "found_inf must be a 1-element tensor.");
100   TORCH_CHECK(inv_scale.scalar_type() == at::ScalarType::Float, "inv_scale must be a float tensor.");
101   TORCH_CHECK(found_inf.scalar_type() == at::ScalarType::Float, "found_inf must be a float tensor.");
102 
103   // Ensures client code (GradScaler) filtered scaled_grads by dtype.
104   check_foreach_api_restrictions(scaled_grads);
105 
106   std::vector<std::vector<at::Tensor>> tensor_lists;
107 
108   // is_non_overlapping_and_dense() is not available in Python.
109   // GradScaler can't filter for it. We need to filter here.
110   if (can_use_fast_route(scaled_grads)) {
111     // Hopefully common case.
112     // can_use_fast_route is true, which confirms:
113     //  - all scaled_grads are strided
114     //  - all scaled_grads are non overlapping and dense
115     //  - all scaled_grads are on the same device
116     //  - all scaled_grads are of the same dtype
117     TORCH_CHECK(scaled_grads[0].is_cuda(), "scaled_grads must be CUDA tensors.");
118     // Sets up MTA launch to use scaled_grads as-is.
119     tensor_lists.emplace_back(scaled_grads.vec());
120   } else {
121     // Hopefully uncommon case.
122     // can_use_fast_route is an all-or-nothing check.  In this path it was false,
123     // so any of the above confirmations could have gone wrong.
124     // We filter MTA-safe tensors into an MTA-able list.
125     // If a tensor is acceptable but not MTA-safe, we fall back to the TensorIterator kernel.
126     // If a tensor is unacceptable, we throw an error to blame GradScaler.
127     tensor_lists.resize(1);
128     tensor_lists[0].reserve(scaled_grads.size());
129     auto expected_device = scaled_grads[0].device();
130     const auto expected_dtype = scaled_grads[0].scalar_type();
131     for (const Tensor& t : scaled_grads) {
132       // Ensures GradScaler filtered scaled_grads by device.
133       TORCH_CHECK(t.is_cuda(), "one of scaled_grads was not a CUDA tensor.");
134       TORCH_CHECK(t.device() == expected_device, "scaled_grads must be on the same device.");
135       TORCH_CHECK(t.layout() == at::kStrided, "one of scaled_grads was not a strided tensor.");
136       if (!t.is_non_overlapping_and_dense() || t.scalar_type() != expected_dtype) {
137         // t is acceptable but not MTA-safe.  Falls back to single-tensor TensorIterator kernel.
138         _amp_non_finite_check_and_unscale_cuda_(const_cast<Tensor&>(t),
139                                                 found_inf,
140                                                 inv_scale);
141       } else {
142         tensor_lists[0].push_back(t);
143       }
144     }
145     if (tensor_lists[0].size() == 0) {
146       return;
147     }
148   }
149 
150   AT_DISPATCH_FLOATING_TYPES_AND_HALF(
151     tensor_lists[0][0].scalar_type(),
152     "_amp_foreach_non_finite_check_and_unscale_cuda",
153     [&tensor_lists, &found_inf, &inv_scale] {
154       auto* found_inf_ptr = found_inf.mutable_data_ptr<float>();
155       auto* inv_scale_ptr = inv_scale.const_data_ptr<float>();
156 
157       using opmath_t = at::opmath_type<scalar_t>;
158 
159       // multi_tensor_apply guards onto tensor_lists[0][0], no need to guard explicitly.
160       multi_tensor_apply<1>(tensor_lists,
161                             UnaryOpFunctor<scalar_t,
162                                            /* depth */ 1,
163                                            /* r_args_depth */ 1,
164                                            /* res_arg_index */ 0>(),
165                             [found_inf_ptr, inv_scale_ptr] GPU_LAMBDA (opmath_t val) -> opmath_t {
166                               // There is a slight asymmetry here with the TensorIterator kernel above.
167                               // MTA Functors ensure val comes in as opmath_t rather than scalar_t.
168                               if (!isfinite_ensure_cuda_math(val)) {
169                                 *found_inf_ptr = 1.f;
170                               }
171                               // Every thread accesses inv_scale, but it will hit in cache.
172                               const auto inv_scale_val = *inv_scale_ptr;
173                               return static_cast<opmath_t>(inv_scale_val == 1.f ? val : val * inv_scale_val);
174                             });
175     });
176 }
177 
178 
179 // amp_update_scale_cuda_kernel is launched with a single thread to compute the new scale.
180 // The scale factor is maintained and updated on the GPU to avoid synchronization.
amp_update_scale_cuda_kernel(float * current_scale,int * growth_tracker,const float * found_inf,double growth_factor,double backoff_factor,int growth_interval)181 __global__ void amp_update_scale_cuda_kernel(float* current_scale,
182                                              int* growth_tracker,
183                                              const float* found_inf,
184                                              double growth_factor,
185                                              double backoff_factor,
186                                              int growth_interval)
187 {
188   if (*found_inf) {
189     *current_scale = (*current_scale)*backoff_factor;
190     *growth_tracker = 0;
191   } else {
192     // Entering this branch means we just carried out a successful step,
193     // so growth_tracker is incremented before comparing to growth_interval.
194     auto successful = (*growth_tracker) + 1;
195     if (successful == growth_interval) {
196       auto new_scale = static_cast<float>((*current_scale)*growth_factor);
197       // Do not grow the scale past fp32 bounds to inf.
198       if (isfinite_ensure_cuda_math(new_scale)) {
199           *current_scale = new_scale;
200       }
201       *growth_tracker = 0;
202     } else {
203       *growth_tracker = successful;
204     }
205   }
206 }
207 
208 
209 // _amp_update_scale_cuda asynchronously updates the scale tensor in place.
210 //
211 // Args:
212 // current_scale:  A one-element cuda float tensor containing the scale value.
213 // growth_tracker:  A one-element torch.cuda.IntTensor containing the number of recent consecutive unskipped steps.
214 // found_inf:  A one-element cuda float tensor. If > 0, indicates that infs/nans were found by the relevant
215 //             prior _amp_non_finite_check_and_unscale_cuda call, and 0 if no infs/nans were found.
216 // growth_factor:  Multiplier if no infs/NaNs were found (typically slightly > 1).
217 // backoff_factor:  Multiplier if infs/NaNs were found (typically 0.5).
218 // growth_interval:  Number of consecutive unskipped steps that must occur for current_scale to be multiplied by
219 //                   growth_factor.
220 //
221 // Returns:
222 // current_scale
_amp_update_scale_cuda_(Tensor & current_scale,Tensor & growth_tracker,const Tensor & found_inf,double growth_factor,double backoff_factor,int64_t growth_interval)223 Tensor& _amp_update_scale_cuda_(Tensor& current_scale,
224                                 Tensor& growth_tracker,
225                                 const Tensor& found_inf,
226                                 double growth_factor,
227                                 double backoff_factor,
228                                 int64_t growth_interval)
229 {
230   TORCH_CHECK(growth_tracker.is_cuda(), "growth_tracker must be a CUDA tensor.");
231   TORCH_CHECK(current_scale.is_cuda(), "current_scale must be a CUDA tensor.");
232   TORCH_CHECK(found_inf.is_cuda(), "found_inf must be a CUDA tensor.");
233   TORCH_CHECK(growth_tracker.numel() == 1, "growth_tracker must be a 1-element tensor.");
234   TORCH_CHECK(current_scale.numel() == 1, "current_scale must be a 1-element tensor.");
235   TORCH_CHECK(found_inf.numel() == 1, "found_inf must be a 1-element tensor.");
236   TORCH_CHECK(growth_tracker.scalar_type() == at::ScalarType::Int, "growth_tracker must be an int tensor.");
237   TORCH_CHECK(current_scale.scalar_type() == at::ScalarType::Float, "current_scale must be a float tensor.");
238   TORCH_CHECK(found_inf.scalar_type() == at::ScalarType::Float, "found_inf must be a float tensor.");
239 
240   amp_update_scale_cuda_kernel<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
241     current_scale.mutable_data_ptr<float>(),
242     growth_tracker.mutable_data_ptr<int>(),
243     found_inf.const_data_ptr<float>(),
244     growth_factor,
245     backoff_factor,
246     growth_interval);
247   C10_CUDA_KERNEL_LAUNCH_CHECK();
248 
249   return current_scale;
250 }
251 
252 } // namespace at::native
253