xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/FusedAdagrad.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/Tensor.h>
2 #include <ATen/native/DispatchStub.h>
3 
4 namespace at::native {
5 
6 using fused_adagrad_fn = void (*)(
7     const at::Tensor& param,
8     const at::Tensor& grad,
9     const at::Tensor& state_sum,
10     const at::Tensor& state_step,
11     const double lr,
12     const double lr_decay,
13     const double weight_decay,
14     const double eps,
15     const bool maximize,
16     const float* grad_scale_ptr);
17 
18 DECLARE_DISPATCH(fused_adagrad_fn, fused_adagrad_stub);
19 
20 } // namespace at::native
21