1 #include <ATen/Dispatch.h>
2 #include <ATen/OpMathType.h>
3 #include <ATen/core/Tensor.h>
4 #include <ATen/native/ForeachUtils.h>
5 #include <c10/util/Exception.h>
6 #include <ATen/native/cuda/ForeachFunctors.cuh>
7 #include <ATen/native/cuda/MultiTensorApply.cuh>
8
9 namespace at::native {
10
11 namespace {
12
13 template <typename scalar_t, int depth>
sgd_math(scalar_t r_args[depth][kILP],const double weight_decay,const double momentum,const float * lr_ptr,const double lr,const double dampening,const bool nesterov,const bool maximize,const bool is_first_step,const float * grad_scale_ptr)14 C10_DEVICE __forceinline__ void sgd_math(
15 scalar_t r_args[depth][kILP],
16 const double weight_decay,
17 const double momentum,
18 const float* lr_ptr,
19 const double lr,
20 const double dampening,
21 const bool nesterov,
22 const bool maximize,
23 const bool is_first_step,
24 const float* grad_scale_ptr) {
25 using opmath_t = at::opmath_type<scalar_t>;
26 const double double_lr = lr_ptr != nullptr ? *lr_ptr : lr;
27 #pragma unroll
28 for (int ii = 0; ii < kILP; ii++) {
29 auto p = static_cast<opmath_t>(r_args[0][ii]);
30 auto g = static_cast<opmath_t>(r_args[1][ii]);
31 if (grad_scale_ptr) {
32 g /= static_cast<double>(*grad_scale_ptr);
33 r_args[1][ii] = g;
34 }
35 if (maximize) {
36 g *= -1.0;
37 }
38 if (weight_decay != 0) {
39 g += weight_decay * p;
40 }
41 if (depth > 2) {
42 const auto momentum_buffer = is_first_step
43 ? g
44 : (momentum * static_cast<opmath_t>(r_args[2][ii]) +
45 (1 - dampening) * g);
46 r_args[2][ii] = momentum_buffer;
47
48 if (nesterov) {
49 g = g + momentum * momentum_buffer;
50 } else {
51 g = momentum_buffer;
52 }
53 }
54 p -= double_lr * g;
55 r_args[0][ii] = p;
56 }
57 }
58
59 template <typename scalar_t, int depth>
60 struct FusedSgdMathFunctor {
61 static_assert(
62 depth == 2 || depth == 3,
63 "depth of 2 for SGD w/ momentum == 0, 3 for SGD w/ momentum != 0");
operator ()at::native::__anon7b1e40e20111::FusedSgdMathFunctor64 C10_DEVICE __forceinline__ void operator()(
65 const int chunk_size,
66 TensorListMetadata<depth>& tl,
67 const double weight_decay,
68 const double momentum,
69 const float* lr_ptr,
70 const double lr,
71 const double dampening,
72 const bool nesterov,
73 const bool maximize,
74 const bool is_first_step,
75 const float* grad_scale_ptr,
76 const float* found_inf_ptr) {
77 if (found_inf_ptr && *found_inf_ptr == 1) {
78 return;
79 }
80 const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
81 const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
82
83 scalar_t* args[depth];
84 scalar_t r_args[depth][kILP];
85 const auto all_aligned{
86 init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc)};
87 const auto n = tl.numel_for_tensor[tensor_loc] - chunk_idx * chunk_size;
88
89 const auto use_faster_load_store =
90 (n % kILP == 0) && (chunk_size % kILP == 0) && all_aligned;
91 if (use_faster_load_store) {
92 for (auto i_start = threadIdx.x;
93 i_start * kILP < n && i_start * kILP < chunk_size;
94 i_start += blockDim.x) {
95 #pragma unroll
96 for (auto i = 0; i < depth; i++) {
97 load_store(r_args[i], args[i], 0, i_start);
98 }
99 sgd_math<scalar_t, depth>(
100 r_args,
101 weight_decay,
102 momentum,
103 lr_ptr,
104 lr,
105 dampening,
106 nesterov,
107 maximize,
108 is_first_step,
109 grad_scale_ptr);
110 load_store(args[0], r_args[0], i_start, 0);
111 if (grad_scale_ptr) {
112 load_store(args[1], r_args[1], i_start, 0);
113 }
114 if (depth > 2) {
115 load_store(args[2], r_args[2], i_start, 0);
116 }
117 }
118 } else {
119 for (auto i_start = 0; i_start < n && i_start < chunk_size;
120 i_start += blockDim.x * kILP) {
121 load_args<depth>(r_args, args, i_start, chunk_size, n);
122 sgd_math<scalar_t, depth>(
123 r_args,
124 weight_decay,
125 momentum,
126 lr_ptr,
127 lr,
128 dampening,
129 nesterov,
130 maximize,
131 is_first_step,
132 grad_scale_ptr);
133 store_args(args[0], r_args[0], i_start, chunk_size, n);
134 if (grad_scale_ptr) {
135 store_args(args[1], r_args[1], i_start, chunk_size, n);
136 }
137 if (depth > 2) {
138 store_args(args[2], r_args[2], i_start, chunk_size, n);
139 }
140 }
141 }
142 }
143 };
144
_fused_sgd_with_momentum_kernel_cuda_(at::TensorList params,at::TensorList grads,at::TensorList momentum_buffer_list,const double weight_decay,const double momentum,const double lr,const double dampening,const bool nesterov,const bool maximize,const bool is_first_step,const std::optional<at::Tensor> & grad_scale,const std::optional<at::Tensor> & found_inf)145 void _fused_sgd_with_momentum_kernel_cuda_(
146 at::TensorList params,
147 at::TensorList grads,
148 at::TensorList momentum_buffer_list,
149 const double weight_decay,
150 const double momentum,
151 const double lr,
152 const double dampening,
153 const bool nesterov,
154 const bool maximize,
155 const bool is_first_step,
156 const std::optional<at::Tensor>& grad_scale,
157 const std::optional<at::Tensor>& found_inf) {
158 TORCH_CHECK_GT(momentum, 0);
159 TORCH_CHECK(at::native::check_fast_path_restrictions(
160 {params, grads, momentum_buffer_list}));
161 float* grad_scale_ptr =
162 grad_scale.has_value() ? grad_scale->data_ptr<float>() : nullptr;
163 float* found_inf_ptr =
164 found_inf.has_value() ? found_inf->data_ptr<float>() : nullptr;
165 float* lr_ptr = nullptr;
166
167 std::vector<std::vector<at::Tensor>> tensor_lists{
168 params.vec(), grads.vec(), momentum_buffer_list.vec()};
169 AT_DISPATCH_FLOATING_TYPES_AND2(
170 kHalf,
171 kBFloat16,
172 params[0].scalar_type(),
173 "fused_sgd_with_momentum_kernel_cuda",
174 [&]() {
175 multi_tensor_apply<3>(
176 tensor_lists,
177 FusedSgdMathFunctor<scalar_t, 3>(),
178 weight_decay,
179 momentum,
180 lr_ptr,
181 lr,
182 dampening,
183 nesterov,
184 maximize,
185 is_first_step,
186 grad_scale_ptr,
187 found_inf_ptr);
188 });
189 }
190
_fused_sgd_with_momentum_kernel_cuda_(at::TensorList params,at::TensorList grads,at::TensorList momentum_buffer_list,const double weight_decay,const double momentum,const at::Tensor & lr,const double dampening,const bool nesterov,const bool maximize,const bool is_first_step,const std::optional<at::Tensor> & grad_scale,const std::optional<at::Tensor> & found_inf)191 void _fused_sgd_with_momentum_kernel_cuda_(
192 at::TensorList params,
193 at::TensorList grads,
194 at::TensorList momentum_buffer_list,
195 const double weight_decay,
196 const double momentum,
197 const at::Tensor& lr,
198 const double dampening,
199 const bool nesterov,
200 const bool maximize,
201 const bool is_first_step,
202 const std::optional<at::Tensor>& grad_scale,
203 const std::optional<at::Tensor>& found_inf) {
204 if (lr.is_cpu()) {
205 _fused_sgd_with_momentum_kernel_cuda_(
206 params,
207 grads,
208 momentum_buffer_list,
209 weight_decay,
210 momentum,
211 lr.item<double>(),
212 dampening,
213 nesterov,
214 maximize,
215 is_first_step,
216 grad_scale,
217 found_inf);
218 return;
219 }
220 TORCH_CHECK_GT(momentum, 0);
221 TORCH_CHECK(at::native::check_fast_path_restrictions(
222 {params, grads, momentum_buffer_list}));
223 if (grad_scale != std::nullopt) {
224 TORCH_CHECK(
225 grad_scale->device() == params[0].device(),
226 "grad_scale must be on the same GPU device as the params");
227 }
228 if (found_inf != std::nullopt) {
229 TORCH_CHECK(
230 found_inf->device() == params[0].device(),
231 "found_inf must be on the same GPU device as the params");
232 }
233 TORCH_CHECK(
234 lr.device() == params[0].device(),
235 "found_inf must be on the same GPU device as the params");
236 float* grad_scale_ptr =
237 grad_scale.has_value() ? grad_scale->data_ptr<float>() : nullptr;
238 float* found_inf_ptr =
239 found_inf.has_value() ? found_inf->data_ptr<float>() : nullptr;
240
241 std::vector<std::vector<at::Tensor>> tensor_lists{
242 params.vec(), grads.vec(), momentum_buffer_list.vec()};
243 AT_DISPATCH_FLOATING_TYPES_AND2(
244 kHalf,
245 kBFloat16,
246 params[0].scalar_type(),
247 "fused_sgd_with_momentum_kernel_cuda",
248 [&]() {
249 multi_tensor_apply<3>(
250 tensor_lists,
251 FusedSgdMathFunctor<scalar_t, 3>(),
252 weight_decay,
253 momentum,
254 lr.data_ptr<float>(),
255 1.0,
256 dampening,
257 nesterov,
258 maximize,
259 is_first_step,
260 grad_scale_ptr,
261 found_inf_ptr);
262 });
263 }
264
265 } // namespace
266
_fused_sgd_kernel_cuda_(at::TensorList params,at::TensorList grads,at::TensorList momentum_buffer_list,const double weight_decay,const double momentum,const double lr,const double dampening,const bool nesterov,const bool maximize,const bool is_first_step,const std::optional<at::Tensor> & grad_scale,const std::optional<at::Tensor> & found_inf)267 void _fused_sgd_kernel_cuda_(
268 at::TensorList params,
269 at::TensorList grads,
270 at::TensorList momentum_buffer_list,
271 const double weight_decay,
272 const double momentum,
273 const double lr,
274 const double dampening,
275 const bool nesterov,
276 const bool maximize,
277 const bool is_first_step,
278 const std::optional<at::Tensor>& grad_scale,
279 const std::optional<at::Tensor>& found_inf) {
280 if (!momentum_buffer_list.empty()) {
281 _fused_sgd_with_momentum_kernel_cuda_(
282 params,
283 grads,
284 momentum_buffer_list,
285 weight_decay,
286 momentum,
287 lr,
288 dampening,
289 nesterov,
290 maximize,
291 is_first_step,
292 grad_scale,
293 found_inf);
294 return;
295 }
296 TORCH_CHECK_EQ(momentum, 0);
297 TORCH_CHECK(at::native::check_fast_path_restrictions({params, grads}));
298 if (is_first_step) {
299 TORCH_WARN_ONCE(
300 "`is_first_step` argument has no effect when `momentum_buffer_list` is empty");
301 }
302 float* grad_scale_ptr =
303 grad_scale.has_value() ? grad_scale->data_ptr<float>() : nullptr;
304 float* found_inf_ptr =
305 found_inf.has_value() ? found_inf->data_ptr<float>() : nullptr;
306 float* lr_ptr = nullptr;
307
308 std::vector<std::vector<at::Tensor>> tensor_lists{params.vec(), grads.vec()};
309 AT_DISPATCH_FLOATING_TYPES_AND2(
310 kHalf,
311 kBFloat16,
312 params[0].scalar_type(),
313 "fused_sgd_kernel_cuda",
314 [&]() {
315 multi_tensor_apply<2>(
316 tensor_lists,
317 FusedSgdMathFunctor<scalar_t, 2>(),
318 weight_decay,
319 momentum,
320 lr_ptr,
321 lr,
322 dampening,
323 nesterov,
324 maximize,
325 /* is_first_step */ false,
326 grad_scale_ptr,
327 found_inf_ptr);
328 });
329 }
330
_fused_sgd_kernel_cuda_(at::TensorList params,at::TensorList grads,at::TensorList momentum_buffer_list,const double weight_decay,const double momentum,const at::Tensor & lr,const double dampening,const bool nesterov,const bool maximize,const bool is_first_step,const std::optional<at::Tensor> & grad_scale,const std::optional<at::Tensor> & found_inf)331 void _fused_sgd_kernel_cuda_(
332 at::TensorList params,
333 at::TensorList grads,
334 at::TensorList momentum_buffer_list,
335 const double weight_decay,
336 const double momentum,
337 const at::Tensor& lr,
338 const double dampening,
339 const bool nesterov,
340 const bool maximize,
341 const bool is_first_step,
342 const std::optional<at::Tensor>& grad_scale,
343 const std::optional<at::Tensor>& found_inf) {
344 if (!momentum_buffer_list.empty()) {
345 _fused_sgd_with_momentum_kernel_cuda_(
346 params,
347 grads,
348 momentum_buffer_list,
349 weight_decay,
350 momentum,
351 lr,
352 dampening,
353 nesterov,
354 maximize,
355 is_first_step,
356 grad_scale,
357 found_inf);
358 return;
359 }
360 if (lr.is_cpu()) {
361 _fused_sgd_kernel_cuda_(
362 params,
363 grads,
364 momentum_buffer_list,
365 weight_decay,
366 momentum,
367 lr.item<double>(),
368 dampening,
369 nesterov,
370 maximize,
371 is_first_step,
372 grad_scale,
373 found_inf);
374 return;
375 }
376 TORCH_CHECK_EQ(momentum, 0);
377 TORCH_CHECK(at::native::check_fast_path_restrictions({params, grads}));
378 if (is_first_step) {
379 TORCH_WARN_ONCE(
380 "`is_first_step` argument has no effect when `momentum_buffer_list` is empty");
381 }
382 if (grad_scale.has_value()) {
383 TORCH_CHECK(
384 grad_scale->device() == params[0].device(),
385 "grad_scale must be on the same GPU device as the params");
386 }
387 if (found_inf.has_value()) {
388 TORCH_CHECK(
389 found_inf->device() == params[0].device(),
390 "found_inf must be on the same GPU device as the params");
391 }
392 TORCH_CHECK(
393 lr.device() == params[0].device(),
394 "lr must be on the same GPU device as the params");
395 float* grad_scale_ptr =
396 grad_scale.has_value() ? grad_scale->data_ptr<float>() : nullptr;
397 float* found_inf_ptr =
398 found_inf.has_value() ? found_inf->data_ptr<float>() : nullptr;
399
400 std::vector<std::vector<at::Tensor>> tensor_lists{params.vec(), grads.vec()};
401 AT_DISPATCH_FLOATING_TYPES_AND2(
402 kHalf,
403 kBFloat16,
404 params[0].scalar_type(),
405 "fused_sgd_kernel_cuda",
406 [&]() {
407 multi_tensor_apply<2>(
408 tensor_lists,
409 FusedSgdMathFunctor<scalar_t, 2>(),
410 weight_decay,
411 momentum,
412 lr.data_ptr<float>(),
413 1.0,
414 dampening,
415 nesterov,
416 maximize,
417 /* is_first_step */ false,
418 grad_scale_ptr,
419 found_inf_ptr);
420 });
421 }
422
423 } // namespace at::native
424