xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/FusedAdamWKernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/TypeDefault.h>
3 #include <ATen/native/ForeachUtils.h>
4 #include <c10/util/Exception.h>
5 #include <ATen/native/cuda/fused_adamw_amsgrad_impl.cuh>
6 #include <ATen/native/cuda/fused_adamw_impl.cuh>
7 
8 namespace at {
9 namespace native {
10 
11 // note(crcrpar): To observe the CI rules, i.e. 20 minutes per file to compile,
12 // defensively split instantiations into _impl files. this is only for CUDA 11.3
13 // for which it took about 20 minutes and 28 minutes in my workstation and CI,
14 // respectively. As a data point, it took about 20 seconds for CUDA 11.7
15 // installed in my environment. See
16 // https://github.com/pytorch/pytorch/pull/81705 for details.
_fused_adamw_kernel_cuda_(at::TensorList params,at::TensorList grads,at::TensorList exp_avgs,at::TensorList exp_avg_sqs,at::TensorList max_exp_avg_sqs,at::TensorList state_steps,const double lr,const double beta1,const double beta2,const double weight_decay,const double eps,const bool amsgrad,const bool maximize,const std::optional<at::Tensor> & grad_scale,const std::optional<at::Tensor> & found_inf)17 void _fused_adamw_kernel_cuda_(
18     at::TensorList params,
19     at::TensorList grads,
20     at::TensorList exp_avgs,
21     at::TensorList exp_avg_sqs,
22     at::TensorList max_exp_avg_sqs,
23     at::TensorList state_steps,
24     const double lr,
25     const double beta1,
26     const double beta2,
27     const double weight_decay,
28     const double eps,
29     const bool amsgrad,
30     const bool maximize,
31     const std::optional<at::Tensor>& grad_scale,
32     const std::optional<at::Tensor>& found_inf) {
33   if (amsgrad) {
34     TORCH_CHECK(
35         at::native::check_fast_path_restrictions(
36             {params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs}),
37         "params, grads, exp_avgs, exp_avg_sqs, and max_exp_avg_sqs must have same dtype, device, and layout");
38     _fused_adamw_amsgrad_cuda_impl_(
39         params,
40         grads,
41         exp_avgs,
42         exp_avg_sqs,
43         max_exp_avg_sqs,
44         state_steps,
45         lr,
46         beta1,
47         beta2,
48         weight_decay,
49         eps,
50         maximize,
51         grad_scale,
52         found_inf);
53   } else {
54     TORCH_CHECK(
55         at::native::check_fast_path_restrictions(
56             {params, grads, exp_avgs, exp_avg_sqs}),
57         "params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout");
58     _fused_adamw_cuda_impl_(
59         params,
60         grads,
61         exp_avgs,
62         exp_avg_sqs,
63         state_steps,
64         lr,
65         beta1,
66         beta2,
67         weight_decay,
68         eps,
69         maximize,
70         grad_scale,
71         found_inf);
72   }
73 }
74 
75 // The following overload simply has a Tensor lr
_fused_adamw_kernel_cuda_(at::TensorList params,at::TensorList grads,at::TensorList exp_avgs,at::TensorList exp_avg_sqs,at::TensorList max_exp_avg_sqs,at::TensorList state_steps,const at::Tensor & lr,const double beta1,const double beta2,const double weight_decay,const double eps,const bool amsgrad,const bool maximize,const std::optional<at::Tensor> & grad_scale,const std::optional<at::Tensor> & found_inf)76 void _fused_adamw_kernel_cuda_(
77     at::TensorList params,
78     at::TensorList grads,
79     at::TensorList exp_avgs,
80     at::TensorList exp_avg_sqs,
81     at::TensorList max_exp_avg_sqs,
82     at::TensorList state_steps,
83     const at::Tensor& lr,
84     const double beta1,
85     const double beta2,
86     const double weight_decay,
87     const double eps,
88     const bool amsgrad,
89     const bool maximize,
90     const std::optional<at::Tensor>& grad_scale,
91     const std::optional<at::Tensor>& found_inf) {
92   if (lr.is_cpu()) {
93     _fused_adamw_kernel_cuda_(
94         params,
95         grads,
96         exp_avgs,
97         exp_avg_sqs,
98         max_exp_avg_sqs,
99         state_steps,
100         lr.item<double>(),
101         beta1,
102         beta2,
103         weight_decay,
104         eps,
105         amsgrad,
106         maximize,
107         grad_scale,
108         found_inf);
109     return;
110   }
111 
112   // Manually check devices since we specify no device check in
113   // native_functions.yaml
114   Device param_device = params[0].device();
115   if (grad_scale != std::nullopt) {
116     TORCH_CHECK(
117         grad_scale->device() == param_device,
118         "grad_scale must be on the same GPU device as the params");
119   }
120   if (found_inf != std::nullopt) {
121     TORCH_CHECK(
122         found_inf->device() == param_device,
123         "found_inf must be on the same GPU device as the params");
124   }
125   TORCH_CHECK(
126       lr.device() == param_device,
127       "lr must be on the same GPU device as the params");
128 
129   if (amsgrad) {
130     TORCH_CHECK(
131         at::native::check_fast_path_restrictions(
132             {params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs}),
133         "params, grads, exp_avgs, exp_avg_sqs, and max_exp_avg_sqs must have same dtype, device, and layout");
134     _fused_adamw_amsgrad_cuda_impl_(
135         params,
136         grads,
137         exp_avgs,
138         exp_avg_sqs,
139         max_exp_avg_sqs,
140         state_steps,
141         lr,
142         beta1,
143         beta2,
144         weight_decay,
145         eps,
146         maximize,
147         grad_scale,
148         found_inf);
149   } else {
150     TORCH_CHECK(
151         at::native::check_fast_path_restrictions(
152             {params, grads, exp_avgs, exp_avg_sqs}),
153         "params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout");
154     _fused_adamw_cuda_impl_(
155         params,
156         grads,
157         exp_avgs,
158         exp_avg_sqs,
159         state_steps,
160         lr,
161         beta1,
162         beta2,
163         weight_decay,
164         eps,
165         maximize,
166         grad_scale,
167         found_inf);
168   }
169 }
170 
171 } // namespace native
172 } // namespace at
173