1 #include <torch/csrc/jit/mobile/train/optim/sgd.h>
2
3 #include <torch/types.h>
4 #include <torch/utils.h>
5
6 #include <ATen/ATen.h>
7
8 namespace torch::jit::mobile {
9
has_options() const10 bool SGDParamGroup::has_options() const {
11 return options_ != nullptr;
12 }
13
options()14 SGDOptions& SGDParamGroup::options() {
15 TORCH_CHECK(has_options());
16 return *options_;
17 }
18
options() const19 const SGDOptions& SGDParamGroup::options() const {
20 TORCH_CHECK(has_options());
21 return *options_;
22 }
23
set_options(std::unique_ptr<SGDOptions> options)24 void SGDParamGroup::set_options(std::unique_ptr<SGDOptions> options) {
25 options_ = std::move(options);
26 }
27
params()28 std::vector<Tensor>& SGDParamGroup::params() {
29 return params_;
30 }
31
params() const32 const std::vector<Tensor>& SGDParamGroup::params() const {
33 return params_;
34 }
35
SGDOptions(double lr)36 SGDOptions::SGDOptions(double lr) : lr_(lr) {}
37
operator ==(const SGDOptions & lhs,const SGDOptions & rhs)38 bool operator==(const SGDOptions& lhs, const SGDOptions& rhs) {
39 return (lhs.lr() == rhs.lr()) && (lhs.momentum() == rhs.momentum()) &&
40 (lhs.dampening() == rhs.dampening()) &&
41 (lhs.weight_decay() == rhs.weight_decay()) &&
42 (lhs.nesterov() == rhs.nesterov());
43 }
44
operator ==(const SGDParamState & lhs,const SGDParamState & rhs)45 bool operator==(const SGDParamState& lhs, const SGDParamState& rhs) {
46 return torch::equal(lhs.momentum_buffer(), rhs.momentum_buffer());
47 }
48
add_param_group(const SGDParamGroup & param_group)49 void SGD::add_param_group(const SGDParamGroup& param_group) {
50 for (const auto& param : param_group.params()) {
51 TORCH_CHECK(param.is_leaf(), "can't optimize a non-leaf Tensor");
52 }
53 TORCH_INTERNAL_ASSERT(defaults_ != nullptr);
54 SGDParamGroup param_group_(param_group.params());
55 if (!param_group.has_options()) {
56 param_group_.set_options(defaults_->clone());
57 } else {
58 param_group_.set_options(param_group.options().clone());
59 }
60 for (const auto& p : param_group_.params()) {
61 TORCH_CHECK(
62 state_.count(p.unsafeGetTensorImpl()) == 0,
63 "some parameters appear in more than one parameter group");
64 }
65 param_groups_.emplace_back(std::move(param_group_));
66 }
67
zero_grad()68 void SGD::zero_grad() {
69 for (auto& group : param_groups_) {
70 for (auto& p : group.params()) {
71 if (p.grad().defined()) {
72 p.grad().detach_();
73 p.grad().zero_();
74 }
75 }
76 }
77 }
78
step(const LossClosure & closure)79 Tensor SGD::step(const LossClosure& closure) {
80 NoGradGuard no_grad;
81 Tensor loss = {};
82 if (closure != nullptr) {
83 at::AutoGradMode enable_grad(true);
84 loss = closure();
85 }
86 for (auto& group : param_groups_) {
87 auto& options = static_cast<SGDOptions&>(group.options());
88 auto weight_decay = options.weight_decay();
89 auto momentum = options.momentum();
90 auto dampening = options.dampening();
91 auto nesterov = options.nesterov();
92
93 for (auto& p : group.params()) {
94 if (!p.grad().defined()) {
95 continue;
96 }
97 auto d_p = p.grad().data();
98 if (weight_decay != 0) {
99 d_p = d_p.add(p.data(), weight_decay);
100 }
101 if (momentum != 0) {
102 Tensor buf;
103 auto param_state = state_.find(p.unsafeGetTensorImpl());
104 if (param_state == state_.end()) {
105 buf = torch::clone(d_p).detach();
106 auto state = std::make_unique<SGDParamState>();
107 state->momentum_buffer(buf);
108 state_[p.unsafeGetTensorImpl()] = std::move(state);
109 } else {
110 buf = static_cast<SGDParamState&>(*param_state->second)
111 .momentum_buffer();
112 buf.mul_(momentum).add_(d_p, 1 - dampening);
113 }
114 if (nesterov) {
115 d_p = d_p.add(buf, momentum);
116 } else {
117 d_p = buf;
118 }
119 }
120 p.data().add_(d_p, -1 * options.lr());
121 }
122 }
123 return loss;
124 }
125 } // namespace torch::jit::mobile
126