xref: /aosp_15_r20/external/executorch/extension/training/optimizer/sgd.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/extension/training/optimizer/sgd.h>
10 
11 #include <executorch/runtime/core/error.h>
12 
13 using exec_aten::Tensor;
14 using exec_aten::TensorImpl;
15 using ::executorch::runtime::Error;
16 
17 namespace executorch {
18 namespace extension {
19 namespace training {
20 namespace optimizer {
21 
22 namespace {
add_out_hack(const Tensor & a,const Tensor & b,const double alpha,Tensor & out)23 void add_out_hack(
24     const Tensor& a,
25     const Tensor& b,
26     const double alpha,
27     Tensor& out) {
28   auto a_ptr = a.const_data_ptr<float>();
29   auto b_ptr = b.const_data_ptr<float>();
30   auto out_ptr = out.mutable_data_ptr<float>();
31   for (size_t i = 0; i < a.numel(); ++i) {
32     out_ptr[i] = a_ptr[i] + b_ptr[i] * alpha;
33   }
34 }
35 
mul_out_hack(const Tensor & a,const double alpha,Tensor & out)36 void mul_out_hack(const Tensor& a, const double alpha, Tensor& out) {
37   auto a_ptr = a.const_data_ptr<float>();
38   auto out_ptr = out.mutable_data_ptr<float>();
39   for (size_t i = 0; i < a.numel(); ++i) {
40     out_ptr[i] = a_ptr[i] * alpha;
41   }
42 }
43 
clone_out_hack(const Tensor & a,Tensor & out)44 void clone_out_hack(const Tensor& a, Tensor& out) {
45   auto a_ptr = a.const_data_ptr<float>();
46   auto out_ptr = out.mutable_data_ptr<float>();
47   for (size_t i = 0; i < a.numel(); ++i) {
48     out_ptr[i] = a_ptr[i];
49   }
50 }
51 } // namespace
52 
has_options() const53 bool SGDParamGroup::has_options() const {
54   return options_ != nullptr;
55 }
56 
options()57 SGDOptions& SGDParamGroup::options() {
58   return *options_.get();
59 }
60 
options() const61 const SGDOptions& SGDParamGroup::options() const {
62   return *options_.get();
63 }
64 
set_options(std::unique_ptr<SGDOptions> options)65 void SGDParamGroup::set_options(std::unique_ptr<SGDOptions> options) {
66   options_ = std::move(options);
67 }
68 
69 const std::map<exec_aten::string_view, exec_aten::Tensor>&
named_parameters() const70 SGDParamGroup::named_parameters() const {
71   return named_parameters_;
72 }
73 
add_param_group(const SGDParamGroup & param_group)74 void SGD::add_param_group(const SGDParamGroup& param_group) {
75   SGDParamGroup param_group_(param_group.named_parameters());
76   if (!param_group.has_options()) {
77     param_group_.set_options(defaults_->clone());
78   } else {
79     param_group_.set_options(param_group.options().clone());
80   }
81   param_groups_.emplace_back(std::move(param_group_));
82 }
83 
step(const std::map<exec_aten::string_view,exec_aten::Tensor> & named_gradients)84 Error SGD::step(const std::map<exec_aten::string_view, exec_aten::Tensor>&
85                     named_gradients) {
86   for (auto& group : param_groups_) {
87     auto& options = static_cast<SGDOptions&>(group.options());
88     auto weight_decay = options.weight_decay();
89     auto momentum = options.momentum();
90     auto dampening = options.dampening();
91     auto nesterov = options.nesterov();
92 
93     for (auto param_iter = group.named_parameters().begin();
94          param_iter != group.named_parameters().end();
95          ++param_iter) {
96       // if param name and gradient name match, run the optimizer step
97       const auto& named_gradient = named_gradients.find(param_iter->first);
98       if (named_gradient != named_gradients.end()) {
99         auto d_p = named_gradient->second;
100         auto p = param_iter->second;
101         if (weight_decay != 0) {
102           // uses weight_decay specified and adds it to the gradient
103           add_out_hack(d_p, p, weight_decay, d_p);
104         }
105         if (momentum != 0) {
106           Tensor buf(nullptr);
107           auto param_state = state_.find(p.unsafeGetTensorImpl());
108           // look for the momentum buffer for the given parameter. this is the
109           // momentum as of the previous epoch
110           if (param_state == state_.end()) {
111             // create a new momentum buffer if it doesn't exist. this memory
112             // needs to be freed when the optimizer is destroyed
113             void* buf_ptr = malloc(d_p.nbytes());
114 
115 #ifdef USE_ATEN_LIB
116             std::vector<int64_t> sizes(d_p.sizes().begin(), d_p.sizes().end());
117             buf = torch::from_blob(buf_ptr, sizes, d_p.scalar_type());
118 #else
119             TensorImpl* buf_impl = new TensorImpl(
120                 d_p.scalar_type(),
121                 d_p.sizes().size(),
122                 const_cast<TensorImpl::SizesType*>(d_p.sizes().data()),
123                 buf_ptr,
124                 const_cast<TensorImpl::DimOrderType*>(d_p.dim_order().data()));
125             buf = Tensor(buf_impl);
126 #endif
127             clone_out_hack(d_p, buf);
128 
129             // save the state of the momentum buffer to be reused in later
130             // epochs
131             auto state = std::make_unique<SGDParamState>(buf);
132             state_[p.unsafeGetTensorImpl()] = std::move(state);
133           } else {
134             buf = static_cast<SGDParamState&>(*param_state->second)
135                       .momentum_buffer();
136 
137             // update the momentum buffer and apply dampening
138             mul_out_hack(buf, momentum, buf);
139             add_out_hack(buf, d_p, 1 - dampening, buf);
140           }
141           if (nesterov) {
142             // apply nesterov momentum
143             add_out_hack(d_p, buf, momentum, d_p);
144           } else {
145             d_p = buf;
146           }
147         }
148         // update the parameter using the gradient and learning rate
149         add_out_hack(p, d_p, -1 * options.lr(), p);
150       }
151     }
152   }
153   return Error::Ok;
154 }
155 
~SGD()156 SGD::~SGD() {
157   for (const auto& state_kv : state_) {
158     auto state_tensor = static_cast<SGDParamState&>(*state_kv.second);
159     free(state_tensor.momentum_buffer().unsafeGetTensorImpl()->mutable_data());
160 #ifndef USE_ATEN_LIB
161     delete state_tensor.momentum_buffer().unsafeGetTensorImpl();
162 #endif
163   }
164 }
165 
166 } // namespace optimizer
167 } // namespace training
168 } // namespace extension
169 } // namespace executorch
170