xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/FusedAdamKernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/TypeDefault.h>
3 #include <ATen/native/ForeachUtils.h>
4 #include <c10/util/Exception.h>
5 #include <ATen/native/cuda/fused_adam_amsgrad_impl.cuh>
6 #include <ATen/native/cuda/fused_adam_impl.cuh>
7 
8 namespace at::native {
9 
10 // note(crcrpar): To observe the CI rules, i.e. 20 minutes per file to compile,
11 // defensively split instantiations into _impl files. this is only for CUDA 11.3
12 // for which it took about 20 minutes and 28 minutes in my workstation and CI,
13 // respectively. As a data point, it took about 20 seconds for CUDA 11.7
14 // installed in my environment. See
15 // https://github.com/pytorch/pytorch/pull/81705 for details.
_fused_adam_kernel_cuda_(at::TensorList params,at::TensorList grads,at::TensorList exp_avgs,at::TensorList exp_avg_sqs,at::TensorList max_exp_avg_sqs,at::TensorList state_steps,const double lr,const double beta1,const double beta2,const double weight_decay,const double eps,const bool amsgrad,const bool maximize,const std::optional<at::Tensor> & grad_scale,const std::optional<at::Tensor> & found_inf)16 void _fused_adam_kernel_cuda_(
17     at::TensorList params,
18     at::TensorList grads,
19     at::TensorList exp_avgs,
20     at::TensorList exp_avg_sqs,
21     at::TensorList max_exp_avg_sqs,
22     at::TensorList state_steps,
23     const double lr,
24     const double beta1,
25     const double beta2,
26     const double weight_decay,
27     const double eps,
28     const bool amsgrad,
29     const bool maximize,
30     const std::optional<at::Tensor>& grad_scale,
31     const std::optional<at::Tensor>& found_inf) {
32   if (amsgrad) {
33     TORCH_CHECK(
34         at::native::check_fast_path_restrictions(
35             {params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs}),
36         "params, grads, exp_avgs, exp_avg_sqs, and max_exp_avg_sqs must have same dtype, device, and layout");
37     _fused_adam_amsgrad_cuda_impl_(
38         params,
39         grads,
40         exp_avgs,
41         exp_avg_sqs,
42         max_exp_avg_sqs,
43         state_steps,
44         lr,
45         beta1,
46         beta2,
47         weight_decay,
48         eps,
49         maximize,
50         grad_scale,
51         found_inf);
52   } else {
53     TORCH_CHECK(
54         at::native::check_fast_path_restrictions(
55             {params, grads, exp_avgs, exp_avg_sqs}),
56         "params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout");
57     _fused_adam_cuda_impl_(
58         params,
59         grads,
60         exp_avgs,
61         exp_avg_sqs,
62         state_steps,
63         lr,
64         beta1,
65         beta2,
66         weight_decay,
67         eps,
68         maximize,
69         grad_scale,
70         found_inf);
71   }
72 }
73 
74 // The following overload simply has a Tensor lr
_fused_adam_kernel_cuda_(at::TensorList params,at::TensorList grads,at::TensorList exp_avgs,at::TensorList exp_avg_sqs,at::TensorList max_exp_avg_sqs,at::TensorList state_steps,const at::Tensor & lr,const double beta1,const double beta2,const double weight_decay,const double eps,const bool amsgrad,const bool maximize,const std::optional<at::Tensor> & grad_scale,const std::optional<at::Tensor> & found_inf)75 void _fused_adam_kernel_cuda_(
76     at::TensorList params,
77     at::TensorList grads,
78     at::TensorList exp_avgs,
79     at::TensorList exp_avg_sqs,
80     at::TensorList max_exp_avg_sqs,
81     at::TensorList state_steps,
82     const at::Tensor& lr,
83     const double beta1,
84     const double beta2,
85     const double weight_decay,
86     const double eps,
87     const bool amsgrad,
88     const bool maximize,
89     const std::optional<at::Tensor>& grad_scale,
90     const std::optional<at::Tensor>& found_inf) {
91   if (lr.is_cpu()) {
92     _fused_adam_kernel_cuda_(
93         params,
94         grads,
95         exp_avgs,
96         exp_avg_sqs,
97         max_exp_avg_sqs,
98         state_steps,
99         lr.item<double>(),
100         beta1,
101         beta2,
102         weight_decay,
103         eps,
104         amsgrad,
105         maximize,
106         grad_scale,
107         found_inf);
108     return;
109   }
110 
111   // Manually check devices since we specify no device check in
112   // native_functions.yaml
113   Device param_device = params[0].device();
114   if (grad_scale != std::nullopt) {
115     TORCH_CHECK(
116         grad_scale->device() == param_device,
117         "grad_scale must be on the same GPU device as the params");
118   }
119   if (found_inf != std::nullopt) {
120     TORCH_CHECK(
121         found_inf->device() == param_device,
122         "found_inf must be on the same GPU device as the params");
123   }
124   TORCH_CHECK(
125       lr.device() == param_device,
126       "lr must be on the same GPU device as the params");
127 
128   if (amsgrad) {
129     TORCH_CHECK(
130         at::native::check_fast_path_restrictions(
131             {params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs}),
132         "params, grads, exp_avgs, exp_avg_sqs, and max_exp_avg_sqs must have same dtype, device, and layout");
133     _fused_adam_amsgrad_cuda_impl_(
134         params,
135         grads,
136         exp_avgs,
137         exp_avg_sqs,
138         max_exp_avg_sqs,
139         state_steps,
140         lr,
141         beta1,
142         beta2,
143         weight_decay,
144         eps,
145         maximize,
146         grad_scale,
147         found_inf);
148   } else {
149     TORCH_CHECK(
150         at::native::check_fast_path_restrictions(
151             {params, grads, exp_avgs, exp_avg_sqs}),
152         "params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout");
153     _fused_adam_cuda_impl_(
154         params,
155         grads,
156         exp_avgs,
157         exp_avg_sqs,
158         state_steps,
159         lr,
160         beta1,
161         beta2,
162         weight_decay,
163         eps,
164         maximize,
165         grad_scale,
166         found_inf);
167   }
168 }
169 
170 } // namespace at::native
171