1 #pragma once 2 3 #include <torch/nn/module.h> 4 #include <torch/optim/optimizer.h> 5 #include <torch/optim/serialize.h> 6 #include <torch/serialize/archive.h> 7 #include <torch/types.h> 8 9 #include <functional> 10 #include <memory> 11 #include <string> 12 #include <vector> 13 14 namespace torch { 15 namespace serialize { 16 class OutputArchive; 17 class InputArchive; 18 } // namespace serialize 19 } // namespace torch 20 21 namespace torch { 22 namespace optim { 23 24 struct TORCH_API RMSpropOptions 25 : public OptimizerCloneableOptions<RMSpropOptions> { 26 RMSpropOptions(double lr = 1e-2); 27 TORCH_ARG(double, lr) = 1e-2; 28 TORCH_ARG(double, alpha) = 0.99; 29 TORCH_ARG(double, eps) = 1e-8; 30 TORCH_ARG(double, weight_decay) = 0; 31 TORCH_ARG(double, momentum) = 0; 32 TORCH_ARG(bool, centered) = false; 33 34 public: 35 void serialize(torch::serialize::InputArchive& archive) override; 36 void serialize(torch::serialize::OutputArchive& archive) const override; 37 TORCH_API friend bool operator==( 38 const RMSpropOptions& lhs, 39 const RMSpropOptions& rhs); 40 double get_lr() const override; 41 void set_lr(const double lr) override; 42 }; 43 44 struct TORCH_API RMSpropParamState 45 : public OptimizerCloneableParamState<RMSpropParamState> { 46 TORCH_ARG(int64_t, step) = 0; 47 TORCH_ARG(torch::Tensor, square_avg); 48 TORCH_ARG(torch::Tensor, momentum_buffer) = {}; 49 TORCH_ARG(torch::Tensor, grad_avg) = {}; 50 51 public: 52 void serialize(torch::serialize::InputArchive& archive) override; 53 void serialize(torch::serialize::OutputArchive& archive) const override; 54 TORCH_API friend bool operator==( 55 const RMSpropParamState& lhs, 56 const RMSpropParamState& rhs); 57 }; 58 59 class TORCH_API RMSprop : public Optimizer { 60 public: 61 explicit RMSprop( 62 std::vector<OptimizerParamGroup> param_groups, 63 RMSpropOptions defaults = {}) Optimizer(std::move (param_groups),std::make_unique<RMSpropOptions> (defaults))64 : Optimizer( 65 std::move(param_groups), 66 std::make_unique<RMSpropOptions>(defaults)) { 67 TORCH_CHECK(defaults.lr() >= 0, "Invalid learning rate: ", defaults.lr()); 68 TORCH_CHECK(defaults.eps() >= 0, "Invalid epsilon value: ", defaults.eps()); 69 TORCH_CHECK( 70 defaults.momentum() >= 0, 71 "Invalid momentum value: ", 72 defaults.momentum()); 73 TORCH_CHECK( 74 defaults.weight_decay() >= 0, 75 "Invalid weight_decay value: ", 76 defaults.weight_decay()); 77 TORCH_CHECK( 78 defaults.alpha() >= 0, "Invalid alpha value: ", defaults.alpha()); 79 } 80 81 explicit RMSprop(std::vector<Tensor> params, RMSpropOptions defaults = {}) 82 : RMSprop({OptimizerParamGroup(std::move(params))}, defaults) {} 83 84 torch::Tensor step(LossClosure closure = nullptr) override; 85 void save(serialize::OutputArchive& archive) const override; 86 void load(serialize::InputArchive& archive) override; 87 88 private: 89 template <typename Self, typename Archive> serialize(Self & self,Archive & archive)90 static void serialize(Self& self, Archive& archive) { 91 _TORCH_OPTIM_SERIALIZE_WITH_TEMPLATE_ARG(RMSprop); 92 } 93 }; 94 } // namespace optim 95 } // namespace torch 96