1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #define _USE_MATH_DEFINES
3
4 #include <math.h>
5
6 #include <ATen/core/Tensor.h>
7 #include <ATen/DeviceGuard.h>
8 #include <ATen/Dispatch.h>
9 #include <ATen/native/cuda/ForeachFunctors.cuh>
10 #include <ATen/native/cuda/Loops.cuh>
11 #include <ATen/native/ForeachUtils.h>
12 #include <ATen/native/TensorIterator.h>
13
14
15 namespace {
16 // Thin wrapper around https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH__SINGLE.html#group__CUDA__MATH__SINGLE_1g57a3c8313f570282a1a7bcc78743b08e,
17 // to ensure the Cuda math library's isfinite is actually what gets called in
18 // _amp_non_finite_check_and_unscale_cuda_'s gpu_kernel lambda.
19 //
20 // isfinite_ensure_cuda_math is defined outside at::native because:
21 // - A bare call to "isfinite(val)" inside at::native causes nvcc to prefer the unrelated
22 // Tensor at::native::isfinite(const Tensor&), resulting in an error:
23 // "no suitable constructor exists to convert from "float" to "at::Tensor""
24 // - Unfortunately, the Cuda math library documentation doesn't say how (or if) you can provide a full namespace path
25 // to ensure that its version of a particular function is invoked. It only shows bare (not-namespaced)
26 // calls to its routines inside kernel or device functions.
27 // - "std::isfinite(val)" in the gpu_kernel lambda causes an "unspecified launch failure" at runtime with cuda 9 on Windows.
28 //
29 // isfinite_ensure_cuda_math, declared at file scope outside the at::native region, uses isfinite as math library docs
30 // suggest and allows disambiguated usage in the lambda within the at::native region.
31 // GPU_LAMBDA is defined as __host__ __device__ (see Loops.cuh), so I need the __host__ keyword or else nvcc complains that
32 // "calling a __device__ function("isfinite_ensure_cuda_math") from a __host__ __device__ function("operator()") is not allowed."
isfinite_ensure_cuda_math(float val)33 static __host__ __device__ __forceinline__ int isfinite_ensure_cuda_math(float val) {
34 return isfinite(val);
35 }
36 }
37
38 namespace at::native {
39
40 namespace {
41 // Single-tensor fallback for _amp_foreach_non_finite_check_and_unscale_cuda_.
42 // Handles individual tensors that are acceptable to unscale but not MTA-safe.
_amp_non_finite_check_and_unscale_cuda_(Tensor & scaled_grad,Tensor & found_inf,const Tensor & inv_scale)43 void _amp_non_finite_check_and_unscale_cuda_(Tensor& scaled_grad,
44 Tensor& found_inf,
45 const Tensor& inv_scale)
46 {
47 // The only way we reach this function is through _amp_foreach_non_finite_check_and_unscale_cuda_, so no input checks.
48
49 // It's not obvious gpu_kernel always guards onto its argument. Guarding here just in case.
50 const OptionalDeviceGuard device_guard(device_of(scaled_grad));
51
52 // Acts on scaled_grad in place.
53 auto iter = TensorIterator::unary_op(scaled_grad, scaled_grad);
54
55 AT_DISPATCH_FLOATING_TYPES_AND_HALF(
56 iter.dtype(),
57 "_amp_non_finite_check_and_unscale_cuda",
58 [&iter, &found_inf, &inv_scale] {
59 auto* found_inf_ptr = found_inf.mutable_data_ptr<float>();
60 auto* inv_scale_ptr = inv_scale.const_data_ptr<float>();
61
62 using opmath_t = at::opmath_type<scalar_t>;
63
64 gpu_kernel(iter,
65 [found_inf_ptr, inv_scale_ptr] GPU_LAMBDA (scalar_t val_in) -> scalar_t {
66 auto val = static_cast<opmath_t>(val_in);
67 if (!isfinite_ensure_cuda_math(val)) {
68 *found_inf_ptr = 1.f;
69 }
70 // Every thread accesses inv_scale, but it will hit in cache.
71 const auto inv_scale_val = *inv_scale_ptr;
72 return static_cast<scalar_t>(inv_scale_val == 1.f ? val : val * inv_scale_val);
73 });
74 });
75 }
76 } // anonymous namespace
77
78
79 // Multiplies each tensor in scaled_grads by inv_scale in-place.
80 // If any element of any tensor in scaled_grads is inf or NaN, sets found_inf to 1.0.
81 // Uses multi tensor apply (MTA) to process all MTA-safe tensors.
82 //
83 // Args:
84 // scaled_grads: A TensorList of scaled gradient tensors. May contain infs or NaNs.
85 // found_inf: A single-element float tensor to which 1.0 will be written if any gradient contain infs/nans.
86 // Pre-zeroing found_inf, if appropriate, is the responsibility of the caller.
87 // inv_scale: The inverse of the scale factor by which scaled_grads are currently multiplied.
_amp_foreach_non_finite_check_and_unscale_cuda_(TensorList scaled_grads,Tensor & found_inf,const Tensor & inv_scale)88 void _amp_foreach_non_finite_check_and_unscale_cuda_(TensorList scaled_grads,
89 Tensor& found_inf,
90 const Tensor& inv_scale)
91 {
92 if (scaled_grads.size() == 0) {
93 return;
94 }
95
96 TORCH_CHECK(inv_scale.is_cuda(), "inv_scale must be a CUDA tensor.");
97 TORCH_CHECK(found_inf.is_cuda(), "found_inf must be a CUDA tensor.");
98 TORCH_CHECK(inv_scale.numel() == 1, "inv_scale must be a 1-element tensor.");
99 TORCH_CHECK(found_inf.numel() == 1, "found_inf must be a 1-element tensor.");
100 TORCH_CHECK(inv_scale.scalar_type() == at::ScalarType::Float, "inv_scale must be a float tensor.");
101 TORCH_CHECK(found_inf.scalar_type() == at::ScalarType::Float, "found_inf must be a float tensor.");
102
103 // Ensures client code (GradScaler) filtered scaled_grads by dtype.
104 check_foreach_api_restrictions(scaled_grads);
105
106 std::vector<std::vector<at::Tensor>> tensor_lists;
107
108 // is_non_overlapping_and_dense() is not available in Python.
109 // GradScaler can't filter for it. We need to filter here.
110 if (can_use_fast_route(scaled_grads)) {
111 // Hopefully common case.
112 // can_use_fast_route is true, which confirms:
113 // - all scaled_grads are strided
114 // - all scaled_grads are non overlapping and dense
115 // - all scaled_grads are on the same device
116 // - all scaled_grads are of the same dtype
117 TORCH_CHECK(scaled_grads[0].is_cuda(), "scaled_grads must be CUDA tensors.");
118 // Sets up MTA launch to use scaled_grads as-is.
119 tensor_lists.emplace_back(scaled_grads.vec());
120 } else {
121 // Hopefully uncommon case.
122 // can_use_fast_route is an all-or-nothing check. In this path it was false,
123 // so any of the above confirmations could have gone wrong.
124 // We filter MTA-safe tensors into an MTA-able list.
125 // If a tensor is acceptable but not MTA-safe, we fall back to the TensorIterator kernel.
126 // If a tensor is unacceptable, we throw an error to blame GradScaler.
127 tensor_lists.resize(1);
128 tensor_lists[0].reserve(scaled_grads.size());
129 auto expected_device = scaled_grads[0].device();
130 const auto expected_dtype = scaled_grads[0].scalar_type();
131 for (const Tensor& t : scaled_grads) {
132 // Ensures GradScaler filtered scaled_grads by device.
133 TORCH_CHECK(t.is_cuda(), "one of scaled_grads was not a CUDA tensor.");
134 TORCH_CHECK(t.device() == expected_device, "scaled_grads must be on the same device.");
135 TORCH_CHECK(t.layout() == at::kStrided, "one of scaled_grads was not a strided tensor.");
136 if (!t.is_non_overlapping_and_dense() || t.scalar_type() != expected_dtype) {
137 // t is acceptable but not MTA-safe. Falls back to single-tensor TensorIterator kernel.
138 _amp_non_finite_check_and_unscale_cuda_(const_cast<Tensor&>(t),
139 found_inf,
140 inv_scale);
141 } else {
142 tensor_lists[0].push_back(t);
143 }
144 }
145 if (tensor_lists[0].size() == 0) {
146 return;
147 }
148 }
149
150 AT_DISPATCH_FLOATING_TYPES_AND_HALF(
151 tensor_lists[0][0].scalar_type(),
152 "_amp_foreach_non_finite_check_and_unscale_cuda",
153 [&tensor_lists, &found_inf, &inv_scale] {
154 auto* found_inf_ptr = found_inf.mutable_data_ptr<float>();
155 auto* inv_scale_ptr = inv_scale.const_data_ptr<float>();
156
157 using opmath_t = at::opmath_type<scalar_t>;
158
159 // multi_tensor_apply guards onto tensor_lists[0][0], no need to guard explicitly.
160 multi_tensor_apply<1>(tensor_lists,
161 UnaryOpFunctor<scalar_t,
162 /* depth */ 1,
163 /* r_args_depth */ 1,
164 /* res_arg_index */ 0>(),
165 [found_inf_ptr, inv_scale_ptr] GPU_LAMBDA (opmath_t val) -> opmath_t {
166 // There is a slight asymmetry here with the TensorIterator kernel above.
167 // MTA Functors ensure val comes in as opmath_t rather than scalar_t.
168 if (!isfinite_ensure_cuda_math(val)) {
169 *found_inf_ptr = 1.f;
170 }
171 // Every thread accesses inv_scale, but it will hit in cache.
172 const auto inv_scale_val = *inv_scale_ptr;
173 return static_cast<opmath_t>(inv_scale_val == 1.f ? val : val * inv_scale_val);
174 });
175 });
176 }
177
178
179 // amp_update_scale_cuda_kernel is launched with a single thread to compute the new scale.
180 // The scale factor is maintained and updated on the GPU to avoid synchronization.
amp_update_scale_cuda_kernel(float * current_scale,int * growth_tracker,const float * found_inf,double growth_factor,double backoff_factor,int growth_interval)181 __global__ void amp_update_scale_cuda_kernel(float* current_scale,
182 int* growth_tracker,
183 const float* found_inf,
184 double growth_factor,
185 double backoff_factor,
186 int growth_interval)
187 {
188 if (*found_inf) {
189 *current_scale = (*current_scale)*backoff_factor;
190 *growth_tracker = 0;
191 } else {
192 // Entering this branch means we just carried out a successful step,
193 // so growth_tracker is incremented before comparing to growth_interval.
194 auto successful = (*growth_tracker) + 1;
195 if (successful == growth_interval) {
196 auto new_scale = static_cast<float>((*current_scale)*growth_factor);
197 // Do not grow the scale past fp32 bounds to inf.
198 if (isfinite_ensure_cuda_math(new_scale)) {
199 *current_scale = new_scale;
200 }
201 *growth_tracker = 0;
202 } else {
203 *growth_tracker = successful;
204 }
205 }
206 }
207
208
209 // _amp_update_scale_cuda asynchronously updates the scale tensor in place.
210 //
211 // Args:
212 // current_scale: A one-element cuda float tensor containing the scale value.
213 // growth_tracker: A one-element torch.cuda.IntTensor containing the number of recent consecutive unskipped steps.
214 // found_inf: A one-element cuda float tensor. If > 0, indicates that infs/nans were found by the relevant
215 // prior _amp_non_finite_check_and_unscale_cuda call, and 0 if no infs/nans were found.
216 // growth_factor: Multiplier if no infs/NaNs were found (typically slightly > 1).
217 // backoff_factor: Multiplier if infs/NaNs were found (typically 0.5).
218 // growth_interval: Number of consecutive unskipped steps that must occur for current_scale to be multiplied by
219 // growth_factor.
220 //
221 // Returns:
222 // current_scale
_amp_update_scale_cuda_(Tensor & current_scale,Tensor & growth_tracker,const Tensor & found_inf,double growth_factor,double backoff_factor,int64_t growth_interval)223 Tensor& _amp_update_scale_cuda_(Tensor& current_scale,
224 Tensor& growth_tracker,
225 const Tensor& found_inf,
226 double growth_factor,
227 double backoff_factor,
228 int64_t growth_interval)
229 {
230 TORCH_CHECK(growth_tracker.is_cuda(), "growth_tracker must be a CUDA tensor.");
231 TORCH_CHECK(current_scale.is_cuda(), "current_scale must be a CUDA tensor.");
232 TORCH_CHECK(found_inf.is_cuda(), "found_inf must be a CUDA tensor.");
233 TORCH_CHECK(growth_tracker.numel() == 1, "growth_tracker must be a 1-element tensor.");
234 TORCH_CHECK(current_scale.numel() == 1, "current_scale must be a 1-element tensor.");
235 TORCH_CHECK(found_inf.numel() == 1, "found_inf must be a 1-element tensor.");
236 TORCH_CHECK(growth_tracker.scalar_type() == at::ScalarType::Int, "growth_tracker must be an int tensor.");
237 TORCH_CHECK(current_scale.scalar_type() == at::ScalarType::Float, "current_scale must be a float tensor.");
238 TORCH_CHECK(found_inf.scalar_type() == at::ScalarType::Float, "found_inf must be a float tensor.");
239
240 amp_update_scale_cuda_kernel<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
241 current_scale.mutable_data_ptr<float>(),
242 growth_tracker.mutable_data_ptr<int>(),
243 found_inf.const_data_ptr<float>(),
244 growth_factor,
245 backoff_factor,
246 growth_interval);
247 C10_CUDA_KERNEL_LAUNCH_CHECK();
248
249 return current_scale;
250 }
251
252 } // namespace at::native
253