xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/fused_adamw_impl.cuh (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/core/Tensor.h>
3 
4 namespace at {
5 namespace native {
6 
7 void _fused_adamw_cuda_impl_(
8     at::TensorList params,
9     at::TensorList grads,
10     at::TensorList exp_avgs,
11     at::TensorList exp_avg_sqs,
12     at::TensorList state_steps,
13     const double lr,
14     const double beta1,
15     const double beta2,
16     const double weight_decay,
17     const double eps,
18     const bool maximize,
19     const std::optional<at::Tensor>& grad_scale,
20     const std::optional<at::Tensor>& found_inf);
21 
22 void _fused_adamw_cuda_impl_(
23     at::TensorList params,
24     at::TensorList grads,
25     at::TensorList exp_avgs,
26     at::TensorList exp_avg_sqs,
27     at::TensorList state_steps,
28     const at::Tensor& lr,
29     const double beta1,
30     const double beta2,
31     const double weight_decay,
32     const double eps,
33     const bool maximize,
34     const std::optional<at::Tensor>& grad_scale,
35     const std::optional<at::Tensor>& found_inf);
36 
37 } // namespace native
38 } // namespace at
39