xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/fused_adam_amsgrad_impl.cuh (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/core/Tensor.h>
3 
4 namespace at {
5 namespace native {
6 
7 void _fused_adam_amsgrad_cuda_impl_(
8     at::TensorList params,
9     at::TensorList grads,
10     at::TensorList exp_avgs,
11     at::TensorList exp_avg_sqs,
12     at::TensorList max_exp_avg_sqs,
13     at::TensorList state_steps,
14     const double lr,
15     const double beta1,
16     const double beta2,
17     const double weight_decay,
18     const double eps,
19     const bool maximize,
20     const std::optional<at::Tensor>& grad_scale,
21     const std::optional<at::Tensor>& found_inf);
22 
23 void _fused_adam_amsgrad_cuda_impl_(
24     at::TensorList params,
25     at::TensorList grads,
26     at::TensorList exp_avgs,
27     at::TensorList exp_avg_sqs,
28     at::TensorList max_exp_avg_sqs,
29     at::TensorList state_steps,
30     const at::Tensor& lr,
31     const double beta1,
32     const double beta2,
33     const double weight_decay,
34     const double eps,
35     const bool maximize,
36     const std::optional<at::Tensor>& grad_scale,
37     const std::optional<at::Tensor>& found_inf);
38 
39 } // namespace native
40 } // namespace at
41