1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/TypeDefault.h>
3 #include <ATen/native/ForeachUtils.h>
4 #include <c10/util/Exception.h>
5 #include <ATen/native/cuda/fused_adam_amsgrad_impl.cuh>
6 #include <ATen/native/cuda/fused_adam_impl.cuh>
7
8 namespace at::native {
9
10 // note(crcrpar): To observe the CI rules, i.e. 20 minutes per file to compile,
11 // defensively split instantiations into _impl files. this is only for CUDA 11.3
12 // for which it took about 20 minutes and 28 minutes in my workstation and CI,
13 // respectively. As a data point, it took about 20 seconds for CUDA 11.7
14 // installed in my environment. See
15 // https://github.com/pytorch/pytorch/pull/81705 for details.
_fused_adam_kernel_cuda_(at::TensorList params,at::TensorList grads,at::TensorList exp_avgs,at::TensorList exp_avg_sqs,at::TensorList max_exp_avg_sqs,at::TensorList state_steps,const double lr,const double beta1,const double beta2,const double weight_decay,const double eps,const bool amsgrad,const bool maximize,const std::optional<at::Tensor> & grad_scale,const std::optional<at::Tensor> & found_inf)16 void _fused_adam_kernel_cuda_(
17 at::TensorList params,
18 at::TensorList grads,
19 at::TensorList exp_avgs,
20 at::TensorList exp_avg_sqs,
21 at::TensorList max_exp_avg_sqs,
22 at::TensorList state_steps,
23 const double lr,
24 const double beta1,
25 const double beta2,
26 const double weight_decay,
27 const double eps,
28 const bool amsgrad,
29 const bool maximize,
30 const std::optional<at::Tensor>& grad_scale,
31 const std::optional<at::Tensor>& found_inf) {
32 if (amsgrad) {
33 TORCH_CHECK(
34 at::native::check_fast_path_restrictions(
35 {params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs}),
36 "params, grads, exp_avgs, exp_avg_sqs, and max_exp_avg_sqs must have same dtype, device, and layout");
37 _fused_adam_amsgrad_cuda_impl_(
38 params,
39 grads,
40 exp_avgs,
41 exp_avg_sqs,
42 max_exp_avg_sqs,
43 state_steps,
44 lr,
45 beta1,
46 beta2,
47 weight_decay,
48 eps,
49 maximize,
50 grad_scale,
51 found_inf);
52 } else {
53 TORCH_CHECK(
54 at::native::check_fast_path_restrictions(
55 {params, grads, exp_avgs, exp_avg_sqs}),
56 "params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout");
57 _fused_adam_cuda_impl_(
58 params,
59 grads,
60 exp_avgs,
61 exp_avg_sqs,
62 state_steps,
63 lr,
64 beta1,
65 beta2,
66 weight_decay,
67 eps,
68 maximize,
69 grad_scale,
70 found_inf);
71 }
72 }
73
74 // The following overload simply has a Tensor lr
_fused_adam_kernel_cuda_(at::TensorList params,at::TensorList grads,at::TensorList exp_avgs,at::TensorList exp_avg_sqs,at::TensorList max_exp_avg_sqs,at::TensorList state_steps,const at::Tensor & lr,const double beta1,const double beta2,const double weight_decay,const double eps,const bool amsgrad,const bool maximize,const std::optional<at::Tensor> & grad_scale,const std::optional<at::Tensor> & found_inf)75 void _fused_adam_kernel_cuda_(
76 at::TensorList params,
77 at::TensorList grads,
78 at::TensorList exp_avgs,
79 at::TensorList exp_avg_sqs,
80 at::TensorList max_exp_avg_sqs,
81 at::TensorList state_steps,
82 const at::Tensor& lr,
83 const double beta1,
84 const double beta2,
85 const double weight_decay,
86 const double eps,
87 const bool amsgrad,
88 const bool maximize,
89 const std::optional<at::Tensor>& grad_scale,
90 const std::optional<at::Tensor>& found_inf) {
91 if (lr.is_cpu()) {
92 _fused_adam_kernel_cuda_(
93 params,
94 grads,
95 exp_avgs,
96 exp_avg_sqs,
97 max_exp_avg_sqs,
98 state_steps,
99 lr.item<double>(),
100 beta1,
101 beta2,
102 weight_decay,
103 eps,
104 amsgrad,
105 maximize,
106 grad_scale,
107 found_inf);
108 return;
109 }
110
111 // Manually check devices since we specify no device check in
112 // native_functions.yaml
113 Device param_device = params[0].device();
114 if (grad_scale != std::nullopt) {
115 TORCH_CHECK(
116 grad_scale->device() == param_device,
117 "grad_scale must be on the same GPU device as the params");
118 }
119 if (found_inf != std::nullopt) {
120 TORCH_CHECK(
121 found_inf->device() == param_device,
122 "found_inf must be on the same GPU device as the params");
123 }
124 TORCH_CHECK(
125 lr.device() == param_device,
126 "lr must be on the same GPU device as the params");
127
128 if (amsgrad) {
129 TORCH_CHECK(
130 at::native::check_fast_path_restrictions(
131 {params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs}),
132 "params, grads, exp_avgs, exp_avg_sqs, and max_exp_avg_sqs must have same dtype, device, and layout");
133 _fused_adam_amsgrad_cuda_impl_(
134 params,
135 grads,
136 exp_avgs,
137 exp_avg_sqs,
138 max_exp_avg_sqs,
139 state_steps,
140 lr,
141 beta1,
142 beta2,
143 weight_decay,
144 eps,
145 maximize,
146 grad_scale,
147 found_inf);
148 } else {
149 TORCH_CHECK(
150 at::native::check_fast_path_restrictions(
151 {params, grads, exp_avgs, exp_avg_sqs}),
152 "params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout");
153 _fused_adam_cuda_impl_(
154 params,
155 grads,
156 exp_avgs,
157 exp_avg_sqs,
158 state_steps,
159 lr,
160 beta1,
161 beta2,
162 weight_decay,
163 eps,
164 maximize,
165 grad_scale,
166 found_inf);
167 }
168 }
169
170 } // namespace at::native
171