1 #pragma once 2 #include <ATen/core/Tensor.h> 3 4 namespace at { 5 namespace native { 6 7 void _fused_adamw_cuda_impl_( 8 at::TensorList params, 9 at::TensorList grads, 10 at::TensorList exp_avgs, 11 at::TensorList exp_avg_sqs, 12 at::TensorList state_steps, 13 const double lr, 14 const double beta1, 15 const double beta2, 16 const double weight_decay, 17 const double eps, 18 const bool maximize, 19 const std::optional<at::Tensor>& grad_scale, 20 const std::optional<at::Tensor>& found_inf); 21 22 void _fused_adamw_cuda_impl_( 23 at::TensorList params, 24 at::TensorList grads, 25 at::TensorList exp_avgs, 26 at::TensorList exp_avg_sqs, 27 at::TensorList state_steps, 28 const at::Tensor& lr, 29 const double beta1, 30 const double beta2, 31 const double weight_decay, 32 const double eps, 33 const bool maximize, 34 const std::optional<at::Tensor>& grad_scale, 35 const std::optional<at::Tensor>& found_inf); 36 37 } // namespace native 38 } // namespace at 39