1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/extension/training/optimizer/sgd.h>
10
11 #include <executorch/runtime/core/error.h>
12
13 using exec_aten::Tensor;
14 using exec_aten::TensorImpl;
15 using ::executorch::runtime::Error;
16
17 namespace executorch {
18 namespace extension {
19 namespace training {
20 namespace optimizer {
21
22 namespace {
add_out_hack(const Tensor & a,const Tensor & b,const double alpha,Tensor & out)23 void add_out_hack(
24 const Tensor& a,
25 const Tensor& b,
26 const double alpha,
27 Tensor& out) {
28 auto a_ptr = a.const_data_ptr<float>();
29 auto b_ptr = b.const_data_ptr<float>();
30 auto out_ptr = out.mutable_data_ptr<float>();
31 for (size_t i = 0; i < a.numel(); ++i) {
32 out_ptr[i] = a_ptr[i] + b_ptr[i] * alpha;
33 }
34 }
35
mul_out_hack(const Tensor & a,const double alpha,Tensor & out)36 void mul_out_hack(const Tensor& a, const double alpha, Tensor& out) {
37 auto a_ptr = a.const_data_ptr<float>();
38 auto out_ptr = out.mutable_data_ptr<float>();
39 for (size_t i = 0; i < a.numel(); ++i) {
40 out_ptr[i] = a_ptr[i] * alpha;
41 }
42 }
43
clone_out_hack(const Tensor & a,Tensor & out)44 void clone_out_hack(const Tensor& a, Tensor& out) {
45 auto a_ptr = a.const_data_ptr<float>();
46 auto out_ptr = out.mutable_data_ptr<float>();
47 for (size_t i = 0; i < a.numel(); ++i) {
48 out_ptr[i] = a_ptr[i];
49 }
50 }
51 } // namespace
52
has_options() const53 bool SGDParamGroup::has_options() const {
54 return options_ != nullptr;
55 }
56
options()57 SGDOptions& SGDParamGroup::options() {
58 return *options_.get();
59 }
60
options() const61 const SGDOptions& SGDParamGroup::options() const {
62 return *options_.get();
63 }
64
set_options(std::unique_ptr<SGDOptions> options)65 void SGDParamGroup::set_options(std::unique_ptr<SGDOptions> options) {
66 options_ = std::move(options);
67 }
68
69 const std::map<exec_aten::string_view, exec_aten::Tensor>&
named_parameters() const70 SGDParamGroup::named_parameters() const {
71 return named_parameters_;
72 }
73
add_param_group(const SGDParamGroup & param_group)74 void SGD::add_param_group(const SGDParamGroup& param_group) {
75 SGDParamGroup param_group_(param_group.named_parameters());
76 if (!param_group.has_options()) {
77 param_group_.set_options(defaults_->clone());
78 } else {
79 param_group_.set_options(param_group.options().clone());
80 }
81 param_groups_.emplace_back(std::move(param_group_));
82 }
83
step(const std::map<exec_aten::string_view,exec_aten::Tensor> & named_gradients)84 Error SGD::step(const std::map<exec_aten::string_view, exec_aten::Tensor>&
85 named_gradients) {
86 for (auto& group : param_groups_) {
87 auto& options = static_cast<SGDOptions&>(group.options());
88 auto weight_decay = options.weight_decay();
89 auto momentum = options.momentum();
90 auto dampening = options.dampening();
91 auto nesterov = options.nesterov();
92
93 for (auto param_iter = group.named_parameters().begin();
94 param_iter != group.named_parameters().end();
95 ++param_iter) {
96 // if param name and gradient name match, run the optimizer step
97 const auto& named_gradient = named_gradients.find(param_iter->first);
98 if (named_gradient != named_gradients.end()) {
99 auto d_p = named_gradient->second;
100 auto p = param_iter->second;
101 if (weight_decay != 0) {
102 // uses weight_decay specified and adds it to the gradient
103 add_out_hack(d_p, p, weight_decay, d_p);
104 }
105 if (momentum != 0) {
106 Tensor buf(nullptr);
107 auto param_state = state_.find(p.unsafeGetTensorImpl());
108 // look for the momentum buffer for the given parameter. this is the
109 // momentum as of the previous epoch
110 if (param_state == state_.end()) {
111 // create a new momentum buffer if it doesn't exist. this memory
112 // needs to be freed when the optimizer is destroyed
113 void* buf_ptr = malloc(d_p.nbytes());
114
115 #ifdef USE_ATEN_LIB
116 std::vector<int64_t> sizes(d_p.sizes().begin(), d_p.sizes().end());
117 buf = torch::from_blob(buf_ptr, sizes, d_p.scalar_type());
118 #else
119 TensorImpl* buf_impl = new TensorImpl(
120 d_p.scalar_type(),
121 d_p.sizes().size(),
122 const_cast<TensorImpl::SizesType*>(d_p.sizes().data()),
123 buf_ptr,
124 const_cast<TensorImpl::DimOrderType*>(d_p.dim_order().data()));
125 buf = Tensor(buf_impl);
126 #endif
127 clone_out_hack(d_p, buf);
128
129 // save the state of the momentum buffer to be reused in later
130 // epochs
131 auto state = std::make_unique<SGDParamState>(buf);
132 state_[p.unsafeGetTensorImpl()] = std::move(state);
133 } else {
134 buf = static_cast<SGDParamState&>(*param_state->second)
135 .momentum_buffer();
136
137 // update the momentum buffer and apply dampening
138 mul_out_hack(buf, momentum, buf);
139 add_out_hack(buf, d_p, 1 - dampening, buf);
140 }
141 if (nesterov) {
142 // apply nesterov momentum
143 add_out_hack(d_p, buf, momentum, d_p);
144 } else {
145 d_p = buf;
146 }
147 }
148 // update the parameter using the gradient and learning rate
149 add_out_hack(p, d_p, -1 * options.lr(), p);
150 }
151 }
152 }
153 return Error::Ok;
154 }
155
~SGD()156 SGD::~SGD() {
157 for (const auto& state_kv : state_) {
158 auto state_tensor = static_cast<SGDParamState&>(*state_kv.second);
159 free(state_tensor.momentum_buffer().unsafeGetTensorImpl()->mutable_data());
160 #ifndef USE_ATEN_LIB
161 delete state_tensor.momentum_buffer().unsafeGetTensorImpl();
162 #endif
163 }
164 }
165
166 } // namespace optimizer
167 } // namespace training
168 } // namespace extension
169 } // namespace executorch
170