1 #include <ATen/core/Tensor.h> 2 #include <ATen/native/DispatchStub.h> 3 4 namespace at::native { 5 6 using fused_adagrad_fn = void (*)( 7 const at::Tensor& param, 8 const at::Tensor& grad, 9 const at::Tensor& state_sum, 10 const at::Tensor& state_step, 11 const double lr, 12 const double lr_decay, 13 const double weight_decay, 14 const double eps, 15 const bool maximize, 16 const float* grad_scale_ptr); 17 18 DECLARE_DISPATCH(fused_adagrad_fn, fused_adagrad_stub); 19 20 } // namespace at::native 21