xref: /aosp_15_r20/external/pytorch/test/cpp/api/optim.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include <gtest/gtest.h>
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <c10/util/irange.h>
4*da0073e9SAndroid Build Coastguard Worker #include <torch/torch.h>
5*da0073e9SAndroid Build Coastguard Worker 
6*da0073e9SAndroid Build Coastguard Worker #include <test/cpp/api/optim_baseline.h>
7*da0073e9SAndroid Build Coastguard Worker #include <test/cpp/api/support.h>
8*da0073e9SAndroid Build Coastguard Worker 
9*da0073e9SAndroid Build Coastguard Worker #include <cmath>
10*da0073e9SAndroid Build Coastguard Worker #include <cstdlib>
11*da0073e9SAndroid Build Coastguard Worker #include <functional>
12*da0073e9SAndroid Build Coastguard Worker #include <iostream>
13*da0073e9SAndroid Build Coastguard Worker #include <memory>
14*da0073e9SAndroid Build Coastguard Worker #include <random>
15*da0073e9SAndroid Build Coastguard Worker #include <vector>
16*da0073e9SAndroid Build Coastguard Worker 
17*da0073e9SAndroid Build Coastguard Worker using namespace torch::nn;
18*da0073e9SAndroid Build Coastguard Worker using namespace torch::optim;
19*da0073e9SAndroid Build Coastguard Worker 
20*da0073e9SAndroid Build Coastguard Worker template <typename OptimizerClass, typename Options>
test_optimizer_xor(Options options)21*da0073e9SAndroid Build Coastguard Worker bool test_optimizer_xor(Options options) {
22*da0073e9SAndroid Build Coastguard Worker   torch::manual_seed(0);
23*da0073e9SAndroid Build Coastguard Worker 
24*da0073e9SAndroid Build Coastguard Worker   Sequential model(
25*da0073e9SAndroid Build Coastguard Worker       Linear(2, 8),
26*da0073e9SAndroid Build Coastguard Worker       Functional(torch::sigmoid),
27*da0073e9SAndroid Build Coastguard Worker       Linear(8, 1),
28*da0073e9SAndroid Build Coastguard Worker       Functional(torch::sigmoid));
29*da0073e9SAndroid Build Coastguard Worker 
30*da0073e9SAndroid Build Coastguard Worker   const int64_t kBatchSize = 200;
31*da0073e9SAndroid Build Coastguard Worker   const int64_t kMaximumNumberOfEpochs = 3000;
32*da0073e9SAndroid Build Coastguard Worker 
33*da0073e9SAndroid Build Coastguard Worker   OptimizerClass optimizer(model->parameters(), options);
34*da0073e9SAndroid Build Coastguard Worker 
35*da0073e9SAndroid Build Coastguard Worker   float running_loss = 1;
36*da0073e9SAndroid Build Coastguard Worker   int epoch = 0;
37*da0073e9SAndroid Build Coastguard Worker   while (running_loss > 0.1) {
38*da0073e9SAndroid Build Coastguard Worker     auto inputs = torch::empty({kBatchSize, 2});
39*da0073e9SAndroid Build Coastguard Worker     auto labels = torch::empty({kBatchSize});
40*da0073e9SAndroid Build Coastguard Worker     for (const auto i : c10::irange(kBatchSize)) {
41*da0073e9SAndroid Build Coastguard Worker       inputs[i] = torch::randint(2, {2}, torch::kInt64);
42*da0073e9SAndroid Build Coastguard Worker       labels[i] = inputs[i][0].item<int64_t>() ^ inputs[i][1].item<int64_t>();
43*da0073e9SAndroid Build Coastguard Worker     }
44*da0073e9SAndroid Build Coastguard Worker 
45*da0073e9SAndroid Build Coastguard Worker     inputs.set_requires_grad(true);
46*da0073e9SAndroid Build Coastguard Worker 
47*da0073e9SAndroid Build Coastguard Worker     auto step = [&](OptimizerClass& optimizer,
48*da0073e9SAndroid Build Coastguard Worker                     Sequential model,
49*da0073e9SAndroid Build Coastguard Worker                     torch::Tensor inputs,
50*da0073e9SAndroid Build Coastguard Worker                     torch::Tensor labels) {
51*da0073e9SAndroid Build Coastguard Worker       auto closure = [&]() {
52*da0073e9SAndroid Build Coastguard Worker         optimizer.zero_grad();
53*da0073e9SAndroid Build Coastguard Worker         auto x = model->forward(inputs);
54*da0073e9SAndroid Build Coastguard Worker         auto loss = torch::binary_cross_entropy(x, labels);
55*da0073e9SAndroid Build Coastguard Worker         loss.backward();
56*da0073e9SAndroid Build Coastguard Worker         return loss;
57*da0073e9SAndroid Build Coastguard Worker       };
58*da0073e9SAndroid Build Coastguard Worker       return optimizer.step(closure);
59*da0073e9SAndroid Build Coastguard Worker     };
60*da0073e9SAndroid Build Coastguard Worker 
61*da0073e9SAndroid Build Coastguard Worker     torch::Tensor loss = step(optimizer, model, inputs, labels);
62*da0073e9SAndroid Build Coastguard Worker 
63*da0073e9SAndroid Build Coastguard Worker     // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,bugprone-narrowing-conversions)
64*da0073e9SAndroid Build Coastguard Worker     running_loss = running_loss * 0.99 + loss.item<float>() * 0.01;
65*da0073e9SAndroid Build Coastguard Worker     if (epoch > kMaximumNumberOfEpochs) {
66*da0073e9SAndroid Build Coastguard Worker       std::cout << "Loss is too high after epoch " << epoch << ": "
67*da0073e9SAndroid Build Coastguard Worker                 << running_loss << std::endl;
68*da0073e9SAndroid Build Coastguard Worker       return false;
69*da0073e9SAndroid Build Coastguard Worker     }
70*da0073e9SAndroid Build Coastguard Worker     epoch++;
71*da0073e9SAndroid Build Coastguard Worker   }
72*da0073e9SAndroid Build Coastguard Worker   return true;
73*da0073e9SAndroid Build Coastguard Worker }
74*da0073e9SAndroid Build Coastguard Worker 
75*da0073e9SAndroid Build Coastguard Worker template <typename Parameters>
assign_parameter(const Parameters & parameters,const char * name,torch::Tensor new_tensor)76*da0073e9SAndroid Build Coastguard Worker void assign_parameter(
77*da0073e9SAndroid Build Coastguard Worker     const Parameters& parameters,
78*da0073e9SAndroid Build Coastguard Worker     const char* name,
79*da0073e9SAndroid Build Coastguard Worker     torch::Tensor new_tensor) {
80*da0073e9SAndroid Build Coastguard Worker   auto parameter = parameters[name];
81*da0073e9SAndroid Build Coastguard Worker   parameter.set_requires_grad(false);
82*da0073e9SAndroid Build Coastguard Worker   parameter.flatten().copy_(new_tensor);
83*da0073e9SAndroid Build Coastguard Worker   parameter.set_requires_grad(true);
84*da0073e9SAndroid Build Coastguard Worker }
85*da0073e9SAndroid Build Coastguard Worker 
86*da0073e9SAndroid Build Coastguard Worker template <typename OptimizerClass, typename Options>
check_exact_values(Options options,std::vector<std::vector<torch::Tensor>> expected_parameters)87*da0073e9SAndroid Build Coastguard Worker void check_exact_values(
88*da0073e9SAndroid Build Coastguard Worker     Options options,
89*da0073e9SAndroid Build Coastguard Worker     std::vector<std::vector<torch::Tensor>> expected_parameters) {
90*da0073e9SAndroid Build Coastguard Worker   const size_t kIterations = 1001;
91*da0073e9SAndroid Build Coastguard Worker   const size_t kSampleEvery = 100;
92*da0073e9SAndroid Build Coastguard Worker 
93*da0073e9SAndroid Build Coastguard Worker   torch::manual_seed(0);
94*da0073e9SAndroid Build Coastguard Worker 
95*da0073e9SAndroid Build Coastguard Worker   Sequential model(
96*da0073e9SAndroid Build Coastguard Worker       Linear(2, 3),
97*da0073e9SAndroid Build Coastguard Worker       Functional(torch::sigmoid),
98*da0073e9SAndroid Build Coastguard Worker       Linear(3, 1),
99*da0073e9SAndroid Build Coastguard Worker       Functional(torch::sigmoid));
100*da0073e9SAndroid Build Coastguard Worker 
101*da0073e9SAndroid Build Coastguard Worker   model->to(torch::kFloat64);
102*da0073e9SAndroid Build Coastguard Worker 
103*da0073e9SAndroid Build Coastguard Worker   // Use exact input values because matching random values is hard.
104*da0073e9SAndroid Build Coastguard Worker   auto parameters = model->named_parameters();
105*da0073e9SAndroid Build Coastguard Worker   assign_parameter(
106*da0073e9SAndroid Build Coastguard Worker       parameters,
107*da0073e9SAndroid Build Coastguard Worker       "0.weight",
108*da0073e9SAndroid Build Coastguard Worker       torch::tensor(
109*da0073e9SAndroid Build Coastguard Worker           {-0.2109, -0.4976, -0.1413, -0.3420, -0.2524, 0.6976},
110*da0073e9SAndroid Build Coastguard Worker           torch::kFloat64));
111*da0073e9SAndroid Build Coastguard Worker   assign_parameter(
112*da0073e9SAndroid Build Coastguard Worker       parameters,
113*da0073e9SAndroid Build Coastguard Worker       "0.bias",
114*da0073e9SAndroid Build Coastguard Worker       torch::tensor({-0.1085, -0.2979, 0.6892}, torch::kFloat64));
115*da0073e9SAndroid Build Coastguard Worker   assign_parameter(
116*da0073e9SAndroid Build Coastguard Worker       parameters,
117*da0073e9SAndroid Build Coastguard Worker       "2.weight",
118*da0073e9SAndroid Build Coastguard Worker       torch::tensor({-0.0508, -0.3941, -0.2843}, torch::kFloat64));
119*da0073e9SAndroid Build Coastguard Worker   assign_parameter(
120*da0073e9SAndroid Build Coastguard Worker       parameters, "2.bias", torch::tensor({-0.0711}, torch::kFloat64));
121*da0073e9SAndroid Build Coastguard Worker 
122*da0073e9SAndroid Build Coastguard Worker   auto optimizer = OptimizerClass(parameters.values(), options);
123*da0073e9SAndroid Build Coastguard Worker   torch::Tensor input =
124*da0073e9SAndroid Build Coastguard Worker       torch::tensor({0.1, 0.2, 0.3, 0.4, 0.5, 0.6}, torch::kFloat64)
125*da0073e9SAndroid Build Coastguard Worker           .reshape({3, 2});
126*da0073e9SAndroid Build Coastguard Worker 
127*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(kIterations)) {
128*da0073e9SAndroid Build Coastguard Worker     optimizer.zero_grad();
129*da0073e9SAndroid Build Coastguard Worker     auto output = model->forward(input);
130*da0073e9SAndroid Build Coastguard Worker     auto loss = output.sum();
131*da0073e9SAndroid Build Coastguard Worker     loss.backward();
132*da0073e9SAndroid Build Coastguard Worker 
133*da0073e9SAndroid Build Coastguard Worker     auto closure = []() { return torch::tensor({10}); };
134*da0073e9SAndroid Build Coastguard Worker     optimizer.step(closure);
135*da0073e9SAndroid Build Coastguard Worker 
136*da0073e9SAndroid Build Coastguard Worker     if (i % kSampleEvery == 0) {
137*da0073e9SAndroid Build Coastguard Worker       ASSERT_TRUE(
138*da0073e9SAndroid Build Coastguard Worker           expected_parameters.at(i / kSampleEvery).size() == parameters.size());
139*da0073e9SAndroid Build Coastguard Worker       for (const auto p : c10::irange(parameters.size())) {
140*da0073e9SAndroid Build Coastguard Worker         ASSERT_TRUE(parameters[p]->defined());
141*da0073e9SAndroid Build Coastguard Worker         // Always compare using double dtype, regardless of the original dtype
142*da0073e9SAndroid Build Coastguard Worker         // of the tensors
143*da0073e9SAndroid Build Coastguard Worker         auto computed = parameters[p]->flatten().to(torch::kFloat64);
144*da0073e9SAndroid Build Coastguard Worker         auto expected =
145*da0073e9SAndroid Build Coastguard Worker             expected_parameters.at(i / kSampleEvery).at(p).to(torch::kFloat64);
146*da0073e9SAndroid Build Coastguard Worker         if (!computed.allclose(expected, /*rtol=*/1e-3, /*atol=*/5e-4)) {
147*da0073e9SAndroid Build Coastguard Worker           std::cout << "Iteration " << i << ": " << computed
148*da0073e9SAndroid Build Coastguard Worker                     << " != " << expected << " (parameter " << p << ")"
149*da0073e9SAndroid Build Coastguard Worker                     << std::endl;
150*da0073e9SAndroid Build Coastguard Worker           ASSERT_TRUE(false);
151*da0073e9SAndroid Build Coastguard Worker         }
152*da0073e9SAndroid Build Coastguard Worker       }
153*da0073e9SAndroid Build Coastguard Worker     }
154*da0073e9SAndroid Build Coastguard Worker   }
155*da0073e9SAndroid Build Coastguard Worker }
156*da0073e9SAndroid Build Coastguard Worker 
TEST(OptimTest,OptimizerAccessors)157*da0073e9SAndroid Build Coastguard Worker TEST(OptimTest, OptimizerAccessors) {
158*da0073e9SAndroid Build Coastguard Worker   auto options = AdagradOptions(1.0);
159*da0073e9SAndroid Build Coastguard Worker   std::vector<torch::Tensor> params;
160*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(3)) {
161*da0073e9SAndroid Build Coastguard Worker     (void)i; // Suppress unused variable warning
162*da0073e9SAndroid Build Coastguard Worker     params.push_back(torch::randn(10));
163*da0073e9SAndroid Build Coastguard Worker   }
164*da0073e9SAndroid Build Coastguard Worker   auto optimizer = Adagrad(params, options);
165*da0073e9SAndroid Build Coastguard Worker   // test for defaults() method with non-const reference
166*da0073e9SAndroid Build Coastguard Worker   auto& options_ = static_cast<AdagradOptions&>(optimizer.defaults());
167*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(options == options_);
168*da0073e9SAndroid Build Coastguard Worker   // test for param_groups() with non-const reference return
169*da0073e9SAndroid Build Coastguard Worker   auto& params_groups = optimizer.param_groups();
170*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(modernize-use-emplace)
171*da0073e9SAndroid Build Coastguard Worker   params_groups.push_back(OptimizerParamGroup(params));
172*da0073e9SAndroid Build Coastguard Worker   auto& params_1 = params_groups[1].params();
173*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(params_1.size())) {
174*da0073e9SAndroid Build Coastguard Worker     torch::equal(params[i], params_1[i]);
175*da0073e9SAndroid Build Coastguard Worker   }
176*da0073e9SAndroid Build Coastguard Worker 
177*da0073e9SAndroid Build Coastguard Worker   // test for add_param_group() when one or more params existing in another
178*da0073e9SAndroid Build Coastguard Worker   // param_group are passed in the new param group to be added
179*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
180*da0073e9SAndroid Build Coastguard Worker       optimizer.add_param_group(OptimizerParamGroup(params)),
181*da0073e9SAndroid Build Coastguard Worker       "some parameters appear in more than one parameter group");
182*da0073e9SAndroid Build Coastguard Worker 
183*da0073e9SAndroid Build Coastguard Worker   // test for state() with non-const reference return
184*da0073e9SAndroid Build Coastguard Worker   auto& state_ = static_cast<AdagradParamState&>(
185*da0073e9SAndroid Build Coastguard Worker       *(optimizer.state()[params_1[0].unsafeGetTensorImpl()]));
186*da0073e9SAndroid Build Coastguard Worker   state_.step(state_.step() + 1);
187*da0073e9SAndroid Build Coastguard Worker 
188*da0073e9SAndroid Build Coastguard Worker   const auto& optimizer_ = Adagrad(params, options);
189*da0073e9SAndroid Build Coastguard Worker   optimizer_.defaults();
190*da0073e9SAndroid Build Coastguard Worker   // test for param_groups() with const reference return
191*da0073e9SAndroid Build Coastguard Worker   (void)optimizer_.param_groups();
192*da0073e9SAndroid Build Coastguard Worker   // test for state() with const reference return
193*da0073e9SAndroid Build Coastguard Worker   optimizer_.state();
194*da0073e9SAndroid Build Coastguard Worker }
195*da0073e9SAndroid Build Coastguard Worker 
196*da0073e9SAndroid Build Coastguard Worker #define OLD_INTERFACE_WARNING_CHECK(func)       \
197*da0073e9SAndroid Build Coastguard Worker   {                                             \
198*da0073e9SAndroid Build Coastguard Worker     torch::test::WarningCapture warnings;       \
199*da0073e9SAndroid Build Coastguard Worker     func;                                       \
200*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(                                  \
201*da0073e9SAndroid Build Coastguard Worker         torch::test::count_substr_occurrences(  \
202*da0073e9SAndroid Build Coastguard Worker             warnings.str(), "will be removed"), \
203*da0073e9SAndroid Build Coastguard Worker         1);                                     \
204*da0073e9SAndroid Build Coastguard Worker   }
205*da0073e9SAndroid Build Coastguard Worker 
206*da0073e9SAndroid Build Coastguard Worker struct MyOptimizerOptions
207*da0073e9SAndroid Build Coastguard Worker     : public OptimizerCloneableOptions<MyOptimizerOptions> {
MyOptimizerOptionsMyOptimizerOptions208*da0073e9SAndroid Build Coastguard Worker   MyOptimizerOptions(double lr = 1.0) : lr_(lr){};
209*da0073e9SAndroid Build Coastguard Worker   TORCH_ARG(double, lr) = 1.0;
210*da0073e9SAndroid Build Coastguard Worker };
211*da0073e9SAndroid Build Coastguard Worker 
TEST(OptimTest,OldInterface)212*da0073e9SAndroid Build Coastguard Worker TEST(OptimTest, OldInterface) {
213*da0073e9SAndroid Build Coastguard Worker   struct MyOptimizer : Optimizer {
214*da0073e9SAndroid Build Coastguard Worker     using Optimizer::Optimizer;
215*da0073e9SAndroid Build Coastguard Worker     torch::Tensor step(LossClosure closure = nullptr) override {
216*da0073e9SAndroid Build Coastguard Worker       return {};
217*da0073e9SAndroid Build Coastguard Worker     }
218*da0073e9SAndroid Build Coastguard Worker     explicit MyOptimizer(
219*da0073e9SAndroid Build Coastguard Worker         std::vector<at::Tensor> params,
220*da0073e9SAndroid Build Coastguard Worker         MyOptimizerOptions defaults = {})
221*da0073e9SAndroid Build Coastguard Worker         : // NOLINTNEXTLINE(performance-move-const-arg)
222*da0073e9SAndroid Build Coastguard Worker           Optimizer(
223*da0073e9SAndroid Build Coastguard Worker               {std::move(OptimizerParamGroup(params))},
224*da0073e9SAndroid Build Coastguard Worker               std::make_unique<MyOptimizerOptions>(defaults)) {}
225*da0073e9SAndroid Build Coastguard Worker   };
226*da0073e9SAndroid Build Coastguard Worker   std::vector<torch::Tensor> parameters = {
227*da0073e9SAndroid Build Coastguard Worker       torch::ones({2, 3}), torch::zeros({2, 3}), torch::rand({2, 3})};
228*da0073e9SAndroid Build Coastguard Worker   {
229*da0073e9SAndroid Build Coastguard Worker     MyOptimizer optimizer(parameters);
230*da0073e9SAndroid Build Coastguard Worker     // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
231*da0073e9SAndroid Build Coastguard Worker     size_t size;
232*da0073e9SAndroid Build Coastguard Worker     OLD_INTERFACE_WARNING_CHECK(size = optimizer.size());
233*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(size, parameters.size());
234*da0073e9SAndroid Build Coastguard Worker   }
235*da0073e9SAndroid Build Coastguard Worker   {
236*da0073e9SAndroid Build Coastguard Worker     std::vector<at::Tensor> params;
237*da0073e9SAndroid Build Coastguard Worker     MyOptimizer optimizer(params);
238*da0073e9SAndroid Build Coastguard Worker 
239*da0073e9SAndroid Build Coastguard Worker     // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
240*da0073e9SAndroid Build Coastguard Worker     size_t size;
241*da0073e9SAndroid Build Coastguard Worker     OLD_INTERFACE_WARNING_CHECK(size = optimizer.size());
242*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(size, 0);
243*da0073e9SAndroid Build Coastguard Worker 
244*da0073e9SAndroid Build Coastguard Worker     OLD_INTERFACE_WARNING_CHECK(optimizer.add_parameters(parameters));
245*da0073e9SAndroid Build Coastguard Worker 
246*da0073e9SAndroid Build Coastguard Worker     OLD_INTERFACE_WARNING_CHECK(size = optimizer.size());
247*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(size, parameters.size());
248*da0073e9SAndroid Build Coastguard Worker 
249*da0073e9SAndroid Build Coastguard Worker     std::vector<torch::Tensor> params_;
250*da0073e9SAndroid Build Coastguard Worker     OLD_INTERFACE_WARNING_CHECK(params_ = optimizer.parameters());
251*da0073e9SAndroid Build Coastguard Worker     for (const auto p : c10::irange(size)) {
252*da0073e9SAndroid Build Coastguard Worker       ASSERT_TRUE(params_[p].allclose(parameters[p]));
253*da0073e9SAndroid Build Coastguard Worker     }
254*da0073e9SAndroid Build Coastguard Worker   }
255*da0073e9SAndroid Build Coastguard Worker   {
256*da0073e9SAndroid Build Coastguard Worker     Linear linear(3, 4);
257*da0073e9SAndroid Build Coastguard Worker     MyOptimizer optimizer(linear->parameters());
258*da0073e9SAndroid Build Coastguard Worker 
259*da0073e9SAndroid Build Coastguard Worker     // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
260*da0073e9SAndroid Build Coastguard Worker     size_t size;
261*da0073e9SAndroid Build Coastguard Worker     OLD_INTERFACE_WARNING_CHECK(size = optimizer.size());
262*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(size, linear->parameters().size());
263*da0073e9SAndroid Build Coastguard Worker   }
264*da0073e9SAndroid Build Coastguard Worker }
265*da0073e9SAndroid Build Coastguard Worker 
TEST(OptimTest,XORConvergence_SGD)266*da0073e9SAndroid Build Coastguard Worker TEST(OptimTest, XORConvergence_SGD) {
267*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(test_optimizer_xor<SGD>(
268*da0073e9SAndroid Build Coastguard Worker       SGDOptions(0.1).momentum(0.9).nesterov(true).weight_decay(1e-6)));
269*da0073e9SAndroid Build Coastguard Worker }
270*da0073e9SAndroid Build Coastguard Worker 
TEST(OptimTest,XORConvergence_LBFGS)271*da0073e9SAndroid Build Coastguard Worker TEST(OptimTest, XORConvergence_LBFGS) {
272*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(test_optimizer_xor<LBFGS>(LBFGSOptions(1.0)));
273*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(test_optimizer_xor<LBFGS>(
274*da0073e9SAndroid Build Coastguard Worker       LBFGSOptions(1.0).line_search_fn("strong_wolfe")));
275*da0073e9SAndroid Build Coastguard Worker }
276*da0073e9SAndroid Build Coastguard Worker 
TEST(OptimTest,XORConvergence_Adagrad)277*da0073e9SAndroid Build Coastguard Worker TEST(OptimTest, XORConvergence_Adagrad) {
278*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(test_optimizer_xor<Adagrad>(
279*da0073e9SAndroid Build Coastguard Worker       AdagradOptions(1.0).weight_decay(1e-6).lr_decay(1e-3)));
280*da0073e9SAndroid Build Coastguard Worker }
281*da0073e9SAndroid Build Coastguard Worker 
TEST(OptimTest,XORConvergence_RMSprop)282*da0073e9SAndroid Build Coastguard Worker TEST(OptimTest, XORConvergence_RMSprop) {
283*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(test_optimizer_xor<RMSprop>(RMSpropOptions(0.1).centered(true)));
284*da0073e9SAndroid Build Coastguard Worker }
285*da0073e9SAndroid Build Coastguard Worker 
TEST(OptimTest,XORConvergence_RMSpropWithMomentum)286*da0073e9SAndroid Build Coastguard Worker TEST(OptimTest, XORConvergence_RMSpropWithMomentum) {
287*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(test_optimizer_xor<RMSprop>(
288*da0073e9SAndroid Build Coastguard Worker       RMSpropOptions(0.1).momentum(0.9).weight_decay(1e-6)));
289*da0073e9SAndroid Build Coastguard Worker }
290*da0073e9SAndroid Build Coastguard Worker 
TEST(OptimTest,XORConvergence_Adam)291*da0073e9SAndroid Build Coastguard Worker TEST(OptimTest, XORConvergence_Adam) {
292*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(test_optimizer_xor<Adam>(AdamOptions(0.1).weight_decay(1e-6)));
293*da0073e9SAndroid Build Coastguard Worker }
294*da0073e9SAndroid Build Coastguard Worker 
TEST(OptimTest,XORConvergence_AdamWithAmsgrad)295*da0073e9SAndroid Build Coastguard Worker TEST(OptimTest, XORConvergence_AdamWithAmsgrad) {
296*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(test_optimizer_xor<Adam>(
297*da0073e9SAndroid Build Coastguard Worker       AdamOptions(0.1).weight_decay(1e-6).amsgrad(true)));
298*da0073e9SAndroid Build Coastguard Worker }
299*da0073e9SAndroid Build Coastguard Worker 
TEST(OptimTest,ProducesPyTorchValues_Adam)300*da0073e9SAndroid Build Coastguard Worker TEST(OptimTest, ProducesPyTorchValues_Adam) {
301*da0073e9SAndroid Build Coastguard Worker   check_exact_values<Adam>(AdamOptions(1.0), expected_parameters::Adam());
302*da0073e9SAndroid Build Coastguard Worker }
303*da0073e9SAndroid Build Coastguard Worker 
TEST(OptimTest,ProducesPyTorchValues_AdamWithWeightDecay)304*da0073e9SAndroid Build Coastguard Worker TEST(OptimTest, ProducesPyTorchValues_AdamWithWeightDecay) {
305*da0073e9SAndroid Build Coastguard Worker   check_exact_values<Adam>(
306*da0073e9SAndroid Build Coastguard Worker       AdamOptions(1.0).weight_decay(1e-2),
307*da0073e9SAndroid Build Coastguard Worker       expected_parameters::Adam_with_weight_decay());
308*da0073e9SAndroid Build Coastguard Worker }
309*da0073e9SAndroid Build Coastguard Worker 
TEST(OptimTest,ProducesPyTorchValues_AdamWithWeightDecayAndAMSGrad)310*da0073e9SAndroid Build Coastguard Worker TEST(OptimTest, ProducesPyTorchValues_AdamWithWeightDecayAndAMSGrad) {
311*da0073e9SAndroid Build Coastguard Worker   check_exact_values<Adam>(
312*da0073e9SAndroid Build Coastguard Worker       AdamOptions(1.0).weight_decay(1e-6).amsgrad(true),
313*da0073e9SAndroid Build Coastguard Worker       expected_parameters::Adam_with_weight_decay_and_amsgrad());
314*da0073e9SAndroid Build Coastguard Worker }
315*da0073e9SAndroid Build Coastguard Worker 
TEST(OptimTest,XORConvergence_AdamW)316*da0073e9SAndroid Build Coastguard Worker TEST(OptimTest, XORConvergence_AdamW) {
317*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(test_optimizer_xor<AdamW>(AdamWOptions(0.1)));
318*da0073e9SAndroid Build Coastguard Worker }
319*da0073e9SAndroid Build Coastguard Worker 
TEST(OptimTest,XORConvergence_AdamWWithAmsgrad)320*da0073e9SAndroid Build Coastguard Worker TEST(OptimTest, XORConvergence_AdamWWithAmsgrad) {
321*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(test_optimizer_xor<AdamW>(AdamWOptions(0.1).amsgrad(true)));
322*da0073e9SAndroid Build Coastguard Worker }
323*da0073e9SAndroid Build Coastguard Worker 
TEST(OptimTest,ProducesPyTorchValues_AdamW)324*da0073e9SAndroid Build Coastguard Worker TEST(OptimTest, ProducesPyTorchValues_AdamW) {
325*da0073e9SAndroid Build Coastguard Worker   check_exact_values<AdamW>(AdamWOptions(1.0), expected_parameters::AdamW());
326*da0073e9SAndroid Build Coastguard Worker }
327*da0073e9SAndroid Build Coastguard Worker 
TEST(OptimTest,ProducesPyTorchValues_AdamWWithoutWeightDecay)328*da0073e9SAndroid Build Coastguard Worker TEST(OptimTest, ProducesPyTorchValues_AdamWWithoutWeightDecay) {
329*da0073e9SAndroid Build Coastguard Worker   check_exact_values<AdamW>(
330*da0073e9SAndroid Build Coastguard Worker       AdamWOptions(1.0).weight_decay(0),
331*da0073e9SAndroid Build Coastguard Worker       expected_parameters::AdamW_without_weight_decay());
332*da0073e9SAndroid Build Coastguard Worker }
333*da0073e9SAndroid Build Coastguard Worker 
TEST(OptimTest,ProducesPyTorchValues_AdamWWithAMSGrad)334*da0073e9SAndroid Build Coastguard Worker TEST(OptimTest, ProducesPyTorchValues_AdamWWithAMSGrad) {
335*da0073e9SAndroid Build Coastguard Worker   check_exact_values<AdamW>(
336*da0073e9SAndroid Build Coastguard Worker       AdamWOptions(1.0).amsgrad(true),
337*da0073e9SAndroid Build Coastguard Worker       expected_parameters::AdamW_with_amsgrad());
338*da0073e9SAndroid Build Coastguard Worker }
339*da0073e9SAndroid Build Coastguard Worker 
TEST(OptimTest,ProducesPyTorchValues_Adagrad)340*da0073e9SAndroid Build Coastguard Worker TEST(OptimTest, ProducesPyTorchValues_Adagrad) {
341*da0073e9SAndroid Build Coastguard Worker   check_exact_values<Adagrad>(
342*da0073e9SAndroid Build Coastguard Worker       AdagradOptions(1.0), expected_parameters::Adagrad());
343*da0073e9SAndroid Build Coastguard Worker }
344*da0073e9SAndroid Build Coastguard Worker 
TEST(OptimTest,ProducesPyTorchValues_AdagradWithWeightDecay)345*da0073e9SAndroid Build Coastguard Worker TEST(OptimTest, ProducesPyTorchValues_AdagradWithWeightDecay) {
346*da0073e9SAndroid Build Coastguard Worker   check_exact_values<Adagrad>(
347*da0073e9SAndroid Build Coastguard Worker       AdagradOptions(1.0).weight_decay(1e-2),
348*da0073e9SAndroid Build Coastguard Worker       expected_parameters::Adagrad_with_weight_decay());
349*da0073e9SAndroid Build Coastguard Worker }
350*da0073e9SAndroid Build Coastguard Worker 
TEST(OptimTest,ProducesPyTorchValues_AdagradWithWeightDecayAndLRDecay)351*da0073e9SAndroid Build Coastguard Worker TEST(OptimTest, ProducesPyTorchValues_AdagradWithWeightDecayAndLRDecay) {
352*da0073e9SAndroid Build Coastguard Worker   check_exact_values<Adagrad>(
353*da0073e9SAndroid Build Coastguard Worker       AdagradOptions(1.0).weight_decay(1e-6).lr_decay(1e-3),
354*da0073e9SAndroid Build Coastguard Worker       expected_parameters::Adagrad_with_weight_decay_and_lr_decay());
355*da0073e9SAndroid Build Coastguard Worker }
356*da0073e9SAndroid Build Coastguard Worker 
TEST(OptimTest,ProducesPyTorchValues_RMSprop)357*da0073e9SAndroid Build Coastguard Worker TEST(OptimTest, ProducesPyTorchValues_RMSprop) {
358*da0073e9SAndroid Build Coastguard Worker   check_exact_values<RMSprop>(
359*da0073e9SAndroid Build Coastguard Worker       RMSpropOptions(0.1), expected_parameters::RMSprop());
360*da0073e9SAndroid Build Coastguard Worker }
361*da0073e9SAndroid Build Coastguard Worker 
TEST(OptimTest,ProducesPyTorchValues_RMSpropWithWeightDecay)362*da0073e9SAndroid Build Coastguard Worker TEST(OptimTest, ProducesPyTorchValues_RMSpropWithWeightDecay) {
363*da0073e9SAndroid Build Coastguard Worker   check_exact_values<RMSprop>(
364*da0073e9SAndroid Build Coastguard Worker       RMSpropOptions(0.1).weight_decay(1e-2),
365*da0073e9SAndroid Build Coastguard Worker       expected_parameters::RMSprop_with_weight_decay());
366*da0073e9SAndroid Build Coastguard Worker }
367*da0073e9SAndroid Build Coastguard Worker 
TEST(OptimTest,ProducesPyTorchValues_RMSpropWithWeightDecayAndCentered)368*da0073e9SAndroid Build Coastguard Worker TEST(OptimTest, ProducesPyTorchValues_RMSpropWithWeightDecayAndCentered) {
369*da0073e9SAndroid Build Coastguard Worker   check_exact_values<RMSprop>(
370*da0073e9SAndroid Build Coastguard Worker       RMSpropOptions(0.1).weight_decay(1e-6).centered(true),
371*da0073e9SAndroid Build Coastguard Worker       expected_parameters::RMSprop_with_weight_decay_and_centered());
372*da0073e9SAndroid Build Coastguard Worker }
373*da0073e9SAndroid Build Coastguard Worker 
TEST(OptimTest,ProducesPyTorchValues_RMSpropWithWeightDecayAndCenteredAndMomentum)374*da0073e9SAndroid Build Coastguard Worker TEST(
375*da0073e9SAndroid Build Coastguard Worker     OptimTest,
376*da0073e9SAndroid Build Coastguard Worker     ProducesPyTorchValues_RMSpropWithWeightDecayAndCenteredAndMomentum) {
377*da0073e9SAndroid Build Coastguard Worker   check_exact_values<RMSprop>(
378*da0073e9SAndroid Build Coastguard Worker       RMSpropOptions(0.1).weight_decay(1e-6).centered(true).momentum(0.9),
379*da0073e9SAndroid Build Coastguard Worker       expected_parameters::
380*da0073e9SAndroid Build Coastguard Worker           RMSprop_with_weight_decay_and_centered_and_momentum());
381*da0073e9SAndroid Build Coastguard Worker }
382*da0073e9SAndroid Build Coastguard Worker 
TEST(OptimTest,ProducesPyTorchValues_SGD)383*da0073e9SAndroid Build Coastguard Worker TEST(OptimTest, ProducesPyTorchValues_SGD) {
384*da0073e9SAndroid Build Coastguard Worker   check_exact_values<SGD>(SGDOptions(0.1), expected_parameters::SGD());
385*da0073e9SAndroid Build Coastguard Worker }
386*da0073e9SAndroid Build Coastguard Worker 
TEST(OptimTest,ProducesPyTorchValues_SGDWithWeightDecay)387*da0073e9SAndroid Build Coastguard Worker TEST(OptimTest, ProducesPyTorchValues_SGDWithWeightDecay) {
388*da0073e9SAndroid Build Coastguard Worker   check_exact_values<SGD>(
389*da0073e9SAndroid Build Coastguard Worker       SGDOptions(0.1).weight_decay(1e-2),
390*da0073e9SAndroid Build Coastguard Worker       expected_parameters::SGD_with_weight_decay());
391*da0073e9SAndroid Build Coastguard Worker }
392*da0073e9SAndroid Build Coastguard Worker 
TEST(OptimTest,ProducesPyTorchValues_SGDWithWeightDecayAndMomentum)393*da0073e9SAndroid Build Coastguard Worker TEST(OptimTest, ProducesPyTorchValues_SGDWithWeightDecayAndMomentum) {
394*da0073e9SAndroid Build Coastguard Worker   check_exact_values<SGD>(
395*da0073e9SAndroid Build Coastguard Worker       SGDOptions(0.1).weight_decay(1e-2).momentum(0.9),
396*da0073e9SAndroid Build Coastguard Worker       expected_parameters::SGD_with_weight_decay_and_momentum());
397*da0073e9SAndroid Build Coastguard Worker }
398*da0073e9SAndroid Build Coastguard Worker 
TEST(OptimTest,ProducesPyTorchValues_SGDWithWeightDecayAndNesterovMomentum)399*da0073e9SAndroid Build Coastguard Worker TEST(OptimTest, ProducesPyTorchValues_SGDWithWeightDecayAndNesterovMomentum) {
400*da0073e9SAndroid Build Coastguard Worker   check_exact_values<SGD>(
401*da0073e9SAndroid Build Coastguard Worker       SGDOptions(0.1).weight_decay(1e-6).momentum(0.9).nesterov(true),
402*da0073e9SAndroid Build Coastguard Worker       expected_parameters::SGD_with_weight_decay_and_nesterov_momentum());
403*da0073e9SAndroid Build Coastguard Worker }
404*da0073e9SAndroid Build Coastguard Worker 
TEST(OptimTest,ProducesPyTorchValues_LBFGS)405*da0073e9SAndroid Build Coastguard Worker TEST(OptimTest, ProducesPyTorchValues_LBFGS) {
406*da0073e9SAndroid Build Coastguard Worker   check_exact_values<LBFGS>(LBFGSOptions(1.0), expected_parameters::LBFGS());
407*da0073e9SAndroid Build Coastguard Worker }
408*da0073e9SAndroid Build Coastguard Worker 
TEST(OptimTest,ProducesPyTorchValues_LBFGS_with_line_search)409*da0073e9SAndroid Build Coastguard Worker TEST(OptimTest, ProducesPyTorchValues_LBFGS_with_line_search) {
410*da0073e9SAndroid Build Coastguard Worker   check_exact_values<LBFGS>(
411*da0073e9SAndroid Build Coastguard Worker       LBFGSOptions(1.0).line_search_fn("strong_wolfe"),
412*da0073e9SAndroid Build Coastguard Worker       expected_parameters::LBFGS_with_line_search());
413*da0073e9SAndroid Build Coastguard Worker }
414*da0073e9SAndroid Build Coastguard Worker 
TEST(OptimTest,ZeroGrad)415*da0073e9SAndroid Build Coastguard Worker TEST(OptimTest, ZeroGrad) {
416*da0073e9SAndroid Build Coastguard Worker   torch::manual_seed(0);
417*da0073e9SAndroid Build Coastguard Worker 
418*da0073e9SAndroid Build Coastguard Worker   Linear model(2, 8);
419*da0073e9SAndroid Build Coastguard Worker   SGD optimizer(model->parameters(), 0.1);
420*da0073e9SAndroid Build Coastguard Worker 
421*da0073e9SAndroid Build Coastguard Worker   for (const auto& parameter : model->parameters()) {
422*da0073e9SAndroid Build Coastguard Worker     ASSERT_FALSE(parameter.grad().defined());
423*da0073e9SAndroid Build Coastguard Worker   }
424*da0073e9SAndroid Build Coastguard Worker 
425*da0073e9SAndroid Build Coastguard Worker   auto output = model->forward(torch::ones({5, 2}));
426*da0073e9SAndroid Build Coastguard Worker   auto loss = output.sum();
427*da0073e9SAndroid Build Coastguard Worker   loss.backward();
428*da0073e9SAndroid Build Coastguard Worker 
429*da0073e9SAndroid Build Coastguard Worker   for (const auto& parameter : model->parameters()) {
430*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(parameter.grad().defined());
431*da0073e9SAndroid Build Coastguard Worker     ASSERT_GT(parameter.grad().sum().item<float>(), 0);
432*da0073e9SAndroid Build Coastguard Worker   }
433*da0073e9SAndroid Build Coastguard Worker 
434*da0073e9SAndroid Build Coastguard Worker   optimizer.zero_grad();
435*da0073e9SAndroid Build Coastguard Worker 
436*da0073e9SAndroid Build Coastguard Worker   for (const auto& parameter : model->parameters()) {
437*da0073e9SAndroid Build Coastguard Worker     ASSERT_FALSE(parameter.grad().defined());
438*da0073e9SAndroid Build Coastguard Worker   }
439*da0073e9SAndroid Build Coastguard Worker }
440*da0073e9SAndroid Build Coastguard Worker 
TEST(OptimTest,ExternalVectorOfParameters)441*da0073e9SAndroid Build Coastguard Worker TEST(OptimTest, ExternalVectorOfParameters) {
442*da0073e9SAndroid Build Coastguard Worker   torch::manual_seed(0);
443*da0073e9SAndroid Build Coastguard Worker 
444*da0073e9SAndroid Build Coastguard Worker   std::vector<torch::Tensor> parameters = {
445*da0073e9SAndroid Build Coastguard Worker       torch::randn({2, 2}), torch::randn({3, 3}), torch::randn({4, 4})};
446*da0073e9SAndroid Build Coastguard Worker   std::vector<torch::Tensor> original_parameters = {
447*da0073e9SAndroid Build Coastguard Worker       parameters[0].clone(), parameters[1].clone(), parameters[2].clone()};
448*da0073e9SAndroid Build Coastguard Worker 
449*da0073e9SAndroid Build Coastguard Worker   // Set all gradients to one
450*da0073e9SAndroid Build Coastguard Worker   for (auto& parameter : parameters) {
451*da0073e9SAndroid Build Coastguard Worker     parameter.mutable_grad() = torch::ones_like(parameter);
452*da0073e9SAndroid Build Coastguard Worker   }
453*da0073e9SAndroid Build Coastguard Worker 
454*da0073e9SAndroid Build Coastguard Worker   SGD optimizer(parameters, 1.0);
455*da0073e9SAndroid Build Coastguard Worker 
456*da0073e9SAndroid Build Coastguard Worker   optimizer.step();
457*da0073e9SAndroid Build Coastguard Worker 
458*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(parameters[0].allclose(original_parameters[0] - 1.0));
459*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(parameters[1].allclose(original_parameters[1] - 1.0));
460*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(parameters[2].allclose(original_parameters[2] - 1.0));
461*da0073e9SAndroid Build Coastguard Worker }
462*da0073e9SAndroid Build Coastguard Worker 
TEST(OptimTest,AddParameter_LBFGS)463*da0073e9SAndroid Build Coastguard Worker TEST(OptimTest, AddParameter_LBFGS) {
464*da0073e9SAndroid Build Coastguard Worker   torch::manual_seed(0);
465*da0073e9SAndroid Build Coastguard Worker 
466*da0073e9SAndroid Build Coastguard Worker   std::vector<torch::Tensor> parameters = {torch::randn({5, 5})};
467*da0073e9SAndroid Build Coastguard Worker   std::vector<torch::Tensor> original_parameters = {parameters[0].clone()};
468*da0073e9SAndroid Build Coastguard Worker 
469*da0073e9SAndroid Build Coastguard Worker   // Set all gradients to one
470*da0073e9SAndroid Build Coastguard Worker   for (auto& parameter : parameters) {
471*da0073e9SAndroid Build Coastguard Worker     parameter.mutable_grad() = torch::ones_like(parameter);
472*da0073e9SAndroid Build Coastguard Worker   }
473*da0073e9SAndroid Build Coastguard Worker 
474*da0073e9SAndroid Build Coastguard Worker   LBFGS optimizer(std::vector<torch::Tensor>{}, 1.0);
475*da0073e9SAndroid Build Coastguard Worker   OLD_INTERFACE_WARNING_CHECK(optimizer.add_parameters(parameters));
476*da0073e9SAndroid Build Coastguard Worker 
477*da0073e9SAndroid Build Coastguard Worker   optimizer.step([]() { return torch::tensor(1); });
478*da0073e9SAndroid Build Coastguard Worker 
479*da0073e9SAndroid Build Coastguard Worker   // REQUIRE this doesn't throw
480*da0073e9SAndroid Build Coastguard Worker }
481*da0073e9SAndroid Build Coastguard Worker 
482*da0073e9SAndroid Build Coastguard Worker // Check whether the learning rate of the parameter groups in the optimizer are
483*da0073e9SAndroid Build Coastguard Worker // the same as the expected learning rates given in the epoch:learning rate map
check_lr_change(Optimizer & optimizer,LRScheduler & lr_scheduler,std::map<unsigned,double> expected_epoch_lrs)484*da0073e9SAndroid Build Coastguard Worker void check_lr_change(
485*da0073e9SAndroid Build Coastguard Worker     Optimizer& optimizer,
486*da0073e9SAndroid Build Coastguard Worker     LRScheduler& lr_scheduler,
487*da0073e9SAndroid Build Coastguard Worker     std::map<unsigned, double> expected_epoch_lrs) {
488*da0073e9SAndroid Build Coastguard Worker   // Find maximum epoch in map
489*da0073e9SAndroid Build Coastguard Worker   unsigned kIterations = std::max_element(
490*da0073e9SAndroid Build Coastguard Worker                              expected_epoch_lrs.begin(),
491*da0073e9SAndroid Build Coastguard Worker                              expected_epoch_lrs.end(),
492*da0073e9SAndroid Build Coastguard Worker                              [](const std::pair<unsigned, double>& a,
493*da0073e9SAndroid Build Coastguard Worker                                 const std::pair<unsigned, double>& b) -> bool {
494*da0073e9SAndroid Build Coastguard Worker                                return a.second > b.second;
495*da0073e9SAndroid Build Coastguard Worker                              })
496*da0073e9SAndroid Build Coastguard Worker                              ->first;
497*da0073e9SAndroid Build Coastguard Worker 
498*da0073e9SAndroid Build Coastguard Worker   for (unsigned i = 0; i <= kIterations; i++) {
499*da0073e9SAndroid Build Coastguard Worker     const auto epoch_iter = expected_epoch_lrs.find(i);
500*da0073e9SAndroid Build Coastguard Worker     if (epoch_iter != expected_epoch_lrs.end()) {
501*da0073e9SAndroid Build Coastguard Worker       // Compare the similarity of the two floating point learning rates
502*da0073e9SAndroid Build Coastguard Worker       ASSERT_TRUE(
503*da0073e9SAndroid Build Coastguard Worker           fabs(
504*da0073e9SAndroid Build Coastguard Worker               epoch_iter->second -
505*da0073e9SAndroid Build Coastguard Worker               optimizer.param_groups()[0].options().get_lr()) <
506*da0073e9SAndroid Build Coastguard Worker           std::numeric_limits<double>::epsilon());
507*da0073e9SAndroid Build Coastguard Worker     }
508*da0073e9SAndroid Build Coastguard Worker     optimizer.step();
509*da0073e9SAndroid Build Coastguard Worker     lr_scheduler.step();
510*da0073e9SAndroid Build Coastguard Worker   }
511*da0073e9SAndroid Build Coastguard Worker }
512*da0073e9SAndroid Build Coastguard Worker 
513*da0073e9SAndroid Build Coastguard Worker // Very similar to check_lr_change, but for ReduceLROnPlateauScheduler
514*da0073e9SAndroid Build Coastguard Worker // which does not inherit from LRScheduler and requires a metrics
515*da0073e9SAndroid Build Coastguard Worker // input to step().
check_lr_change_for_reduce_on_plateau(Optimizer & optimizer,ReduceLROnPlateauScheduler & lr_scheduler,std::map<unsigned,double> expected_epoch_lrs)516*da0073e9SAndroid Build Coastguard Worker void check_lr_change_for_reduce_on_plateau(
517*da0073e9SAndroid Build Coastguard Worker     Optimizer& optimizer,
518*da0073e9SAndroid Build Coastguard Worker     ReduceLROnPlateauScheduler& lr_scheduler,
519*da0073e9SAndroid Build Coastguard Worker     std::map<unsigned, double> expected_epoch_lrs) {
520*da0073e9SAndroid Build Coastguard Worker   // Find maximum epoch in map
521*da0073e9SAndroid Build Coastguard Worker   unsigned kIterations = std::max_element(
522*da0073e9SAndroid Build Coastguard Worker                              expected_epoch_lrs.begin(),
523*da0073e9SAndroid Build Coastguard Worker                              expected_epoch_lrs.end(),
524*da0073e9SAndroid Build Coastguard Worker                              [](const std::pair<unsigned, double>& a,
525*da0073e9SAndroid Build Coastguard Worker                                 const std::pair<unsigned, double>& b) -> bool {
526*da0073e9SAndroid Build Coastguard Worker                                return a.second > b.second;
527*da0073e9SAndroid Build Coastguard Worker                              })
528*da0073e9SAndroid Build Coastguard Worker                              ->first;
529*da0073e9SAndroid Build Coastguard Worker 
530*da0073e9SAndroid Build Coastguard Worker   for (unsigned i = 0; i <= kIterations; i++) {
531*da0073e9SAndroid Build Coastguard Worker     const auto epoch_iter = expected_epoch_lrs.find(i);
532*da0073e9SAndroid Build Coastguard Worker     if (epoch_iter != expected_epoch_lrs.end()) {
533*da0073e9SAndroid Build Coastguard Worker       // Compare the similarity of the two floating point learning rates
534*da0073e9SAndroid Build Coastguard Worker       ASSERT_TRUE(
535*da0073e9SAndroid Build Coastguard Worker           fabs(
536*da0073e9SAndroid Build Coastguard Worker               epoch_iter->second -
537*da0073e9SAndroid Build Coastguard Worker               optimizer.param_groups()[0].options().get_lr()) <
538*da0073e9SAndroid Build Coastguard Worker           std::numeric_limits<double>::epsilon());
539*da0073e9SAndroid Build Coastguard Worker     }
540*da0073e9SAndroid Build Coastguard Worker     optimizer.step();
541*da0073e9SAndroid Build Coastguard Worker     lr_scheduler.step(5.0);
542*da0073e9SAndroid Build Coastguard Worker   }
543*da0073e9SAndroid Build Coastguard Worker }
544*da0073e9SAndroid Build Coastguard Worker 
TEST(OptimTest,CheckLRChange_StepLR_Adam)545*da0073e9SAndroid Build Coastguard Worker TEST(OptimTest, CheckLRChange_StepLR_Adam) {
546*da0073e9SAndroid Build Coastguard Worker   torch::Tensor parameters = torch::zeros({1});
547*da0073e9SAndroid Build Coastguard Worker   auto optimizer = Adam({parameters}, AdamOptions().lr(1e-3));
548*da0073e9SAndroid Build Coastguard Worker 
549*da0073e9SAndroid Build Coastguard Worker   const unsigned step_size = 20;
550*da0073e9SAndroid Build Coastguard Worker   const double gamma = 0.5;
551*da0073e9SAndroid Build Coastguard Worker   StepLR step_lr_scheduler(optimizer, step_size, gamma);
552*da0073e9SAndroid Build Coastguard Worker 
553*da0073e9SAndroid Build Coastguard Worker   // The learning rate should have halved at epoch 20
554*da0073e9SAndroid Build Coastguard Worker   const std::map<unsigned, double> expected_epoch_lrs = {{1, 1e-3}, {25, 5e-4}};
555*da0073e9SAndroid Build Coastguard Worker 
556*da0073e9SAndroid Build Coastguard Worker   check_lr_change(optimizer, step_lr_scheduler, expected_epoch_lrs);
557*da0073e9SAndroid Build Coastguard Worker }
558*da0073e9SAndroid Build Coastguard Worker 
TEST(OptimTest,CheckLRChange_ReduceLROnPlateau_Adam)559*da0073e9SAndroid Build Coastguard Worker TEST(OptimTest, CheckLRChange_ReduceLROnPlateau_Adam) {
560*da0073e9SAndroid Build Coastguard Worker   torch::Tensor parameters = torch::zeros({1});
561*da0073e9SAndroid Build Coastguard Worker   auto optimizer = Adam({parameters}, AdamOptions().lr(1e-3));
562*da0073e9SAndroid Build Coastguard Worker   const float factor = 0.5;
563*da0073e9SAndroid Build Coastguard Worker   const int patience = 20;
564*da0073e9SAndroid Build Coastguard Worker   ReduceLROnPlateauScheduler reduce_lr_on_plateau_scheduler(
565*da0073e9SAndroid Build Coastguard Worker       optimizer,
566*da0073e9SAndroid Build Coastguard Worker       ReduceLROnPlateauScheduler::SchedulerMode::min,
567*da0073e9SAndroid Build Coastguard Worker       factor,
568*da0073e9SAndroid Build Coastguard Worker       patience);
569*da0073e9SAndroid Build Coastguard Worker 
570*da0073e9SAndroid Build Coastguard Worker   // The learning rate should have halved at epoch 20
571*da0073e9SAndroid Build Coastguard Worker   const std::map<unsigned, double> expected_epoch_lrs = {{1, 1e-3}, {25, 5e-4}};
572*da0073e9SAndroid Build Coastguard Worker 
573*da0073e9SAndroid Build Coastguard Worker   check_lr_change_for_reduce_on_plateau(
574*da0073e9SAndroid Build Coastguard Worker       optimizer, reduce_lr_on_plateau_scheduler, expected_epoch_lrs);
575*da0073e9SAndroid Build Coastguard Worker }
576