1 #include <gtest/gtest.h>
2
3 #include <c10/util/irange.h>
4 #include <torch/torch.h>
5
6 #include <test/cpp/api/optim_baseline.h>
7 #include <test/cpp/api/support.h>
8
9 #include <cmath>
10 #include <cstdlib>
11 #include <functional>
12 #include <iostream>
13 #include <memory>
14 #include <random>
15 #include <vector>
16
17 using namespace torch::nn;
18 using namespace torch::optim;
19
20 template <typename OptimizerClass, typename Options>
test_optimizer_xor(Options options)21 bool test_optimizer_xor(Options options) {
22 torch::manual_seed(0);
23
24 Sequential model(
25 Linear(2, 8),
26 Functional(torch::sigmoid),
27 Linear(8, 1),
28 Functional(torch::sigmoid));
29
30 const int64_t kBatchSize = 200;
31 const int64_t kMaximumNumberOfEpochs = 3000;
32
33 OptimizerClass optimizer(model->parameters(), options);
34
35 float running_loss = 1;
36 int epoch = 0;
37 while (running_loss > 0.1) {
38 auto inputs = torch::empty({kBatchSize, 2});
39 auto labels = torch::empty({kBatchSize});
40 for (const auto i : c10::irange(kBatchSize)) {
41 inputs[i] = torch::randint(2, {2}, torch::kInt64);
42 labels[i] = inputs[i][0].item<int64_t>() ^ inputs[i][1].item<int64_t>();
43 }
44
45 inputs.set_requires_grad(true);
46
47 auto step = [&](OptimizerClass& optimizer,
48 Sequential model,
49 torch::Tensor inputs,
50 torch::Tensor labels) {
51 auto closure = [&]() {
52 optimizer.zero_grad();
53 auto x = model->forward(inputs);
54 auto loss = torch::binary_cross_entropy(x, labels);
55 loss.backward();
56 return loss;
57 };
58 return optimizer.step(closure);
59 };
60
61 torch::Tensor loss = step(optimizer, model, inputs, labels);
62
63 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,bugprone-narrowing-conversions)
64 running_loss = running_loss * 0.99 + loss.item<float>() * 0.01;
65 if (epoch > kMaximumNumberOfEpochs) {
66 std::cout << "Loss is too high after epoch " << epoch << ": "
67 << running_loss << std::endl;
68 return false;
69 }
70 epoch++;
71 }
72 return true;
73 }
74
75 template <typename Parameters>
assign_parameter(const Parameters & parameters,const char * name,torch::Tensor new_tensor)76 void assign_parameter(
77 const Parameters& parameters,
78 const char* name,
79 torch::Tensor new_tensor) {
80 auto parameter = parameters[name];
81 parameter.set_requires_grad(false);
82 parameter.flatten().copy_(new_tensor);
83 parameter.set_requires_grad(true);
84 }
85
86 template <typename OptimizerClass, typename Options>
check_exact_values(Options options,std::vector<std::vector<torch::Tensor>> expected_parameters)87 void check_exact_values(
88 Options options,
89 std::vector<std::vector<torch::Tensor>> expected_parameters) {
90 const size_t kIterations = 1001;
91 const size_t kSampleEvery = 100;
92
93 torch::manual_seed(0);
94
95 Sequential model(
96 Linear(2, 3),
97 Functional(torch::sigmoid),
98 Linear(3, 1),
99 Functional(torch::sigmoid));
100
101 model->to(torch::kFloat64);
102
103 // Use exact input values because matching random values is hard.
104 auto parameters = model->named_parameters();
105 assign_parameter(
106 parameters,
107 "0.weight",
108 torch::tensor(
109 {-0.2109, -0.4976, -0.1413, -0.3420, -0.2524, 0.6976},
110 torch::kFloat64));
111 assign_parameter(
112 parameters,
113 "0.bias",
114 torch::tensor({-0.1085, -0.2979, 0.6892}, torch::kFloat64));
115 assign_parameter(
116 parameters,
117 "2.weight",
118 torch::tensor({-0.0508, -0.3941, -0.2843}, torch::kFloat64));
119 assign_parameter(
120 parameters, "2.bias", torch::tensor({-0.0711}, torch::kFloat64));
121
122 auto optimizer = OptimizerClass(parameters.values(), options);
123 torch::Tensor input =
124 torch::tensor({0.1, 0.2, 0.3, 0.4, 0.5, 0.6}, torch::kFloat64)
125 .reshape({3, 2});
126
127 for (const auto i : c10::irange(kIterations)) {
128 optimizer.zero_grad();
129 auto output = model->forward(input);
130 auto loss = output.sum();
131 loss.backward();
132
133 auto closure = []() { return torch::tensor({10}); };
134 optimizer.step(closure);
135
136 if (i % kSampleEvery == 0) {
137 ASSERT_TRUE(
138 expected_parameters.at(i / kSampleEvery).size() == parameters.size());
139 for (const auto p : c10::irange(parameters.size())) {
140 ASSERT_TRUE(parameters[p]->defined());
141 // Always compare using double dtype, regardless of the original dtype
142 // of the tensors
143 auto computed = parameters[p]->flatten().to(torch::kFloat64);
144 auto expected =
145 expected_parameters.at(i / kSampleEvery).at(p).to(torch::kFloat64);
146 if (!computed.allclose(expected, /*rtol=*/1e-3, /*atol=*/5e-4)) {
147 std::cout << "Iteration " << i << ": " << computed
148 << " != " << expected << " (parameter " << p << ")"
149 << std::endl;
150 ASSERT_TRUE(false);
151 }
152 }
153 }
154 }
155 }
156
TEST(OptimTest,OptimizerAccessors)157 TEST(OptimTest, OptimizerAccessors) {
158 auto options = AdagradOptions(1.0);
159 std::vector<torch::Tensor> params;
160 for (const auto i : c10::irange(3)) {
161 (void)i; // Suppress unused variable warning
162 params.push_back(torch::randn(10));
163 }
164 auto optimizer = Adagrad(params, options);
165 // test for defaults() method with non-const reference
166 auto& options_ = static_cast<AdagradOptions&>(optimizer.defaults());
167 ASSERT_TRUE(options == options_);
168 // test for param_groups() with non-const reference return
169 auto& params_groups = optimizer.param_groups();
170 // NOLINTNEXTLINE(modernize-use-emplace)
171 params_groups.push_back(OptimizerParamGroup(params));
172 auto& params_1 = params_groups[1].params();
173 for (const auto i : c10::irange(params_1.size())) {
174 torch::equal(params[i], params_1[i]);
175 }
176
177 // test for add_param_group() when one or more params existing in another
178 // param_group are passed in the new param group to be added
179 ASSERT_THROWS_WITH(
180 optimizer.add_param_group(OptimizerParamGroup(params)),
181 "some parameters appear in more than one parameter group");
182
183 // test for state() with non-const reference return
184 auto& state_ = static_cast<AdagradParamState&>(
185 *(optimizer.state()[params_1[0].unsafeGetTensorImpl()]));
186 state_.step(state_.step() + 1);
187
188 const auto& optimizer_ = Adagrad(params, options);
189 optimizer_.defaults();
190 // test for param_groups() with const reference return
191 (void)optimizer_.param_groups();
192 // test for state() with const reference return
193 optimizer_.state();
194 }
195
196 #define OLD_INTERFACE_WARNING_CHECK(func) \
197 { \
198 torch::test::WarningCapture warnings; \
199 func; \
200 ASSERT_EQ( \
201 torch::test::count_substr_occurrences( \
202 warnings.str(), "will be removed"), \
203 1); \
204 }
205
206 struct MyOptimizerOptions
207 : public OptimizerCloneableOptions<MyOptimizerOptions> {
MyOptimizerOptionsMyOptimizerOptions208 MyOptimizerOptions(double lr = 1.0) : lr_(lr){};
209 TORCH_ARG(double, lr) = 1.0;
210 };
211
TEST(OptimTest,OldInterface)212 TEST(OptimTest, OldInterface) {
213 struct MyOptimizer : Optimizer {
214 using Optimizer::Optimizer;
215 torch::Tensor step(LossClosure closure = nullptr) override {
216 return {};
217 }
218 explicit MyOptimizer(
219 std::vector<at::Tensor> params,
220 MyOptimizerOptions defaults = {})
221 : // NOLINTNEXTLINE(performance-move-const-arg)
222 Optimizer(
223 {std::move(OptimizerParamGroup(params))},
224 std::make_unique<MyOptimizerOptions>(defaults)) {}
225 };
226 std::vector<torch::Tensor> parameters = {
227 torch::ones({2, 3}), torch::zeros({2, 3}), torch::rand({2, 3})};
228 {
229 MyOptimizer optimizer(parameters);
230 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
231 size_t size;
232 OLD_INTERFACE_WARNING_CHECK(size = optimizer.size());
233 ASSERT_EQ(size, parameters.size());
234 }
235 {
236 std::vector<at::Tensor> params;
237 MyOptimizer optimizer(params);
238
239 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
240 size_t size;
241 OLD_INTERFACE_WARNING_CHECK(size = optimizer.size());
242 ASSERT_EQ(size, 0);
243
244 OLD_INTERFACE_WARNING_CHECK(optimizer.add_parameters(parameters));
245
246 OLD_INTERFACE_WARNING_CHECK(size = optimizer.size());
247 ASSERT_EQ(size, parameters.size());
248
249 std::vector<torch::Tensor> params_;
250 OLD_INTERFACE_WARNING_CHECK(params_ = optimizer.parameters());
251 for (const auto p : c10::irange(size)) {
252 ASSERT_TRUE(params_[p].allclose(parameters[p]));
253 }
254 }
255 {
256 Linear linear(3, 4);
257 MyOptimizer optimizer(linear->parameters());
258
259 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
260 size_t size;
261 OLD_INTERFACE_WARNING_CHECK(size = optimizer.size());
262 ASSERT_EQ(size, linear->parameters().size());
263 }
264 }
265
TEST(OptimTest,XORConvergence_SGD)266 TEST(OptimTest, XORConvergence_SGD) {
267 ASSERT_TRUE(test_optimizer_xor<SGD>(
268 SGDOptions(0.1).momentum(0.9).nesterov(true).weight_decay(1e-6)));
269 }
270
TEST(OptimTest,XORConvergence_LBFGS)271 TEST(OptimTest, XORConvergence_LBFGS) {
272 ASSERT_TRUE(test_optimizer_xor<LBFGS>(LBFGSOptions(1.0)));
273 ASSERT_TRUE(test_optimizer_xor<LBFGS>(
274 LBFGSOptions(1.0).line_search_fn("strong_wolfe")));
275 }
276
TEST(OptimTest,XORConvergence_Adagrad)277 TEST(OptimTest, XORConvergence_Adagrad) {
278 ASSERT_TRUE(test_optimizer_xor<Adagrad>(
279 AdagradOptions(1.0).weight_decay(1e-6).lr_decay(1e-3)));
280 }
281
TEST(OptimTest,XORConvergence_RMSprop)282 TEST(OptimTest, XORConvergence_RMSprop) {
283 ASSERT_TRUE(test_optimizer_xor<RMSprop>(RMSpropOptions(0.1).centered(true)));
284 }
285
TEST(OptimTest,XORConvergence_RMSpropWithMomentum)286 TEST(OptimTest, XORConvergence_RMSpropWithMomentum) {
287 ASSERT_TRUE(test_optimizer_xor<RMSprop>(
288 RMSpropOptions(0.1).momentum(0.9).weight_decay(1e-6)));
289 }
290
TEST(OptimTest,XORConvergence_Adam)291 TEST(OptimTest, XORConvergence_Adam) {
292 ASSERT_TRUE(test_optimizer_xor<Adam>(AdamOptions(0.1).weight_decay(1e-6)));
293 }
294
TEST(OptimTest,XORConvergence_AdamWithAmsgrad)295 TEST(OptimTest, XORConvergence_AdamWithAmsgrad) {
296 ASSERT_TRUE(test_optimizer_xor<Adam>(
297 AdamOptions(0.1).weight_decay(1e-6).amsgrad(true)));
298 }
299
TEST(OptimTest,ProducesPyTorchValues_Adam)300 TEST(OptimTest, ProducesPyTorchValues_Adam) {
301 check_exact_values<Adam>(AdamOptions(1.0), expected_parameters::Adam());
302 }
303
TEST(OptimTest,ProducesPyTorchValues_AdamWithWeightDecay)304 TEST(OptimTest, ProducesPyTorchValues_AdamWithWeightDecay) {
305 check_exact_values<Adam>(
306 AdamOptions(1.0).weight_decay(1e-2),
307 expected_parameters::Adam_with_weight_decay());
308 }
309
TEST(OptimTest,ProducesPyTorchValues_AdamWithWeightDecayAndAMSGrad)310 TEST(OptimTest, ProducesPyTorchValues_AdamWithWeightDecayAndAMSGrad) {
311 check_exact_values<Adam>(
312 AdamOptions(1.0).weight_decay(1e-6).amsgrad(true),
313 expected_parameters::Adam_with_weight_decay_and_amsgrad());
314 }
315
TEST(OptimTest,XORConvergence_AdamW)316 TEST(OptimTest, XORConvergence_AdamW) {
317 ASSERT_TRUE(test_optimizer_xor<AdamW>(AdamWOptions(0.1)));
318 }
319
TEST(OptimTest,XORConvergence_AdamWWithAmsgrad)320 TEST(OptimTest, XORConvergence_AdamWWithAmsgrad) {
321 ASSERT_TRUE(test_optimizer_xor<AdamW>(AdamWOptions(0.1).amsgrad(true)));
322 }
323
TEST(OptimTest,ProducesPyTorchValues_AdamW)324 TEST(OptimTest, ProducesPyTorchValues_AdamW) {
325 check_exact_values<AdamW>(AdamWOptions(1.0), expected_parameters::AdamW());
326 }
327
TEST(OptimTest,ProducesPyTorchValues_AdamWWithoutWeightDecay)328 TEST(OptimTest, ProducesPyTorchValues_AdamWWithoutWeightDecay) {
329 check_exact_values<AdamW>(
330 AdamWOptions(1.0).weight_decay(0),
331 expected_parameters::AdamW_without_weight_decay());
332 }
333
TEST(OptimTest,ProducesPyTorchValues_AdamWWithAMSGrad)334 TEST(OptimTest, ProducesPyTorchValues_AdamWWithAMSGrad) {
335 check_exact_values<AdamW>(
336 AdamWOptions(1.0).amsgrad(true),
337 expected_parameters::AdamW_with_amsgrad());
338 }
339
TEST(OptimTest,ProducesPyTorchValues_Adagrad)340 TEST(OptimTest, ProducesPyTorchValues_Adagrad) {
341 check_exact_values<Adagrad>(
342 AdagradOptions(1.0), expected_parameters::Adagrad());
343 }
344
TEST(OptimTest,ProducesPyTorchValues_AdagradWithWeightDecay)345 TEST(OptimTest, ProducesPyTorchValues_AdagradWithWeightDecay) {
346 check_exact_values<Adagrad>(
347 AdagradOptions(1.0).weight_decay(1e-2),
348 expected_parameters::Adagrad_with_weight_decay());
349 }
350
TEST(OptimTest,ProducesPyTorchValues_AdagradWithWeightDecayAndLRDecay)351 TEST(OptimTest, ProducesPyTorchValues_AdagradWithWeightDecayAndLRDecay) {
352 check_exact_values<Adagrad>(
353 AdagradOptions(1.0).weight_decay(1e-6).lr_decay(1e-3),
354 expected_parameters::Adagrad_with_weight_decay_and_lr_decay());
355 }
356
TEST(OptimTest,ProducesPyTorchValues_RMSprop)357 TEST(OptimTest, ProducesPyTorchValues_RMSprop) {
358 check_exact_values<RMSprop>(
359 RMSpropOptions(0.1), expected_parameters::RMSprop());
360 }
361
TEST(OptimTest,ProducesPyTorchValues_RMSpropWithWeightDecay)362 TEST(OptimTest, ProducesPyTorchValues_RMSpropWithWeightDecay) {
363 check_exact_values<RMSprop>(
364 RMSpropOptions(0.1).weight_decay(1e-2),
365 expected_parameters::RMSprop_with_weight_decay());
366 }
367
TEST(OptimTest,ProducesPyTorchValues_RMSpropWithWeightDecayAndCentered)368 TEST(OptimTest, ProducesPyTorchValues_RMSpropWithWeightDecayAndCentered) {
369 check_exact_values<RMSprop>(
370 RMSpropOptions(0.1).weight_decay(1e-6).centered(true),
371 expected_parameters::RMSprop_with_weight_decay_and_centered());
372 }
373
TEST(OptimTest,ProducesPyTorchValues_RMSpropWithWeightDecayAndCenteredAndMomentum)374 TEST(
375 OptimTest,
376 ProducesPyTorchValues_RMSpropWithWeightDecayAndCenteredAndMomentum) {
377 check_exact_values<RMSprop>(
378 RMSpropOptions(0.1).weight_decay(1e-6).centered(true).momentum(0.9),
379 expected_parameters::
380 RMSprop_with_weight_decay_and_centered_and_momentum());
381 }
382
TEST(OptimTest,ProducesPyTorchValues_SGD)383 TEST(OptimTest, ProducesPyTorchValues_SGD) {
384 check_exact_values<SGD>(SGDOptions(0.1), expected_parameters::SGD());
385 }
386
TEST(OptimTest,ProducesPyTorchValues_SGDWithWeightDecay)387 TEST(OptimTest, ProducesPyTorchValues_SGDWithWeightDecay) {
388 check_exact_values<SGD>(
389 SGDOptions(0.1).weight_decay(1e-2),
390 expected_parameters::SGD_with_weight_decay());
391 }
392
TEST(OptimTest,ProducesPyTorchValues_SGDWithWeightDecayAndMomentum)393 TEST(OptimTest, ProducesPyTorchValues_SGDWithWeightDecayAndMomentum) {
394 check_exact_values<SGD>(
395 SGDOptions(0.1).weight_decay(1e-2).momentum(0.9),
396 expected_parameters::SGD_with_weight_decay_and_momentum());
397 }
398
TEST(OptimTest,ProducesPyTorchValues_SGDWithWeightDecayAndNesterovMomentum)399 TEST(OptimTest, ProducesPyTorchValues_SGDWithWeightDecayAndNesterovMomentum) {
400 check_exact_values<SGD>(
401 SGDOptions(0.1).weight_decay(1e-6).momentum(0.9).nesterov(true),
402 expected_parameters::SGD_with_weight_decay_and_nesterov_momentum());
403 }
404
TEST(OptimTest,ProducesPyTorchValues_LBFGS)405 TEST(OptimTest, ProducesPyTorchValues_LBFGS) {
406 check_exact_values<LBFGS>(LBFGSOptions(1.0), expected_parameters::LBFGS());
407 }
408
TEST(OptimTest,ProducesPyTorchValues_LBFGS_with_line_search)409 TEST(OptimTest, ProducesPyTorchValues_LBFGS_with_line_search) {
410 check_exact_values<LBFGS>(
411 LBFGSOptions(1.0).line_search_fn("strong_wolfe"),
412 expected_parameters::LBFGS_with_line_search());
413 }
414
TEST(OptimTest,ZeroGrad)415 TEST(OptimTest, ZeroGrad) {
416 torch::manual_seed(0);
417
418 Linear model(2, 8);
419 SGD optimizer(model->parameters(), 0.1);
420
421 for (const auto& parameter : model->parameters()) {
422 ASSERT_FALSE(parameter.grad().defined());
423 }
424
425 auto output = model->forward(torch::ones({5, 2}));
426 auto loss = output.sum();
427 loss.backward();
428
429 for (const auto& parameter : model->parameters()) {
430 ASSERT_TRUE(parameter.grad().defined());
431 ASSERT_GT(parameter.grad().sum().item<float>(), 0);
432 }
433
434 optimizer.zero_grad();
435
436 for (const auto& parameter : model->parameters()) {
437 ASSERT_FALSE(parameter.grad().defined());
438 }
439 }
440
TEST(OptimTest,ExternalVectorOfParameters)441 TEST(OptimTest, ExternalVectorOfParameters) {
442 torch::manual_seed(0);
443
444 std::vector<torch::Tensor> parameters = {
445 torch::randn({2, 2}), torch::randn({3, 3}), torch::randn({4, 4})};
446 std::vector<torch::Tensor> original_parameters = {
447 parameters[0].clone(), parameters[1].clone(), parameters[2].clone()};
448
449 // Set all gradients to one
450 for (auto& parameter : parameters) {
451 parameter.mutable_grad() = torch::ones_like(parameter);
452 }
453
454 SGD optimizer(parameters, 1.0);
455
456 optimizer.step();
457
458 ASSERT_TRUE(parameters[0].allclose(original_parameters[0] - 1.0));
459 ASSERT_TRUE(parameters[1].allclose(original_parameters[1] - 1.0));
460 ASSERT_TRUE(parameters[2].allclose(original_parameters[2] - 1.0));
461 }
462
TEST(OptimTest,AddParameter_LBFGS)463 TEST(OptimTest, AddParameter_LBFGS) {
464 torch::manual_seed(0);
465
466 std::vector<torch::Tensor> parameters = {torch::randn({5, 5})};
467 std::vector<torch::Tensor> original_parameters = {parameters[0].clone()};
468
469 // Set all gradients to one
470 for (auto& parameter : parameters) {
471 parameter.mutable_grad() = torch::ones_like(parameter);
472 }
473
474 LBFGS optimizer(std::vector<torch::Tensor>{}, 1.0);
475 OLD_INTERFACE_WARNING_CHECK(optimizer.add_parameters(parameters));
476
477 optimizer.step([]() { return torch::tensor(1); });
478
479 // REQUIRE this doesn't throw
480 }
481
482 // Check whether the learning rate of the parameter groups in the optimizer are
483 // the same as the expected learning rates given in the epoch:learning rate map
check_lr_change(Optimizer & optimizer,LRScheduler & lr_scheduler,std::map<unsigned,double> expected_epoch_lrs)484 void check_lr_change(
485 Optimizer& optimizer,
486 LRScheduler& lr_scheduler,
487 std::map<unsigned, double> expected_epoch_lrs) {
488 // Find maximum epoch in map
489 unsigned kIterations = std::max_element(
490 expected_epoch_lrs.begin(),
491 expected_epoch_lrs.end(),
492 [](const std::pair<unsigned, double>& a,
493 const std::pair<unsigned, double>& b) -> bool {
494 return a.second > b.second;
495 })
496 ->first;
497
498 for (unsigned i = 0; i <= kIterations; i++) {
499 const auto epoch_iter = expected_epoch_lrs.find(i);
500 if (epoch_iter != expected_epoch_lrs.end()) {
501 // Compare the similarity of the two floating point learning rates
502 ASSERT_TRUE(
503 fabs(
504 epoch_iter->second -
505 optimizer.param_groups()[0].options().get_lr()) <
506 std::numeric_limits<double>::epsilon());
507 }
508 optimizer.step();
509 lr_scheduler.step();
510 }
511 }
512
513 // Very similar to check_lr_change, but for ReduceLROnPlateauScheduler
514 // which does not inherit from LRScheduler and requires a metrics
515 // input to step().
check_lr_change_for_reduce_on_plateau(Optimizer & optimizer,ReduceLROnPlateauScheduler & lr_scheduler,std::map<unsigned,double> expected_epoch_lrs)516 void check_lr_change_for_reduce_on_plateau(
517 Optimizer& optimizer,
518 ReduceLROnPlateauScheduler& lr_scheduler,
519 std::map<unsigned, double> expected_epoch_lrs) {
520 // Find maximum epoch in map
521 unsigned kIterations = std::max_element(
522 expected_epoch_lrs.begin(),
523 expected_epoch_lrs.end(),
524 [](const std::pair<unsigned, double>& a,
525 const std::pair<unsigned, double>& b) -> bool {
526 return a.second > b.second;
527 })
528 ->first;
529
530 for (unsigned i = 0; i <= kIterations; i++) {
531 const auto epoch_iter = expected_epoch_lrs.find(i);
532 if (epoch_iter != expected_epoch_lrs.end()) {
533 // Compare the similarity of the two floating point learning rates
534 ASSERT_TRUE(
535 fabs(
536 epoch_iter->second -
537 optimizer.param_groups()[0].options().get_lr()) <
538 std::numeric_limits<double>::epsilon());
539 }
540 optimizer.step();
541 lr_scheduler.step(5.0);
542 }
543 }
544
TEST(OptimTest,CheckLRChange_StepLR_Adam)545 TEST(OptimTest, CheckLRChange_StepLR_Adam) {
546 torch::Tensor parameters = torch::zeros({1});
547 auto optimizer = Adam({parameters}, AdamOptions().lr(1e-3));
548
549 const unsigned step_size = 20;
550 const double gamma = 0.5;
551 StepLR step_lr_scheduler(optimizer, step_size, gamma);
552
553 // The learning rate should have halved at epoch 20
554 const std::map<unsigned, double> expected_epoch_lrs = {{1, 1e-3}, {25, 5e-4}};
555
556 check_lr_change(optimizer, step_lr_scheduler, expected_epoch_lrs);
557 }
558
TEST(OptimTest,CheckLRChange_ReduceLROnPlateau_Adam)559 TEST(OptimTest, CheckLRChange_ReduceLROnPlateau_Adam) {
560 torch::Tensor parameters = torch::zeros({1});
561 auto optimizer = Adam({parameters}, AdamOptions().lr(1e-3));
562 const float factor = 0.5;
563 const int patience = 20;
564 ReduceLROnPlateauScheduler reduce_lr_on_plateau_scheduler(
565 optimizer,
566 ReduceLROnPlateauScheduler::SchedulerMode::min,
567 factor,
568 patience);
569
570 // The learning rate should have halved at epoch 20
571 const std::map<unsigned, double> expected_epoch_lrs = {{1, 1e-3}, {25, 5e-4}};
572
573 check_lr_change_for_reduce_on_plateau(
574 optimizer, reduce_lr_on_plateau_scheduler, expected_epoch_lrs);
575 }
576