xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/FusedSgdKernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/Dispatch.h>
2 #include <ATen/OpMathType.h>
3 #include <ATen/core/Tensor.h>
4 #include <ATen/native/ForeachUtils.h>
5 #include <c10/util/Exception.h>
6 #include <ATen/native/cuda/ForeachFunctors.cuh>
7 #include <ATen/native/cuda/MultiTensorApply.cuh>
8 
9 namespace at::native {
10 
11 namespace {
12 
13 template <typename scalar_t, int depth>
sgd_math(scalar_t r_args[depth][kILP],const double weight_decay,const double momentum,const float * lr_ptr,const double lr,const double dampening,const bool nesterov,const bool maximize,const bool is_first_step,const float * grad_scale_ptr)14 C10_DEVICE __forceinline__ void sgd_math(
15     scalar_t r_args[depth][kILP],
16     const double weight_decay,
17     const double momentum,
18     const float* lr_ptr,
19     const double lr,
20     const double dampening,
21     const bool nesterov,
22     const bool maximize,
23     const bool is_first_step,
24     const float* grad_scale_ptr) {
25   using opmath_t = at::opmath_type<scalar_t>;
26   const double double_lr = lr_ptr != nullptr ? *lr_ptr : lr;
27 #pragma unroll
28   for (int ii = 0; ii < kILP; ii++) {
29     auto p = static_cast<opmath_t>(r_args[0][ii]);
30     auto g = static_cast<opmath_t>(r_args[1][ii]);
31     if (grad_scale_ptr) {
32       g /= static_cast<double>(*grad_scale_ptr);
33       r_args[1][ii] = g;
34     }
35     if (maximize) {
36       g *= -1.0;
37     }
38     if (weight_decay != 0) {
39       g += weight_decay * p;
40     }
41     if (depth > 2) {
42       const auto momentum_buffer = is_first_step
43           ? g
44           : (momentum * static_cast<opmath_t>(r_args[2][ii]) +
45              (1 - dampening) * g);
46       r_args[2][ii] = momentum_buffer;
47 
48       if (nesterov) {
49         g = g + momentum * momentum_buffer;
50       } else {
51         g = momentum_buffer;
52       }
53     }
54     p -= double_lr * g;
55     r_args[0][ii] = p;
56   }
57 }
58 
59 template <typename scalar_t, int depth>
60 struct FusedSgdMathFunctor {
61   static_assert(
62       depth == 2 || depth == 3,
63       "depth of 2 for SGD w/ momentum == 0, 3 for SGD w/ momentum != 0");
operator ()at::native::__anon7b1e40e20111::FusedSgdMathFunctor64   C10_DEVICE __forceinline__ void operator()(
65       const int chunk_size,
66       TensorListMetadata<depth>& tl,
67       const double weight_decay,
68       const double momentum,
69       const float* lr_ptr,
70       const double lr,
71       const double dampening,
72       const bool nesterov,
73       const bool maximize,
74       const bool is_first_step,
75       const float* grad_scale_ptr,
76       const float* found_inf_ptr) {
77     if (found_inf_ptr && *found_inf_ptr == 1) {
78       return;
79     }
80     const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
81     const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
82 
83     scalar_t* args[depth];
84     scalar_t r_args[depth][kILP];
85     const auto all_aligned{
86         init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc)};
87     const auto n = tl.numel_for_tensor[tensor_loc] - chunk_idx * chunk_size;
88 
89     const auto use_faster_load_store =
90         (n % kILP == 0) && (chunk_size % kILP == 0) && all_aligned;
91     if (use_faster_load_store) {
92       for (auto i_start = threadIdx.x;
93            i_start * kILP < n && i_start * kILP < chunk_size;
94            i_start += blockDim.x) {
95 #pragma unroll
96         for (auto i = 0; i < depth; i++) {
97           load_store(r_args[i], args[i], 0, i_start);
98         }
99         sgd_math<scalar_t, depth>(
100             r_args,
101             weight_decay,
102             momentum,
103             lr_ptr,
104             lr,
105             dampening,
106             nesterov,
107             maximize,
108             is_first_step,
109             grad_scale_ptr);
110         load_store(args[0], r_args[0], i_start, 0);
111         if (grad_scale_ptr) {
112           load_store(args[1], r_args[1], i_start, 0);
113         }
114         if (depth > 2) {
115           load_store(args[2], r_args[2], i_start, 0);
116         }
117       }
118     } else {
119       for (auto i_start = 0; i_start < n && i_start < chunk_size;
120            i_start += blockDim.x * kILP) {
121         load_args<depth>(r_args, args, i_start, chunk_size, n);
122         sgd_math<scalar_t, depth>(
123             r_args,
124             weight_decay,
125             momentum,
126             lr_ptr,
127             lr,
128             dampening,
129             nesterov,
130             maximize,
131             is_first_step,
132             grad_scale_ptr);
133         store_args(args[0], r_args[0], i_start, chunk_size, n);
134         if (grad_scale_ptr) {
135           store_args(args[1], r_args[1], i_start, chunk_size, n);
136         }
137         if (depth > 2) {
138           store_args(args[2], r_args[2], i_start, chunk_size, n);
139         }
140       }
141     }
142   }
143 };
144 
_fused_sgd_with_momentum_kernel_cuda_(at::TensorList params,at::TensorList grads,at::TensorList momentum_buffer_list,const double weight_decay,const double momentum,const double lr,const double dampening,const bool nesterov,const bool maximize,const bool is_first_step,const std::optional<at::Tensor> & grad_scale,const std::optional<at::Tensor> & found_inf)145 void _fused_sgd_with_momentum_kernel_cuda_(
146     at::TensorList params,
147     at::TensorList grads,
148     at::TensorList momentum_buffer_list,
149     const double weight_decay,
150     const double momentum,
151     const double lr,
152     const double dampening,
153     const bool nesterov,
154     const bool maximize,
155     const bool is_first_step,
156     const std::optional<at::Tensor>& grad_scale,
157     const std::optional<at::Tensor>& found_inf) {
158   TORCH_CHECK_GT(momentum, 0);
159   TORCH_CHECK(at::native::check_fast_path_restrictions(
160       {params, grads, momentum_buffer_list}));
161   float* grad_scale_ptr =
162       grad_scale.has_value() ? grad_scale->data_ptr<float>() : nullptr;
163   float* found_inf_ptr =
164       found_inf.has_value() ? found_inf->data_ptr<float>() : nullptr;
165   float* lr_ptr = nullptr;
166 
167   std::vector<std::vector<at::Tensor>> tensor_lists{
168       params.vec(), grads.vec(), momentum_buffer_list.vec()};
169   AT_DISPATCH_FLOATING_TYPES_AND2(
170       kHalf,
171       kBFloat16,
172       params[0].scalar_type(),
173       "fused_sgd_with_momentum_kernel_cuda",
174       [&]() {
175         multi_tensor_apply<3>(
176             tensor_lists,
177             FusedSgdMathFunctor<scalar_t, 3>(),
178             weight_decay,
179             momentum,
180             lr_ptr,
181             lr,
182             dampening,
183             nesterov,
184             maximize,
185             is_first_step,
186             grad_scale_ptr,
187             found_inf_ptr);
188       });
189 }
190 
_fused_sgd_with_momentum_kernel_cuda_(at::TensorList params,at::TensorList grads,at::TensorList momentum_buffer_list,const double weight_decay,const double momentum,const at::Tensor & lr,const double dampening,const bool nesterov,const bool maximize,const bool is_first_step,const std::optional<at::Tensor> & grad_scale,const std::optional<at::Tensor> & found_inf)191 void _fused_sgd_with_momentum_kernel_cuda_(
192     at::TensorList params,
193     at::TensorList grads,
194     at::TensorList momentum_buffer_list,
195     const double weight_decay,
196     const double momentum,
197     const at::Tensor& lr,
198     const double dampening,
199     const bool nesterov,
200     const bool maximize,
201     const bool is_first_step,
202     const std::optional<at::Tensor>& grad_scale,
203     const std::optional<at::Tensor>& found_inf) {
204   if (lr.is_cpu()) {
205     _fused_sgd_with_momentum_kernel_cuda_(
206         params,
207         grads,
208         momentum_buffer_list,
209         weight_decay,
210         momentum,
211         lr.item<double>(),
212         dampening,
213         nesterov,
214         maximize,
215         is_first_step,
216         grad_scale,
217         found_inf);
218     return;
219   }
220   TORCH_CHECK_GT(momentum, 0);
221   TORCH_CHECK(at::native::check_fast_path_restrictions(
222       {params, grads, momentum_buffer_list}));
223   if (grad_scale != std::nullopt) {
224     TORCH_CHECK(
225         grad_scale->device() == params[0].device(),
226         "grad_scale must be on the same GPU device as the params");
227   }
228   if (found_inf != std::nullopt) {
229     TORCH_CHECK(
230         found_inf->device() == params[0].device(),
231         "found_inf must be on the same GPU device as the params");
232   }
233   TORCH_CHECK(
234       lr.device() == params[0].device(),
235       "found_inf must be on the same GPU device as the params");
236   float* grad_scale_ptr =
237       grad_scale.has_value() ? grad_scale->data_ptr<float>() : nullptr;
238   float* found_inf_ptr =
239       found_inf.has_value() ? found_inf->data_ptr<float>() : nullptr;
240 
241   std::vector<std::vector<at::Tensor>> tensor_lists{
242       params.vec(), grads.vec(), momentum_buffer_list.vec()};
243   AT_DISPATCH_FLOATING_TYPES_AND2(
244       kHalf,
245       kBFloat16,
246       params[0].scalar_type(),
247       "fused_sgd_with_momentum_kernel_cuda",
248       [&]() {
249         multi_tensor_apply<3>(
250             tensor_lists,
251             FusedSgdMathFunctor<scalar_t, 3>(),
252             weight_decay,
253             momentum,
254             lr.data_ptr<float>(),
255             1.0,
256             dampening,
257             nesterov,
258             maximize,
259             is_first_step,
260             grad_scale_ptr,
261             found_inf_ptr);
262       });
263 }
264 
265 } // namespace
266 
_fused_sgd_kernel_cuda_(at::TensorList params,at::TensorList grads,at::TensorList momentum_buffer_list,const double weight_decay,const double momentum,const double lr,const double dampening,const bool nesterov,const bool maximize,const bool is_first_step,const std::optional<at::Tensor> & grad_scale,const std::optional<at::Tensor> & found_inf)267 void _fused_sgd_kernel_cuda_(
268     at::TensorList params,
269     at::TensorList grads,
270     at::TensorList momentum_buffer_list,
271     const double weight_decay,
272     const double momentum,
273     const double lr,
274     const double dampening,
275     const bool nesterov,
276     const bool maximize,
277     const bool is_first_step,
278     const std::optional<at::Tensor>& grad_scale,
279     const std::optional<at::Tensor>& found_inf) {
280   if (!momentum_buffer_list.empty()) {
281     _fused_sgd_with_momentum_kernel_cuda_(
282         params,
283         grads,
284         momentum_buffer_list,
285         weight_decay,
286         momentum,
287         lr,
288         dampening,
289         nesterov,
290         maximize,
291         is_first_step,
292         grad_scale,
293         found_inf);
294     return;
295   }
296   TORCH_CHECK_EQ(momentum, 0);
297   TORCH_CHECK(at::native::check_fast_path_restrictions({params, grads}));
298   if (is_first_step) {
299     TORCH_WARN_ONCE(
300         "`is_first_step` argument has no effect when `momentum_buffer_list` is empty");
301   }
302   float* grad_scale_ptr =
303       grad_scale.has_value() ? grad_scale->data_ptr<float>() : nullptr;
304   float* found_inf_ptr =
305       found_inf.has_value() ? found_inf->data_ptr<float>() : nullptr;
306   float* lr_ptr = nullptr;
307 
308   std::vector<std::vector<at::Tensor>> tensor_lists{params.vec(), grads.vec()};
309   AT_DISPATCH_FLOATING_TYPES_AND2(
310       kHalf,
311       kBFloat16,
312       params[0].scalar_type(),
313       "fused_sgd_kernel_cuda",
314       [&]() {
315         multi_tensor_apply<2>(
316             tensor_lists,
317             FusedSgdMathFunctor<scalar_t, 2>(),
318             weight_decay,
319             momentum,
320             lr_ptr,
321             lr,
322             dampening,
323             nesterov,
324             maximize,
325             /* is_first_step */ false,
326             grad_scale_ptr,
327             found_inf_ptr);
328       });
329 }
330 
_fused_sgd_kernel_cuda_(at::TensorList params,at::TensorList grads,at::TensorList momentum_buffer_list,const double weight_decay,const double momentum,const at::Tensor & lr,const double dampening,const bool nesterov,const bool maximize,const bool is_first_step,const std::optional<at::Tensor> & grad_scale,const std::optional<at::Tensor> & found_inf)331 void _fused_sgd_kernel_cuda_(
332     at::TensorList params,
333     at::TensorList grads,
334     at::TensorList momentum_buffer_list,
335     const double weight_decay,
336     const double momentum,
337     const at::Tensor& lr,
338     const double dampening,
339     const bool nesterov,
340     const bool maximize,
341     const bool is_first_step,
342     const std::optional<at::Tensor>& grad_scale,
343     const std::optional<at::Tensor>& found_inf) {
344   if (!momentum_buffer_list.empty()) {
345     _fused_sgd_with_momentum_kernel_cuda_(
346         params,
347         grads,
348         momentum_buffer_list,
349         weight_decay,
350         momentum,
351         lr,
352         dampening,
353         nesterov,
354         maximize,
355         is_first_step,
356         grad_scale,
357         found_inf);
358     return;
359   }
360   if (lr.is_cpu()) {
361     _fused_sgd_kernel_cuda_(
362         params,
363         grads,
364         momentum_buffer_list,
365         weight_decay,
366         momentum,
367         lr.item<double>(),
368         dampening,
369         nesterov,
370         maximize,
371         is_first_step,
372         grad_scale,
373         found_inf);
374     return;
375   }
376   TORCH_CHECK_EQ(momentum, 0);
377   TORCH_CHECK(at::native::check_fast_path_restrictions({params, grads}));
378   if (is_first_step) {
379     TORCH_WARN_ONCE(
380         "`is_first_step` argument has no effect when `momentum_buffer_list` is empty");
381   }
382   if (grad_scale.has_value()) {
383     TORCH_CHECK(
384         grad_scale->device() == params[0].device(),
385         "grad_scale must be on the same GPU device as the params");
386   }
387   if (found_inf.has_value()) {
388     TORCH_CHECK(
389         found_inf->device() == params[0].device(),
390         "found_inf must be on the same GPU device as the params");
391   }
392   TORCH_CHECK(
393       lr.device() == params[0].device(),
394       "lr must be on the same GPU device as the params");
395   float* grad_scale_ptr =
396       grad_scale.has_value() ? grad_scale->data_ptr<float>() : nullptr;
397   float* found_inf_ptr =
398       found_inf.has_value() ? found_inf->data_ptr<float>() : nullptr;
399 
400   std::vector<std::vector<at::Tensor>> tensor_lists{params.vec(), grads.vec()};
401   AT_DISPATCH_FLOATING_TYPES_AND2(
402       kHalf,
403       kBFloat16,
404       params[0].scalar_type(),
405       "fused_sgd_kernel_cuda",
406       [&]() {
407         multi_tensor_apply<2>(
408             tensor_lists,
409             FusedSgdMathFunctor<scalar_t, 2>(),
410             weight_decay,
411             momentum,
412             lr.data_ptr<float>(),
413             1.0,
414             dampening,
415             nesterov,
416             maximize,
417             /* is_first_step */ false,
418             grad_scale_ptr,
419             found_inf_ptr);
420       });
421 }
422 
423 } // namespace at::native
424