1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/TypeDefault.h>
3 #include <ATen/native/ForeachUtils.h>
4 #include <c10/util/Exception.h>
5 #include <ATen/native/cuda/fused_adamw_amsgrad_impl.cuh>
6 #include <ATen/native/cuda/fused_adamw_impl.cuh>
7
8 namespace at {
9 namespace native {
10
11 // note(crcrpar): To observe the CI rules, i.e. 20 minutes per file to compile,
12 // defensively split instantiations into _impl files. this is only for CUDA 11.3
13 // for which it took about 20 minutes and 28 minutes in my workstation and CI,
14 // respectively. As a data point, it took about 20 seconds for CUDA 11.7
15 // installed in my environment. See
16 // https://github.com/pytorch/pytorch/pull/81705 for details.
_fused_adamw_kernel_cuda_(at::TensorList params,at::TensorList grads,at::TensorList exp_avgs,at::TensorList exp_avg_sqs,at::TensorList max_exp_avg_sqs,at::TensorList state_steps,const double lr,const double beta1,const double beta2,const double weight_decay,const double eps,const bool amsgrad,const bool maximize,const std::optional<at::Tensor> & grad_scale,const std::optional<at::Tensor> & found_inf)17 void _fused_adamw_kernel_cuda_(
18 at::TensorList params,
19 at::TensorList grads,
20 at::TensorList exp_avgs,
21 at::TensorList exp_avg_sqs,
22 at::TensorList max_exp_avg_sqs,
23 at::TensorList state_steps,
24 const double lr,
25 const double beta1,
26 const double beta2,
27 const double weight_decay,
28 const double eps,
29 const bool amsgrad,
30 const bool maximize,
31 const std::optional<at::Tensor>& grad_scale,
32 const std::optional<at::Tensor>& found_inf) {
33 if (amsgrad) {
34 TORCH_CHECK(
35 at::native::check_fast_path_restrictions(
36 {params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs}),
37 "params, grads, exp_avgs, exp_avg_sqs, and max_exp_avg_sqs must have same dtype, device, and layout");
38 _fused_adamw_amsgrad_cuda_impl_(
39 params,
40 grads,
41 exp_avgs,
42 exp_avg_sqs,
43 max_exp_avg_sqs,
44 state_steps,
45 lr,
46 beta1,
47 beta2,
48 weight_decay,
49 eps,
50 maximize,
51 grad_scale,
52 found_inf);
53 } else {
54 TORCH_CHECK(
55 at::native::check_fast_path_restrictions(
56 {params, grads, exp_avgs, exp_avg_sqs}),
57 "params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout");
58 _fused_adamw_cuda_impl_(
59 params,
60 grads,
61 exp_avgs,
62 exp_avg_sqs,
63 state_steps,
64 lr,
65 beta1,
66 beta2,
67 weight_decay,
68 eps,
69 maximize,
70 grad_scale,
71 found_inf);
72 }
73 }
74
75 // The following overload simply has a Tensor lr
_fused_adamw_kernel_cuda_(at::TensorList params,at::TensorList grads,at::TensorList exp_avgs,at::TensorList exp_avg_sqs,at::TensorList max_exp_avg_sqs,at::TensorList state_steps,const at::Tensor & lr,const double beta1,const double beta2,const double weight_decay,const double eps,const bool amsgrad,const bool maximize,const std::optional<at::Tensor> & grad_scale,const std::optional<at::Tensor> & found_inf)76 void _fused_adamw_kernel_cuda_(
77 at::TensorList params,
78 at::TensorList grads,
79 at::TensorList exp_avgs,
80 at::TensorList exp_avg_sqs,
81 at::TensorList max_exp_avg_sqs,
82 at::TensorList state_steps,
83 const at::Tensor& lr,
84 const double beta1,
85 const double beta2,
86 const double weight_decay,
87 const double eps,
88 const bool amsgrad,
89 const bool maximize,
90 const std::optional<at::Tensor>& grad_scale,
91 const std::optional<at::Tensor>& found_inf) {
92 if (lr.is_cpu()) {
93 _fused_adamw_kernel_cuda_(
94 params,
95 grads,
96 exp_avgs,
97 exp_avg_sqs,
98 max_exp_avg_sqs,
99 state_steps,
100 lr.item<double>(),
101 beta1,
102 beta2,
103 weight_decay,
104 eps,
105 amsgrad,
106 maximize,
107 grad_scale,
108 found_inf);
109 return;
110 }
111
112 // Manually check devices since we specify no device check in
113 // native_functions.yaml
114 Device param_device = params[0].device();
115 if (grad_scale != std::nullopt) {
116 TORCH_CHECK(
117 grad_scale->device() == param_device,
118 "grad_scale must be on the same GPU device as the params");
119 }
120 if (found_inf != std::nullopt) {
121 TORCH_CHECK(
122 found_inf->device() == param_device,
123 "found_inf must be on the same GPU device as the params");
124 }
125 TORCH_CHECK(
126 lr.device() == param_device,
127 "lr must be on the same GPU device as the params");
128
129 if (amsgrad) {
130 TORCH_CHECK(
131 at::native::check_fast_path_restrictions(
132 {params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs}),
133 "params, grads, exp_avgs, exp_avg_sqs, and max_exp_avg_sqs must have same dtype, device, and layout");
134 _fused_adamw_amsgrad_cuda_impl_(
135 params,
136 grads,
137 exp_avgs,
138 exp_avg_sqs,
139 max_exp_avg_sqs,
140 state_steps,
141 lr,
142 beta1,
143 beta2,
144 weight_decay,
145 eps,
146 maximize,
147 grad_scale,
148 found_inf);
149 } else {
150 TORCH_CHECK(
151 at::native::check_fast_path_restrictions(
152 {params, grads, exp_avgs, exp_avg_sqs}),
153 "params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout");
154 _fused_adamw_cuda_impl_(
155 params,
156 grads,
157 exp_avgs,
158 exp_avg_sqs,
159 state_steps,
160 lr,
161 beta1,
162 beta2,
163 weight_decay,
164 eps,
165 maximize,
166 grad_scale,
167 found_inf);
168 }
169 }
170
171 } // namespace native
172 } // namespace at
173