xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/optim/rmsprop.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/optim/rmsprop.h>
2 
3 #include <torch/csrc/autograd/variable.h>
4 #include <torch/serialize/archive.h>
5 #include <torch/utils.h>
6 
7 #include <ATen/ATen.h>
8 #include <c10/util/irange.h>
9 
10 #include <functional>
11 
12 namespace torch {
13 namespace optim {
14 
RMSpropOptions(double lr)15 RMSpropOptions::RMSpropOptions(double lr) : lr_(lr) {}
16 
operator ==(const RMSpropOptions & lhs,const RMSpropOptions & rhs)17 bool operator==(const RMSpropOptions& lhs, const RMSpropOptions& rhs) {
18   return (lhs.lr() == rhs.lr()) && (lhs.alpha() == rhs.alpha()) &&
19       (lhs.eps() == rhs.eps()) && (lhs.weight_decay() == rhs.weight_decay()) &&
20       (lhs.momentum() == rhs.momentum()) && (lhs.centered() == rhs.centered());
21 }
22 
serialize(torch::serialize::OutputArchive & archive) const23 void RMSpropOptions::serialize(torch::serialize::OutputArchive& archive) const {
24   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(lr);
25   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(alpha);
26   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(eps);
27   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(weight_decay);
28   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(momentum);
29   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(centered);
30 }
31 
serialize(torch::serialize::InputArchive & archive)32 void RMSpropOptions::serialize(torch::serialize::InputArchive& archive) {
33   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, lr);
34   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, alpha);
35   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, eps);
36   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, weight_decay);
37   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, momentum);
38   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(bool, centered);
39 }
40 
get_lr() const41 double RMSpropOptions::get_lr() const {
42   return lr();
43 }
44 
set_lr(const double lr)45 void RMSpropOptions::set_lr(const double lr) {
46   this->lr(lr);
47 }
48 
operator ==(const RMSpropParamState & lhs,const RMSpropParamState & rhs)49 bool operator==(const RMSpropParamState& lhs, const RMSpropParamState& rhs) {
50   return (lhs.step() == rhs.step()) &&
51       torch::equal(lhs.square_avg(), rhs.square_avg()) &&
52       torch::equal_if_defined(lhs.momentum_buffer(), rhs.momentum_buffer()) &&
53       torch::equal_if_defined(lhs.grad_avg(), rhs.grad_avg());
54 }
55 
serialize(torch::serialize::OutputArchive & archive) const56 void RMSpropParamState::serialize(
57     torch::serialize::OutputArchive& archive) const {
58   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(step);
59   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(square_avg);
60   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(momentum_buffer);
61   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(grad_avg);
62 }
63 
serialize(torch::serialize::InputArchive & archive)64 void RMSpropParamState::serialize(torch::serialize::InputArchive& archive) {
65   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(int64_t, step);
66   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(Tensor, square_avg);
67   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(Tensor, momentum_buffer);
68   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(Tensor, grad_avg);
69 }
70 
71 /// Adapted from
72 /// https://github.com/pytorch/pytorch/blob/master/torch/optim/rmsprop.py
step(LossClosure closure)73 Tensor RMSprop::step(LossClosure closure) {
74   NoGradGuard no_grad;
75   Tensor loss = {};
76   if (closure != nullptr) {
77     at::AutoGradMode enable_grad(true);
78     loss = closure();
79   }
80   for (auto& group : param_groups_) {
81     for (auto& p : group.params()) {
82       if (!p.grad().defined()) {
83         continue;
84       }
85       auto grad = p.grad();
86       TORCH_CHECK(
87           !grad.is_sparse(), "RMSprop does not support sparse gradients");
88       auto param_state = state_.find(p.unsafeGetTensorImpl());
89       auto& options = static_cast<RMSpropOptions&>(group.options());
90 
91       // State initialization
92       if (param_state == state_.end()) {
93         auto state = std::make_unique<RMSpropParamState>();
94         state->step(0);
95         state->square_avg(torch::zeros_like(p, MemoryFormat::Preserve));
96         if (options.momentum() > 0) {
97           state->momentum_buffer(torch::zeros_like(p, MemoryFormat::Preserve));
98         }
99         if (options.centered()) {
100           state->grad_avg(torch::zeros_like(p, MemoryFormat::Preserve));
101         }
102         state_[p.unsafeGetTensorImpl()] = std::move(state);
103       }
104 
105       auto& state =
106           static_cast<RMSpropParamState&>(*state_[p.unsafeGetTensorImpl()]);
107       auto& square_avg = state.square_avg();
108       auto alpha = options.alpha();
109 
110       state.step(state.step() + 1);
111 
112       if (options.weight_decay() != 0) {
113         grad = grad.add(p, options.weight_decay());
114       }
115 
116       square_avg.mul_(alpha).addcmul_(grad, grad, 1 - alpha);
117 
118       Tensor avg;
119       if (options.centered()) {
120         auto& grad_avg = state.grad_avg();
121         grad_avg.mul_(alpha).add_(grad, 1 - alpha);
122         avg = square_avg.addcmul(grad_avg, grad_avg, -1)
123                   .sqrt_()
124                   .add_(options.eps());
125       } else {
126         avg = square_avg.sqrt().add_(options.eps());
127       }
128 
129       if (options.momentum() > 0) {
130         auto& buf = state.momentum_buffer();
131         buf.mul_(options.momentum()).addcdiv_(grad, avg);
132         // Need to avoid version tracking for parameter.
133         p.add_(buf, -options.lr());
134       } else {
135         // Need to avoid version tracking for parameter.
136         p.addcdiv_(grad, avg, -options.lr());
137       }
138     }
139   }
140   return loss;
141 }
142 
save(serialize::OutputArchive & archive) const143 void RMSprop::save(serialize::OutputArchive& archive) const {
144   serialize(*this, archive);
145 }
146 
load(serialize::InputArchive & archive)147 void RMSprop::load(serialize::InputArchive& archive) {
148   IValue pytorch_version;
149   if (archive.try_read("pytorch_version", pytorch_version)) {
150     serialize(*this, archive);
151   } else { // deserializing archives saved in old format (prior to
152            // version 1.5.0)
153     TORCH_WARN(
154         "Your serialized RMSprop optimizer is still using the old serialization format. "
155         "The step value in state will be set to 0 because the old RMSprop optimizer didn't track the step value."
156         "You should re-save your RMSprop optimizer to use the new serialization format.");
157     std::vector<Tensor> square_average_buffers;
158     std::vector<Tensor> momentum_buffers;
159     std::vector<Tensor> grad_average_buffers;
160     torch::optim::serialize(
161         archive, "square_average_buffers", square_average_buffers);
162     torch::optim::serialize(archive, "momentum_buffers", momentum_buffers);
163     torch::optim::serialize(
164         archive, "grad_average_buffers", grad_average_buffers);
165     // since there were no param_groups prior to version 1.5.0, assuming all
166     // tensors are now in one param_group
167     std::vector<Tensor> params = param_groups_.at(0).params();
168     for (const auto idx : c10::irange(square_average_buffers.size())) {
169       auto state = std::make_unique<RMSpropParamState>();
170       state->square_avg(square_average_buffers[idx]);
171       if (idx < momentum_buffers.size()) {
172         state->momentum_buffer(momentum_buffers.at(idx));
173       }
174       if (idx < grad_average_buffers.size()) {
175         state->grad_avg(grad_average_buffers.at(idx));
176       }
177       state_[params[idx].unsafeGetTensorImpl()] = std::move(state);
178     }
179   }
180 }
181 } // namespace optim
182 } // namespace torch
183