xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/optim/schedulers/reduce_on_plateau_scheduler.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/optim/schedulers/reduce_on_plateau_scheduler.h>
2 
3 #include <iomanip>
4 
5 namespace torch {
6 namespace optim {
7 
ReduceLROnPlateauScheduler(Optimizer & optimizer,SchedulerMode mode,float factor,int patience,double threshold,ThresholdMode threshold_mode,int cooldown,const std::vector<float> & min_lr,double eps,bool verbose)8 ReduceLROnPlateauScheduler::ReduceLROnPlateauScheduler(
9     Optimizer& optimizer,
10     SchedulerMode mode,
11     float factor,
12     int patience,
13     double threshold,
14     ThresholdMode threshold_mode,
15     int cooldown,
16     const std::vector<float>& min_lr,
17     double eps,
18     bool verbose)
19     : optimizer(optimizer) {
20   if (min_lr.empty()) {
21     this->min_lrs = std::vector<float>(optimizer.param_groups().size());
22   } else {
23     // Check if number of learning rates is equal to the number of parameters
24     // groups in the optimizer
25     TORCH_CHECK(
26         min_lr.size() == optimizer.param_groups().size(),
27         "Number of learning rates not equal to the number of param groups\n",
28         "Number of learning rates given: ",
29         min_lr.size(),
30         "\nNumber of param groups: ",
31         optimizer.param_groups().size());
32     this->min_lrs = min_lr;
33   }
34 
35   TORCH_CHECK(factor < 1.0, "Factor should be < 1.0.");
36   this->factor = factor;
37   this->patience = patience;
38   this->cooldown = cooldown;
39   this->eps = eps;
40   this->verbose = verbose;
41 
42   init_is_better(mode, threshold, threshold_mode);
43   reset();
44 }
45 
step(float metrics)46 void ReduceLROnPlateauScheduler::step(float metrics) {
47   last_epoch++;
48 
49   if (is_better(metrics)) {
50     best = metrics;
51     num_bad_epochs = 0;
52   } else {
53     num_bad_epochs++;
54   }
55 
56   if (in_cooldown()) {
57     cooldown_counter--;
58     num_bad_epochs = 0;
59   }
60 
61   if (num_bad_epochs > patience) {
62     reduce_lr(last_epoch);
63     cooldown_counter = cooldown;
64     num_bad_epochs = 0;
65   }
66 }
67 
reduce_lr(int epoch)68 void ReduceLROnPlateauScheduler::reduce_lr(int epoch) {
69   for (std::size_t i = 0; i < optimizer.param_groups().size(); i++) {
70     auto old_lr = optimizer.param_groups()[i].options().get_lr();
71     auto new_lr = std::fmax(old_lr * factor, min_lrs[i]);
72     if (old_lr - new_lr > eps) {
73       optimizer.param_groups()[i].options().set_lr(new_lr);
74       if (verbose) {
75         std::cout << std::setprecision(4) << "Epoch " << epoch
76                   << ": reducing learning rate of group " << i << " to "
77                   << new_lr << std::endl;
78       }
79     }
80   }
81 }
82 
reset()83 void ReduceLROnPlateauScheduler::reset() {
84   this->cooldown_counter = 0;
85   this->num_bad_epochs = 0;
86   this->last_epoch = 0;
87   this->best = mode_worse;
88 }
89 
in_cooldown()90 bool ReduceLROnPlateauScheduler::in_cooldown() {
91   return cooldown_counter > 0;
92 }
93 
is_better(float a)94 bool ReduceLROnPlateauScheduler::is_better(float a) {
95   if (mode == min && threshold_mode == rel) {
96     auto rel_epsilon = 1.0 - threshold;
97     return a < best * rel_epsilon;
98   } else if (mode == min && threshold_mode == abs) {
99     return a < best - threshold;
100   } else if (mode == max && threshold_mode == rel) {
101     auto rel_epsilon = 1.0 + threshold;
102     return a > best * rel_epsilon;
103   } else {
104     return a > best * threshold;
105   }
106 }
107 
init_is_better(SchedulerMode mode,double threshold,ThresholdMode threshold_mode)108 void ReduceLROnPlateauScheduler::init_is_better(
109     SchedulerMode mode,
110     double threshold,
111     ThresholdMode threshold_mode) {
112   if (mode == min) {
113     mode_worse = std::numeric_limits<float>::max();
114   } else {
115     mode_worse = std::numeric_limits<float>::min();
116   }
117 
118   this->mode = mode;
119   this->threshold_mode = threshold_mode;
120   this->threshold = threshold;
121 }
122 } // namespace optim
123 } // namespace torch
124