1 #pragma once 2 #include <ATen/core/Tensor.h> 3 4 namespace at { 5 namespace native { 6 7 void _fused_adam_amsgrad_cuda_impl_( 8 at::TensorList params, 9 at::TensorList grads, 10 at::TensorList exp_avgs, 11 at::TensorList exp_avg_sqs, 12 at::TensorList max_exp_avg_sqs, 13 at::TensorList state_steps, 14 const double lr, 15 const double beta1, 16 const double beta2, 17 const double weight_decay, 18 const double eps, 19 const bool maximize, 20 const std::optional<at::Tensor>& grad_scale, 21 const std::optional<at::Tensor>& found_inf); 22 23 void _fused_adam_amsgrad_cuda_impl_( 24 at::TensorList params, 25 at::TensorList grads, 26 at::TensorList exp_avgs, 27 at::TensorList exp_avg_sqs, 28 at::TensorList max_exp_avg_sqs, 29 at::TensorList state_steps, 30 const at::Tensor& lr, 31 const double beta1, 32 const double beta2, 33 const double weight_decay, 34 const double eps, 35 const bool maximize, 36 const std::optional<at::Tensor>& grad_scale, 37 const std::optional<at::Tensor>& found_inf); 38 39 } // namespace native 40 } // namespace at 41