xref: /aosp_15_r20/external/pytorch/test/cpp/api/optim.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <c10/util/irange.h>
4 #include <torch/torch.h>
5 
6 #include <test/cpp/api/optim_baseline.h>
7 #include <test/cpp/api/support.h>
8 
9 #include <cmath>
10 #include <cstdlib>
11 #include <functional>
12 #include <iostream>
13 #include <memory>
14 #include <random>
15 #include <vector>
16 
17 using namespace torch::nn;
18 using namespace torch::optim;
19 
20 template <typename OptimizerClass, typename Options>
test_optimizer_xor(Options options)21 bool test_optimizer_xor(Options options) {
22   torch::manual_seed(0);
23 
24   Sequential model(
25       Linear(2, 8),
26       Functional(torch::sigmoid),
27       Linear(8, 1),
28       Functional(torch::sigmoid));
29 
30   const int64_t kBatchSize = 200;
31   const int64_t kMaximumNumberOfEpochs = 3000;
32 
33   OptimizerClass optimizer(model->parameters(), options);
34 
35   float running_loss = 1;
36   int epoch = 0;
37   while (running_loss > 0.1) {
38     auto inputs = torch::empty({kBatchSize, 2});
39     auto labels = torch::empty({kBatchSize});
40     for (const auto i : c10::irange(kBatchSize)) {
41       inputs[i] = torch::randint(2, {2}, torch::kInt64);
42       labels[i] = inputs[i][0].item<int64_t>() ^ inputs[i][1].item<int64_t>();
43     }
44 
45     inputs.set_requires_grad(true);
46 
47     auto step = [&](OptimizerClass& optimizer,
48                     Sequential model,
49                     torch::Tensor inputs,
50                     torch::Tensor labels) {
51       auto closure = [&]() {
52         optimizer.zero_grad();
53         auto x = model->forward(inputs);
54         auto loss = torch::binary_cross_entropy(x, labels);
55         loss.backward();
56         return loss;
57       };
58       return optimizer.step(closure);
59     };
60 
61     torch::Tensor loss = step(optimizer, model, inputs, labels);
62 
63     // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,bugprone-narrowing-conversions)
64     running_loss = running_loss * 0.99 + loss.item<float>() * 0.01;
65     if (epoch > kMaximumNumberOfEpochs) {
66       std::cout << "Loss is too high after epoch " << epoch << ": "
67                 << running_loss << std::endl;
68       return false;
69     }
70     epoch++;
71   }
72   return true;
73 }
74 
75 template <typename Parameters>
assign_parameter(const Parameters & parameters,const char * name,torch::Tensor new_tensor)76 void assign_parameter(
77     const Parameters& parameters,
78     const char* name,
79     torch::Tensor new_tensor) {
80   auto parameter = parameters[name];
81   parameter.set_requires_grad(false);
82   parameter.flatten().copy_(new_tensor);
83   parameter.set_requires_grad(true);
84 }
85 
86 template <typename OptimizerClass, typename Options>
check_exact_values(Options options,std::vector<std::vector<torch::Tensor>> expected_parameters)87 void check_exact_values(
88     Options options,
89     std::vector<std::vector<torch::Tensor>> expected_parameters) {
90   const size_t kIterations = 1001;
91   const size_t kSampleEvery = 100;
92 
93   torch::manual_seed(0);
94 
95   Sequential model(
96       Linear(2, 3),
97       Functional(torch::sigmoid),
98       Linear(3, 1),
99       Functional(torch::sigmoid));
100 
101   model->to(torch::kFloat64);
102 
103   // Use exact input values because matching random values is hard.
104   auto parameters = model->named_parameters();
105   assign_parameter(
106       parameters,
107       "0.weight",
108       torch::tensor(
109           {-0.2109, -0.4976, -0.1413, -0.3420, -0.2524, 0.6976},
110           torch::kFloat64));
111   assign_parameter(
112       parameters,
113       "0.bias",
114       torch::tensor({-0.1085, -0.2979, 0.6892}, torch::kFloat64));
115   assign_parameter(
116       parameters,
117       "2.weight",
118       torch::tensor({-0.0508, -0.3941, -0.2843}, torch::kFloat64));
119   assign_parameter(
120       parameters, "2.bias", torch::tensor({-0.0711}, torch::kFloat64));
121 
122   auto optimizer = OptimizerClass(parameters.values(), options);
123   torch::Tensor input =
124       torch::tensor({0.1, 0.2, 0.3, 0.4, 0.5, 0.6}, torch::kFloat64)
125           .reshape({3, 2});
126 
127   for (const auto i : c10::irange(kIterations)) {
128     optimizer.zero_grad();
129     auto output = model->forward(input);
130     auto loss = output.sum();
131     loss.backward();
132 
133     auto closure = []() { return torch::tensor({10}); };
134     optimizer.step(closure);
135 
136     if (i % kSampleEvery == 0) {
137       ASSERT_TRUE(
138           expected_parameters.at(i / kSampleEvery).size() == parameters.size());
139       for (const auto p : c10::irange(parameters.size())) {
140         ASSERT_TRUE(parameters[p]->defined());
141         // Always compare using double dtype, regardless of the original dtype
142         // of the tensors
143         auto computed = parameters[p]->flatten().to(torch::kFloat64);
144         auto expected =
145             expected_parameters.at(i / kSampleEvery).at(p).to(torch::kFloat64);
146         if (!computed.allclose(expected, /*rtol=*/1e-3, /*atol=*/5e-4)) {
147           std::cout << "Iteration " << i << ": " << computed
148                     << " != " << expected << " (parameter " << p << ")"
149                     << std::endl;
150           ASSERT_TRUE(false);
151         }
152       }
153     }
154   }
155 }
156 
TEST(OptimTest,OptimizerAccessors)157 TEST(OptimTest, OptimizerAccessors) {
158   auto options = AdagradOptions(1.0);
159   std::vector<torch::Tensor> params;
160   for (const auto i : c10::irange(3)) {
161     (void)i; // Suppress unused variable warning
162     params.push_back(torch::randn(10));
163   }
164   auto optimizer = Adagrad(params, options);
165   // test for defaults() method with non-const reference
166   auto& options_ = static_cast<AdagradOptions&>(optimizer.defaults());
167   ASSERT_TRUE(options == options_);
168   // test for param_groups() with non-const reference return
169   auto& params_groups = optimizer.param_groups();
170   // NOLINTNEXTLINE(modernize-use-emplace)
171   params_groups.push_back(OptimizerParamGroup(params));
172   auto& params_1 = params_groups[1].params();
173   for (const auto i : c10::irange(params_1.size())) {
174     torch::equal(params[i], params_1[i]);
175   }
176 
177   // test for add_param_group() when one or more params existing in another
178   // param_group are passed in the new param group to be added
179   ASSERT_THROWS_WITH(
180       optimizer.add_param_group(OptimizerParamGroup(params)),
181       "some parameters appear in more than one parameter group");
182 
183   // test for state() with non-const reference return
184   auto& state_ = static_cast<AdagradParamState&>(
185       *(optimizer.state()[params_1[0].unsafeGetTensorImpl()]));
186   state_.step(state_.step() + 1);
187 
188   const auto& optimizer_ = Adagrad(params, options);
189   optimizer_.defaults();
190   // test for param_groups() with const reference return
191   (void)optimizer_.param_groups();
192   // test for state() with const reference return
193   optimizer_.state();
194 }
195 
196 #define OLD_INTERFACE_WARNING_CHECK(func)       \
197   {                                             \
198     torch::test::WarningCapture warnings;       \
199     func;                                       \
200     ASSERT_EQ(                                  \
201         torch::test::count_substr_occurrences(  \
202             warnings.str(), "will be removed"), \
203         1);                                     \
204   }
205 
206 struct MyOptimizerOptions
207     : public OptimizerCloneableOptions<MyOptimizerOptions> {
MyOptimizerOptionsMyOptimizerOptions208   MyOptimizerOptions(double lr = 1.0) : lr_(lr){};
209   TORCH_ARG(double, lr) = 1.0;
210 };
211 
TEST(OptimTest,OldInterface)212 TEST(OptimTest, OldInterface) {
213   struct MyOptimizer : Optimizer {
214     using Optimizer::Optimizer;
215     torch::Tensor step(LossClosure closure = nullptr) override {
216       return {};
217     }
218     explicit MyOptimizer(
219         std::vector<at::Tensor> params,
220         MyOptimizerOptions defaults = {})
221         : // NOLINTNEXTLINE(performance-move-const-arg)
222           Optimizer(
223               {std::move(OptimizerParamGroup(params))},
224               std::make_unique<MyOptimizerOptions>(defaults)) {}
225   };
226   std::vector<torch::Tensor> parameters = {
227       torch::ones({2, 3}), torch::zeros({2, 3}), torch::rand({2, 3})};
228   {
229     MyOptimizer optimizer(parameters);
230     // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
231     size_t size;
232     OLD_INTERFACE_WARNING_CHECK(size = optimizer.size());
233     ASSERT_EQ(size, parameters.size());
234   }
235   {
236     std::vector<at::Tensor> params;
237     MyOptimizer optimizer(params);
238 
239     // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
240     size_t size;
241     OLD_INTERFACE_WARNING_CHECK(size = optimizer.size());
242     ASSERT_EQ(size, 0);
243 
244     OLD_INTERFACE_WARNING_CHECK(optimizer.add_parameters(parameters));
245 
246     OLD_INTERFACE_WARNING_CHECK(size = optimizer.size());
247     ASSERT_EQ(size, parameters.size());
248 
249     std::vector<torch::Tensor> params_;
250     OLD_INTERFACE_WARNING_CHECK(params_ = optimizer.parameters());
251     for (const auto p : c10::irange(size)) {
252       ASSERT_TRUE(params_[p].allclose(parameters[p]));
253     }
254   }
255   {
256     Linear linear(3, 4);
257     MyOptimizer optimizer(linear->parameters());
258 
259     // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
260     size_t size;
261     OLD_INTERFACE_WARNING_CHECK(size = optimizer.size());
262     ASSERT_EQ(size, linear->parameters().size());
263   }
264 }
265 
TEST(OptimTest,XORConvergence_SGD)266 TEST(OptimTest, XORConvergence_SGD) {
267   ASSERT_TRUE(test_optimizer_xor<SGD>(
268       SGDOptions(0.1).momentum(0.9).nesterov(true).weight_decay(1e-6)));
269 }
270 
TEST(OptimTest,XORConvergence_LBFGS)271 TEST(OptimTest, XORConvergence_LBFGS) {
272   ASSERT_TRUE(test_optimizer_xor<LBFGS>(LBFGSOptions(1.0)));
273   ASSERT_TRUE(test_optimizer_xor<LBFGS>(
274       LBFGSOptions(1.0).line_search_fn("strong_wolfe")));
275 }
276 
TEST(OptimTest,XORConvergence_Adagrad)277 TEST(OptimTest, XORConvergence_Adagrad) {
278   ASSERT_TRUE(test_optimizer_xor<Adagrad>(
279       AdagradOptions(1.0).weight_decay(1e-6).lr_decay(1e-3)));
280 }
281 
TEST(OptimTest,XORConvergence_RMSprop)282 TEST(OptimTest, XORConvergence_RMSprop) {
283   ASSERT_TRUE(test_optimizer_xor<RMSprop>(RMSpropOptions(0.1).centered(true)));
284 }
285 
TEST(OptimTest,XORConvergence_RMSpropWithMomentum)286 TEST(OptimTest, XORConvergence_RMSpropWithMomentum) {
287   ASSERT_TRUE(test_optimizer_xor<RMSprop>(
288       RMSpropOptions(0.1).momentum(0.9).weight_decay(1e-6)));
289 }
290 
TEST(OptimTest,XORConvergence_Adam)291 TEST(OptimTest, XORConvergence_Adam) {
292   ASSERT_TRUE(test_optimizer_xor<Adam>(AdamOptions(0.1).weight_decay(1e-6)));
293 }
294 
TEST(OptimTest,XORConvergence_AdamWithAmsgrad)295 TEST(OptimTest, XORConvergence_AdamWithAmsgrad) {
296   ASSERT_TRUE(test_optimizer_xor<Adam>(
297       AdamOptions(0.1).weight_decay(1e-6).amsgrad(true)));
298 }
299 
TEST(OptimTest,ProducesPyTorchValues_Adam)300 TEST(OptimTest, ProducesPyTorchValues_Adam) {
301   check_exact_values<Adam>(AdamOptions(1.0), expected_parameters::Adam());
302 }
303 
TEST(OptimTest,ProducesPyTorchValues_AdamWithWeightDecay)304 TEST(OptimTest, ProducesPyTorchValues_AdamWithWeightDecay) {
305   check_exact_values<Adam>(
306       AdamOptions(1.0).weight_decay(1e-2),
307       expected_parameters::Adam_with_weight_decay());
308 }
309 
TEST(OptimTest,ProducesPyTorchValues_AdamWithWeightDecayAndAMSGrad)310 TEST(OptimTest, ProducesPyTorchValues_AdamWithWeightDecayAndAMSGrad) {
311   check_exact_values<Adam>(
312       AdamOptions(1.0).weight_decay(1e-6).amsgrad(true),
313       expected_parameters::Adam_with_weight_decay_and_amsgrad());
314 }
315 
TEST(OptimTest,XORConvergence_AdamW)316 TEST(OptimTest, XORConvergence_AdamW) {
317   ASSERT_TRUE(test_optimizer_xor<AdamW>(AdamWOptions(0.1)));
318 }
319 
TEST(OptimTest,XORConvergence_AdamWWithAmsgrad)320 TEST(OptimTest, XORConvergence_AdamWWithAmsgrad) {
321   ASSERT_TRUE(test_optimizer_xor<AdamW>(AdamWOptions(0.1).amsgrad(true)));
322 }
323 
TEST(OptimTest,ProducesPyTorchValues_AdamW)324 TEST(OptimTest, ProducesPyTorchValues_AdamW) {
325   check_exact_values<AdamW>(AdamWOptions(1.0), expected_parameters::AdamW());
326 }
327 
TEST(OptimTest,ProducesPyTorchValues_AdamWWithoutWeightDecay)328 TEST(OptimTest, ProducesPyTorchValues_AdamWWithoutWeightDecay) {
329   check_exact_values<AdamW>(
330       AdamWOptions(1.0).weight_decay(0),
331       expected_parameters::AdamW_without_weight_decay());
332 }
333 
TEST(OptimTest,ProducesPyTorchValues_AdamWWithAMSGrad)334 TEST(OptimTest, ProducesPyTorchValues_AdamWWithAMSGrad) {
335   check_exact_values<AdamW>(
336       AdamWOptions(1.0).amsgrad(true),
337       expected_parameters::AdamW_with_amsgrad());
338 }
339 
TEST(OptimTest,ProducesPyTorchValues_Adagrad)340 TEST(OptimTest, ProducesPyTorchValues_Adagrad) {
341   check_exact_values<Adagrad>(
342       AdagradOptions(1.0), expected_parameters::Adagrad());
343 }
344 
TEST(OptimTest,ProducesPyTorchValues_AdagradWithWeightDecay)345 TEST(OptimTest, ProducesPyTorchValues_AdagradWithWeightDecay) {
346   check_exact_values<Adagrad>(
347       AdagradOptions(1.0).weight_decay(1e-2),
348       expected_parameters::Adagrad_with_weight_decay());
349 }
350 
TEST(OptimTest,ProducesPyTorchValues_AdagradWithWeightDecayAndLRDecay)351 TEST(OptimTest, ProducesPyTorchValues_AdagradWithWeightDecayAndLRDecay) {
352   check_exact_values<Adagrad>(
353       AdagradOptions(1.0).weight_decay(1e-6).lr_decay(1e-3),
354       expected_parameters::Adagrad_with_weight_decay_and_lr_decay());
355 }
356 
TEST(OptimTest,ProducesPyTorchValues_RMSprop)357 TEST(OptimTest, ProducesPyTorchValues_RMSprop) {
358   check_exact_values<RMSprop>(
359       RMSpropOptions(0.1), expected_parameters::RMSprop());
360 }
361 
TEST(OptimTest,ProducesPyTorchValues_RMSpropWithWeightDecay)362 TEST(OptimTest, ProducesPyTorchValues_RMSpropWithWeightDecay) {
363   check_exact_values<RMSprop>(
364       RMSpropOptions(0.1).weight_decay(1e-2),
365       expected_parameters::RMSprop_with_weight_decay());
366 }
367 
TEST(OptimTest,ProducesPyTorchValues_RMSpropWithWeightDecayAndCentered)368 TEST(OptimTest, ProducesPyTorchValues_RMSpropWithWeightDecayAndCentered) {
369   check_exact_values<RMSprop>(
370       RMSpropOptions(0.1).weight_decay(1e-6).centered(true),
371       expected_parameters::RMSprop_with_weight_decay_and_centered());
372 }
373 
TEST(OptimTest,ProducesPyTorchValues_RMSpropWithWeightDecayAndCenteredAndMomentum)374 TEST(
375     OptimTest,
376     ProducesPyTorchValues_RMSpropWithWeightDecayAndCenteredAndMomentum) {
377   check_exact_values<RMSprop>(
378       RMSpropOptions(0.1).weight_decay(1e-6).centered(true).momentum(0.9),
379       expected_parameters::
380           RMSprop_with_weight_decay_and_centered_and_momentum());
381 }
382 
TEST(OptimTest,ProducesPyTorchValues_SGD)383 TEST(OptimTest, ProducesPyTorchValues_SGD) {
384   check_exact_values<SGD>(SGDOptions(0.1), expected_parameters::SGD());
385 }
386 
TEST(OptimTest,ProducesPyTorchValues_SGDWithWeightDecay)387 TEST(OptimTest, ProducesPyTorchValues_SGDWithWeightDecay) {
388   check_exact_values<SGD>(
389       SGDOptions(0.1).weight_decay(1e-2),
390       expected_parameters::SGD_with_weight_decay());
391 }
392 
TEST(OptimTest,ProducesPyTorchValues_SGDWithWeightDecayAndMomentum)393 TEST(OptimTest, ProducesPyTorchValues_SGDWithWeightDecayAndMomentum) {
394   check_exact_values<SGD>(
395       SGDOptions(0.1).weight_decay(1e-2).momentum(0.9),
396       expected_parameters::SGD_with_weight_decay_and_momentum());
397 }
398 
TEST(OptimTest,ProducesPyTorchValues_SGDWithWeightDecayAndNesterovMomentum)399 TEST(OptimTest, ProducesPyTorchValues_SGDWithWeightDecayAndNesterovMomentum) {
400   check_exact_values<SGD>(
401       SGDOptions(0.1).weight_decay(1e-6).momentum(0.9).nesterov(true),
402       expected_parameters::SGD_with_weight_decay_and_nesterov_momentum());
403 }
404 
TEST(OptimTest,ProducesPyTorchValues_LBFGS)405 TEST(OptimTest, ProducesPyTorchValues_LBFGS) {
406   check_exact_values<LBFGS>(LBFGSOptions(1.0), expected_parameters::LBFGS());
407 }
408 
TEST(OptimTest,ProducesPyTorchValues_LBFGS_with_line_search)409 TEST(OptimTest, ProducesPyTorchValues_LBFGS_with_line_search) {
410   check_exact_values<LBFGS>(
411       LBFGSOptions(1.0).line_search_fn("strong_wolfe"),
412       expected_parameters::LBFGS_with_line_search());
413 }
414 
TEST(OptimTest,ZeroGrad)415 TEST(OptimTest, ZeroGrad) {
416   torch::manual_seed(0);
417 
418   Linear model(2, 8);
419   SGD optimizer(model->parameters(), 0.1);
420 
421   for (const auto& parameter : model->parameters()) {
422     ASSERT_FALSE(parameter.grad().defined());
423   }
424 
425   auto output = model->forward(torch::ones({5, 2}));
426   auto loss = output.sum();
427   loss.backward();
428 
429   for (const auto& parameter : model->parameters()) {
430     ASSERT_TRUE(parameter.grad().defined());
431     ASSERT_GT(parameter.grad().sum().item<float>(), 0);
432   }
433 
434   optimizer.zero_grad();
435 
436   for (const auto& parameter : model->parameters()) {
437     ASSERT_FALSE(parameter.grad().defined());
438   }
439 }
440 
TEST(OptimTest,ExternalVectorOfParameters)441 TEST(OptimTest, ExternalVectorOfParameters) {
442   torch::manual_seed(0);
443 
444   std::vector<torch::Tensor> parameters = {
445       torch::randn({2, 2}), torch::randn({3, 3}), torch::randn({4, 4})};
446   std::vector<torch::Tensor> original_parameters = {
447       parameters[0].clone(), parameters[1].clone(), parameters[2].clone()};
448 
449   // Set all gradients to one
450   for (auto& parameter : parameters) {
451     parameter.mutable_grad() = torch::ones_like(parameter);
452   }
453 
454   SGD optimizer(parameters, 1.0);
455 
456   optimizer.step();
457 
458   ASSERT_TRUE(parameters[0].allclose(original_parameters[0] - 1.0));
459   ASSERT_TRUE(parameters[1].allclose(original_parameters[1] - 1.0));
460   ASSERT_TRUE(parameters[2].allclose(original_parameters[2] - 1.0));
461 }
462 
TEST(OptimTest,AddParameter_LBFGS)463 TEST(OptimTest, AddParameter_LBFGS) {
464   torch::manual_seed(0);
465 
466   std::vector<torch::Tensor> parameters = {torch::randn({5, 5})};
467   std::vector<torch::Tensor> original_parameters = {parameters[0].clone()};
468 
469   // Set all gradients to one
470   for (auto& parameter : parameters) {
471     parameter.mutable_grad() = torch::ones_like(parameter);
472   }
473 
474   LBFGS optimizer(std::vector<torch::Tensor>{}, 1.0);
475   OLD_INTERFACE_WARNING_CHECK(optimizer.add_parameters(parameters));
476 
477   optimizer.step([]() { return torch::tensor(1); });
478 
479   // REQUIRE this doesn't throw
480 }
481 
482 // Check whether the learning rate of the parameter groups in the optimizer are
483 // the same as the expected learning rates given in the epoch:learning rate map
check_lr_change(Optimizer & optimizer,LRScheduler & lr_scheduler,std::map<unsigned,double> expected_epoch_lrs)484 void check_lr_change(
485     Optimizer& optimizer,
486     LRScheduler& lr_scheduler,
487     std::map<unsigned, double> expected_epoch_lrs) {
488   // Find maximum epoch in map
489   unsigned kIterations = std::max_element(
490                              expected_epoch_lrs.begin(),
491                              expected_epoch_lrs.end(),
492                              [](const std::pair<unsigned, double>& a,
493                                 const std::pair<unsigned, double>& b) -> bool {
494                                return a.second > b.second;
495                              })
496                              ->first;
497 
498   for (unsigned i = 0; i <= kIterations; i++) {
499     const auto epoch_iter = expected_epoch_lrs.find(i);
500     if (epoch_iter != expected_epoch_lrs.end()) {
501       // Compare the similarity of the two floating point learning rates
502       ASSERT_TRUE(
503           fabs(
504               epoch_iter->second -
505               optimizer.param_groups()[0].options().get_lr()) <
506           std::numeric_limits<double>::epsilon());
507     }
508     optimizer.step();
509     lr_scheduler.step();
510   }
511 }
512 
513 // Very similar to check_lr_change, but for ReduceLROnPlateauScheduler
514 // which does not inherit from LRScheduler and requires a metrics
515 // input to step().
check_lr_change_for_reduce_on_plateau(Optimizer & optimizer,ReduceLROnPlateauScheduler & lr_scheduler,std::map<unsigned,double> expected_epoch_lrs)516 void check_lr_change_for_reduce_on_plateau(
517     Optimizer& optimizer,
518     ReduceLROnPlateauScheduler& lr_scheduler,
519     std::map<unsigned, double> expected_epoch_lrs) {
520   // Find maximum epoch in map
521   unsigned kIterations = std::max_element(
522                              expected_epoch_lrs.begin(),
523                              expected_epoch_lrs.end(),
524                              [](const std::pair<unsigned, double>& a,
525                                 const std::pair<unsigned, double>& b) -> bool {
526                                return a.second > b.second;
527                              })
528                              ->first;
529 
530   for (unsigned i = 0; i <= kIterations; i++) {
531     const auto epoch_iter = expected_epoch_lrs.find(i);
532     if (epoch_iter != expected_epoch_lrs.end()) {
533       // Compare the similarity of the two floating point learning rates
534       ASSERT_TRUE(
535           fabs(
536               epoch_iter->second -
537               optimizer.param_groups()[0].options().get_lr()) <
538           std::numeric_limits<double>::epsilon());
539     }
540     optimizer.step();
541     lr_scheduler.step(5.0);
542   }
543 }
544 
TEST(OptimTest,CheckLRChange_StepLR_Adam)545 TEST(OptimTest, CheckLRChange_StepLR_Adam) {
546   torch::Tensor parameters = torch::zeros({1});
547   auto optimizer = Adam({parameters}, AdamOptions().lr(1e-3));
548 
549   const unsigned step_size = 20;
550   const double gamma = 0.5;
551   StepLR step_lr_scheduler(optimizer, step_size, gamma);
552 
553   // The learning rate should have halved at epoch 20
554   const std::map<unsigned, double> expected_epoch_lrs = {{1, 1e-3}, {25, 5e-4}};
555 
556   check_lr_change(optimizer, step_lr_scheduler, expected_epoch_lrs);
557 }
558 
TEST(OptimTest,CheckLRChange_ReduceLROnPlateau_Adam)559 TEST(OptimTest, CheckLRChange_ReduceLROnPlateau_Adam) {
560   torch::Tensor parameters = torch::zeros({1});
561   auto optimizer = Adam({parameters}, AdamOptions().lr(1e-3));
562   const float factor = 0.5;
563   const int patience = 20;
564   ReduceLROnPlateauScheduler reduce_lr_on_plateau_scheduler(
565       optimizer,
566       ReduceLROnPlateauScheduler::SchedulerMode::min,
567       factor,
568       patience);
569 
570   // The learning rate should have halved at epoch 20
571   const std::map<unsigned, double> expected_epoch_lrs = {{1, 1e-3}, {25, 5e-4}};
572 
573   check_lr_change_for_reduce_on_plateau(
574       optimizer, reduce_lr_on_plateau_scheduler, expected_epoch_lrs);
575 }
576