1 #include <torch/optim/rmsprop.h>
2
3 #include <torch/csrc/autograd/variable.h>
4 #include <torch/serialize/archive.h>
5 #include <torch/utils.h>
6
7 #include <ATen/ATen.h>
8 #include <c10/util/irange.h>
9
10 #include <functional>
11
12 namespace torch {
13 namespace optim {
14
RMSpropOptions(double lr)15 RMSpropOptions::RMSpropOptions(double lr) : lr_(lr) {}
16
operator ==(const RMSpropOptions & lhs,const RMSpropOptions & rhs)17 bool operator==(const RMSpropOptions& lhs, const RMSpropOptions& rhs) {
18 return (lhs.lr() == rhs.lr()) && (lhs.alpha() == rhs.alpha()) &&
19 (lhs.eps() == rhs.eps()) && (lhs.weight_decay() == rhs.weight_decay()) &&
20 (lhs.momentum() == rhs.momentum()) && (lhs.centered() == rhs.centered());
21 }
22
serialize(torch::serialize::OutputArchive & archive) const23 void RMSpropOptions::serialize(torch::serialize::OutputArchive& archive) const {
24 _TORCH_OPTIM_SERIALIZE_TORCH_ARG(lr);
25 _TORCH_OPTIM_SERIALIZE_TORCH_ARG(alpha);
26 _TORCH_OPTIM_SERIALIZE_TORCH_ARG(eps);
27 _TORCH_OPTIM_SERIALIZE_TORCH_ARG(weight_decay);
28 _TORCH_OPTIM_SERIALIZE_TORCH_ARG(momentum);
29 _TORCH_OPTIM_SERIALIZE_TORCH_ARG(centered);
30 }
31
serialize(torch::serialize::InputArchive & archive)32 void RMSpropOptions::serialize(torch::serialize::InputArchive& archive) {
33 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, lr);
34 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, alpha);
35 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, eps);
36 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, weight_decay);
37 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, momentum);
38 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(bool, centered);
39 }
40
get_lr() const41 double RMSpropOptions::get_lr() const {
42 return lr();
43 }
44
set_lr(const double lr)45 void RMSpropOptions::set_lr(const double lr) {
46 this->lr(lr);
47 }
48
operator ==(const RMSpropParamState & lhs,const RMSpropParamState & rhs)49 bool operator==(const RMSpropParamState& lhs, const RMSpropParamState& rhs) {
50 return (lhs.step() == rhs.step()) &&
51 torch::equal(lhs.square_avg(), rhs.square_avg()) &&
52 torch::equal_if_defined(lhs.momentum_buffer(), rhs.momentum_buffer()) &&
53 torch::equal_if_defined(lhs.grad_avg(), rhs.grad_avg());
54 }
55
serialize(torch::serialize::OutputArchive & archive) const56 void RMSpropParamState::serialize(
57 torch::serialize::OutputArchive& archive) const {
58 _TORCH_OPTIM_SERIALIZE_TORCH_ARG(step);
59 _TORCH_OPTIM_SERIALIZE_TORCH_ARG(square_avg);
60 _TORCH_OPTIM_SERIALIZE_TORCH_ARG(momentum_buffer);
61 _TORCH_OPTIM_SERIALIZE_TORCH_ARG(grad_avg);
62 }
63
serialize(torch::serialize::InputArchive & archive)64 void RMSpropParamState::serialize(torch::serialize::InputArchive& archive) {
65 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(int64_t, step);
66 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(Tensor, square_avg);
67 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(Tensor, momentum_buffer);
68 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(Tensor, grad_avg);
69 }
70
71 /// Adapted from
72 /// https://github.com/pytorch/pytorch/blob/master/torch/optim/rmsprop.py
step(LossClosure closure)73 Tensor RMSprop::step(LossClosure closure) {
74 NoGradGuard no_grad;
75 Tensor loss = {};
76 if (closure != nullptr) {
77 at::AutoGradMode enable_grad(true);
78 loss = closure();
79 }
80 for (auto& group : param_groups_) {
81 for (auto& p : group.params()) {
82 if (!p.grad().defined()) {
83 continue;
84 }
85 auto grad = p.grad();
86 TORCH_CHECK(
87 !grad.is_sparse(), "RMSprop does not support sparse gradients");
88 auto param_state = state_.find(p.unsafeGetTensorImpl());
89 auto& options = static_cast<RMSpropOptions&>(group.options());
90
91 // State initialization
92 if (param_state == state_.end()) {
93 auto state = std::make_unique<RMSpropParamState>();
94 state->step(0);
95 state->square_avg(torch::zeros_like(p, MemoryFormat::Preserve));
96 if (options.momentum() > 0) {
97 state->momentum_buffer(torch::zeros_like(p, MemoryFormat::Preserve));
98 }
99 if (options.centered()) {
100 state->grad_avg(torch::zeros_like(p, MemoryFormat::Preserve));
101 }
102 state_[p.unsafeGetTensorImpl()] = std::move(state);
103 }
104
105 auto& state =
106 static_cast<RMSpropParamState&>(*state_[p.unsafeGetTensorImpl()]);
107 auto& square_avg = state.square_avg();
108 auto alpha = options.alpha();
109
110 state.step(state.step() + 1);
111
112 if (options.weight_decay() != 0) {
113 grad = grad.add(p, options.weight_decay());
114 }
115
116 square_avg.mul_(alpha).addcmul_(grad, grad, 1 - alpha);
117
118 Tensor avg;
119 if (options.centered()) {
120 auto& grad_avg = state.grad_avg();
121 grad_avg.mul_(alpha).add_(grad, 1 - alpha);
122 avg = square_avg.addcmul(grad_avg, grad_avg, -1)
123 .sqrt_()
124 .add_(options.eps());
125 } else {
126 avg = square_avg.sqrt().add_(options.eps());
127 }
128
129 if (options.momentum() > 0) {
130 auto& buf = state.momentum_buffer();
131 buf.mul_(options.momentum()).addcdiv_(grad, avg);
132 // Need to avoid version tracking for parameter.
133 p.add_(buf, -options.lr());
134 } else {
135 // Need to avoid version tracking for parameter.
136 p.addcdiv_(grad, avg, -options.lr());
137 }
138 }
139 }
140 return loss;
141 }
142
save(serialize::OutputArchive & archive) const143 void RMSprop::save(serialize::OutputArchive& archive) const {
144 serialize(*this, archive);
145 }
146
load(serialize::InputArchive & archive)147 void RMSprop::load(serialize::InputArchive& archive) {
148 IValue pytorch_version;
149 if (archive.try_read("pytorch_version", pytorch_version)) {
150 serialize(*this, archive);
151 } else { // deserializing archives saved in old format (prior to
152 // version 1.5.0)
153 TORCH_WARN(
154 "Your serialized RMSprop optimizer is still using the old serialization format. "
155 "The step value in state will be set to 0 because the old RMSprop optimizer didn't track the step value."
156 "You should re-save your RMSprop optimizer to use the new serialization format.");
157 std::vector<Tensor> square_average_buffers;
158 std::vector<Tensor> momentum_buffers;
159 std::vector<Tensor> grad_average_buffers;
160 torch::optim::serialize(
161 archive, "square_average_buffers", square_average_buffers);
162 torch::optim::serialize(archive, "momentum_buffers", momentum_buffers);
163 torch::optim::serialize(
164 archive, "grad_average_buffers", grad_average_buffers);
165 // since there were no param_groups prior to version 1.5.0, assuming all
166 // tensors are now in one param_group
167 std::vector<Tensor> params = param_groups_.at(0).params();
168 for (const auto idx : c10::irange(square_average_buffers.size())) {
169 auto state = std::make_unique<RMSpropParamState>();
170 state->square_avg(square_average_buffers[idx]);
171 if (idx < momentum_buffers.size()) {
172 state->momentum_buffer(momentum_buffers.at(idx));
173 }
174 if (idx < grad_average_buffers.size()) {
175 state->grad_avg(grad_average_buffers.at(idx));
176 }
177 state_[params[idx].unsafeGetTensorImpl()] = std::move(state);
178 }
179 }
180 }
181 } // namespace optim
182 } // namespace torch
183