xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/optim/schedulers/reduce_on_plateau_scheduler.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/optim/optimizer.h>
4 #include <torch/optim/schedulers/lr_scheduler.h>
5 
6 #include <torch/csrc/Export.h>
7 
8 #include <string>
9 
10 #include <cmath>
11 
12 #include <iostream>
13 
14 namespace torch {
15 namespace optim {
16 
17 class TORCH_API ReduceLROnPlateauScheduler {
18  public:
19   enum SchedulerMode { min, max };
20   enum ThresholdMode { rel, abs };
21   ReduceLROnPlateauScheduler(
22       Optimizer& optimizer,
23       SchedulerMode mode = min,
24       float factor = 0.1,
25       int patience = 10,
26       double threshold = 1e-4,
27       ThresholdMode threshold_mode = rel,
28       int cooldown = 0,
29       const std::vector<float>& min_lr = std::vector<float>(),
30       double eps = 1e-8,
31       bool verbose = false);
32 
33   virtual ~ReduceLROnPlateauScheduler() = default;
34 
35   void step(float metric);
36 
37  private:
38   void reset();
39   void reduce_lr(int epoch);
40   bool in_cooldown();
41   bool is_better(float a);
42   void init_is_better(
43       SchedulerMode mode,
44       double threshold,
45       ThresholdMode threshold_mode);
46 
47   Optimizer& optimizer;
48   SchedulerMode mode;
49   float mode_worse;
50   float factor;
51   int patience;
52   double threshold;
53   ThresholdMode threshold_mode;
54   int cooldown;
55   int cooldown_counter;
56   std::vector<float> min_lrs;
57   double eps;
58   float best;
59   bool verbose;
60   int last_epoch;
61   int num_bad_epochs;
62 };
63 } // namespace optim
64 } // namespace torch
65