1 #include <torch/optim/schedulers/reduce_on_plateau_scheduler.h>
2
3 #include <iomanip>
4
5 namespace torch {
6 namespace optim {
7
ReduceLROnPlateauScheduler(Optimizer & optimizer,SchedulerMode mode,float factor,int patience,double threshold,ThresholdMode threshold_mode,int cooldown,const std::vector<float> & min_lr,double eps,bool verbose)8 ReduceLROnPlateauScheduler::ReduceLROnPlateauScheduler(
9 Optimizer& optimizer,
10 SchedulerMode mode,
11 float factor,
12 int patience,
13 double threshold,
14 ThresholdMode threshold_mode,
15 int cooldown,
16 const std::vector<float>& min_lr,
17 double eps,
18 bool verbose)
19 : optimizer(optimizer) {
20 if (min_lr.empty()) {
21 this->min_lrs = std::vector<float>(optimizer.param_groups().size());
22 } else {
23 // Check if number of learning rates is equal to the number of parameters
24 // groups in the optimizer
25 TORCH_CHECK(
26 min_lr.size() == optimizer.param_groups().size(),
27 "Number of learning rates not equal to the number of param groups\n",
28 "Number of learning rates given: ",
29 min_lr.size(),
30 "\nNumber of param groups: ",
31 optimizer.param_groups().size());
32 this->min_lrs = min_lr;
33 }
34
35 TORCH_CHECK(factor < 1.0, "Factor should be < 1.0.");
36 this->factor = factor;
37 this->patience = patience;
38 this->cooldown = cooldown;
39 this->eps = eps;
40 this->verbose = verbose;
41
42 init_is_better(mode, threshold, threshold_mode);
43 reset();
44 }
45
step(float metrics)46 void ReduceLROnPlateauScheduler::step(float metrics) {
47 last_epoch++;
48
49 if (is_better(metrics)) {
50 best = metrics;
51 num_bad_epochs = 0;
52 } else {
53 num_bad_epochs++;
54 }
55
56 if (in_cooldown()) {
57 cooldown_counter--;
58 num_bad_epochs = 0;
59 }
60
61 if (num_bad_epochs > patience) {
62 reduce_lr(last_epoch);
63 cooldown_counter = cooldown;
64 num_bad_epochs = 0;
65 }
66 }
67
reduce_lr(int epoch)68 void ReduceLROnPlateauScheduler::reduce_lr(int epoch) {
69 for (std::size_t i = 0; i < optimizer.param_groups().size(); i++) {
70 auto old_lr = optimizer.param_groups()[i].options().get_lr();
71 auto new_lr = std::fmax(old_lr * factor, min_lrs[i]);
72 if (old_lr - new_lr > eps) {
73 optimizer.param_groups()[i].options().set_lr(new_lr);
74 if (verbose) {
75 std::cout << std::setprecision(4) << "Epoch " << epoch
76 << ": reducing learning rate of group " << i << " to "
77 << new_lr << std::endl;
78 }
79 }
80 }
81 }
82
reset()83 void ReduceLROnPlateauScheduler::reset() {
84 this->cooldown_counter = 0;
85 this->num_bad_epochs = 0;
86 this->last_epoch = 0;
87 this->best = mode_worse;
88 }
89
in_cooldown()90 bool ReduceLROnPlateauScheduler::in_cooldown() {
91 return cooldown_counter > 0;
92 }
93
is_better(float a)94 bool ReduceLROnPlateauScheduler::is_better(float a) {
95 if (mode == min && threshold_mode == rel) {
96 auto rel_epsilon = 1.0 - threshold;
97 return a < best * rel_epsilon;
98 } else if (mode == min && threshold_mode == abs) {
99 return a < best - threshold;
100 } else if (mode == max && threshold_mode == rel) {
101 auto rel_epsilon = 1.0 + threshold;
102 return a > best * rel_epsilon;
103 } else {
104 return a > best * threshold;
105 }
106 }
107
init_is_better(SchedulerMode mode,double threshold,ThresholdMode threshold_mode)108 void ReduceLROnPlateauScheduler::init_is_better(
109 SchedulerMode mode,
110 double threshold,
111 ThresholdMode threshold_mode) {
112 if (mode == min) {
113 mode_worse = std::numeric_limits<float>::max();
114 } else {
115 mode_worse = std::numeric_limits<float>::min();
116 }
117
118 this->mode = mode;
119 this->threshold_mode = threshold_mode;
120 this->threshold = threshold;
121 }
122 } // namespace optim
123 } // namespace torch
124