xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/train/optim/sgd.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/mobile/train/optim/sgd.h>
2 
3 #include <torch/types.h>
4 #include <torch/utils.h>
5 
6 #include <ATen/ATen.h>
7 
8 namespace torch::jit::mobile {
9 
has_options() const10 bool SGDParamGroup::has_options() const {
11   return options_ != nullptr;
12 }
13 
options()14 SGDOptions& SGDParamGroup::options() {
15   TORCH_CHECK(has_options());
16   return *options_;
17 }
18 
options() const19 const SGDOptions& SGDParamGroup::options() const {
20   TORCH_CHECK(has_options());
21   return *options_;
22 }
23 
set_options(std::unique_ptr<SGDOptions> options)24 void SGDParamGroup::set_options(std::unique_ptr<SGDOptions> options) {
25   options_ = std::move(options);
26 }
27 
params()28 std::vector<Tensor>& SGDParamGroup::params() {
29   return params_;
30 }
31 
params() const32 const std::vector<Tensor>& SGDParamGroup::params() const {
33   return params_;
34 }
35 
SGDOptions(double lr)36 SGDOptions::SGDOptions(double lr) : lr_(lr) {}
37 
operator ==(const SGDOptions & lhs,const SGDOptions & rhs)38 bool operator==(const SGDOptions& lhs, const SGDOptions& rhs) {
39   return (lhs.lr() == rhs.lr()) && (lhs.momentum() == rhs.momentum()) &&
40       (lhs.dampening() == rhs.dampening()) &&
41       (lhs.weight_decay() == rhs.weight_decay()) &&
42       (lhs.nesterov() == rhs.nesterov());
43 }
44 
operator ==(const SGDParamState & lhs,const SGDParamState & rhs)45 bool operator==(const SGDParamState& lhs, const SGDParamState& rhs) {
46   return torch::equal(lhs.momentum_buffer(), rhs.momentum_buffer());
47 }
48 
add_param_group(const SGDParamGroup & param_group)49 void SGD::add_param_group(const SGDParamGroup& param_group) {
50   for (const auto& param : param_group.params()) {
51     TORCH_CHECK(param.is_leaf(), "can't optimize a non-leaf Tensor");
52   }
53   TORCH_INTERNAL_ASSERT(defaults_ != nullptr);
54   SGDParamGroup param_group_(param_group.params());
55   if (!param_group.has_options()) {
56     param_group_.set_options(defaults_->clone());
57   } else {
58     param_group_.set_options(param_group.options().clone());
59   }
60   for (const auto& p : param_group_.params()) {
61     TORCH_CHECK(
62         state_.count(p.unsafeGetTensorImpl()) == 0,
63         "some parameters appear in more than one parameter group");
64   }
65   param_groups_.emplace_back(std::move(param_group_));
66 }
67 
zero_grad()68 void SGD::zero_grad() {
69   for (auto& group : param_groups_) {
70     for (auto& p : group.params()) {
71       if (p.grad().defined()) {
72         p.grad().detach_();
73         p.grad().zero_();
74       }
75     }
76   }
77 }
78 
step(const LossClosure & closure)79 Tensor SGD::step(const LossClosure& closure) {
80   NoGradGuard no_grad;
81   Tensor loss = {};
82   if (closure != nullptr) {
83     at::AutoGradMode enable_grad(true);
84     loss = closure();
85   }
86   for (auto& group : param_groups_) {
87     auto& options = static_cast<SGDOptions&>(group.options());
88     auto weight_decay = options.weight_decay();
89     auto momentum = options.momentum();
90     auto dampening = options.dampening();
91     auto nesterov = options.nesterov();
92 
93     for (auto& p : group.params()) {
94       if (!p.grad().defined()) {
95         continue;
96       }
97       auto d_p = p.grad().data();
98       if (weight_decay != 0) {
99         d_p = d_p.add(p.data(), weight_decay);
100       }
101       if (momentum != 0) {
102         Tensor buf;
103         auto param_state = state_.find(p.unsafeGetTensorImpl());
104         if (param_state == state_.end()) {
105           buf = torch::clone(d_p).detach();
106           auto state = std::make_unique<SGDParamState>();
107           state->momentum_buffer(buf);
108           state_[p.unsafeGetTensorImpl()] = std::move(state);
109         } else {
110           buf = static_cast<SGDParamState&>(*param_state->second)
111                     .momentum_buffer();
112           buf.mul_(momentum).add_(d_p, 1 - dampening);
113         }
114         if (nesterov) {
115           d_p = d_p.add(buf, momentum);
116         } else {
117           d_p = buf;
118         }
119       }
120       p.data().add_(d_p, -1 * options.lr());
121     }
122   }
123   return loss;
124 }
125 } // namespace torch::jit::mobile
126