xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/optim/rmsprop.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/nn/module.h>
4 #include <torch/optim/optimizer.h>
5 #include <torch/optim/serialize.h>
6 #include <torch/serialize/archive.h>
7 #include <torch/types.h>
8 
9 #include <functional>
10 #include <memory>
11 #include <string>
12 #include <vector>
13 
14 namespace torch {
15 namespace serialize {
16 class OutputArchive;
17 class InputArchive;
18 } // namespace serialize
19 } // namespace torch
20 
21 namespace torch {
22 namespace optim {
23 
24 struct TORCH_API RMSpropOptions
25     : public OptimizerCloneableOptions<RMSpropOptions> {
26   RMSpropOptions(double lr = 1e-2);
27   TORCH_ARG(double, lr) = 1e-2;
28   TORCH_ARG(double, alpha) = 0.99;
29   TORCH_ARG(double, eps) = 1e-8;
30   TORCH_ARG(double, weight_decay) = 0;
31   TORCH_ARG(double, momentum) = 0;
32   TORCH_ARG(bool, centered) = false;
33 
34  public:
35   void serialize(torch::serialize::InputArchive& archive) override;
36   void serialize(torch::serialize::OutputArchive& archive) const override;
37   TORCH_API friend bool operator==(
38       const RMSpropOptions& lhs,
39       const RMSpropOptions& rhs);
40   double get_lr() const override;
41   void set_lr(const double lr) override;
42 };
43 
44 struct TORCH_API RMSpropParamState
45     : public OptimizerCloneableParamState<RMSpropParamState> {
46   TORCH_ARG(int64_t, step) = 0;
47   TORCH_ARG(torch::Tensor, square_avg);
48   TORCH_ARG(torch::Tensor, momentum_buffer) = {};
49   TORCH_ARG(torch::Tensor, grad_avg) = {};
50 
51  public:
52   void serialize(torch::serialize::InputArchive& archive) override;
53   void serialize(torch::serialize::OutputArchive& archive) const override;
54   TORCH_API friend bool operator==(
55       const RMSpropParamState& lhs,
56       const RMSpropParamState& rhs);
57 };
58 
59 class TORCH_API RMSprop : public Optimizer {
60  public:
61   explicit RMSprop(
62       std::vector<OptimizerParamGroup> param_groups,
63       RMSpropOptions defaults = {})
Optimizer(std::move (param_groups),std::make_unique<RMSpropOptions> (defaults))64       : Optimizer(
65             std::move(param_groups),
66             std::make_unique<RMSpropOptions>(defaults)) {
67     TORCH_CHECK(defaults.lr() >= 0, "Invalid learning rate: ", defaults.lr());
68     TORCH_CHECK(defaults.eps() >= 0, "Invalid epsilon value: ", defaults.eps());
69     TORCH_CHECK(
70         defaults.momentum() >= 0,
71         "Invalid momentum value: ",
72         defaults.momentum());
73     TORCH_CHECK(
74         defaults.weight_decay() >= 0,
75         "Invalid weight_decay value: ",
76         defaults.weight_decay());
77     TORCH_CHECK(
78         defaults.alpha() >= 0, "Invalid alpha value: ", defaults.alpha());
79   }
80 
81   explicit RMSprop(std::vector<Tensor> params, RMSpropOptions defaults = {})
82       : RMSprop({OptimizerParamGroup(std::move(params))}, defaults) {}
83 
84   torch::Tensor step(LossClosure closure = nullptr) override;
85   void save(serialize::OutputArchive& archive) const override;
86   void load(serialize::InputArchive& archive) override;
87 
88  private:
89   template <typename Self, typename Archive>
serialize(Self & self,Archive & archive)90   static void serialize(Self& self, Archive& archive) {
91     _TORCH_OPTIM_SERIALIZE_WITH_TEMPLATE_ARG(RMSprop);
92   }
93 };
94 } // namespace optim
95 } // namespace torch
96