1 #pragma once 2 3 #include <torch/optim/optimizer.h> 4 #include <torch/optim/schedulers/lr_scheduler.h> 5 6 #include <torch/csrc/Export.h> 7 8 #include <string> 9 10 #include <cmath> 11 12 #include <iostream> 13 14 namespace torch { 15 namespace optim { 16 17 class TORCH_API ReduceLROnPlateauScheduler { 18 public: 19 enum SchedulerMode { min, max }; 20 enum ThresholdMode { rel, abs }; 21 ReduceLROnPlateauScheduler( 22 Optimizer& optimizer, 23 SchedulerMode mode = min, 24 float factor = 0.1, 25 int patience = 10, 26 double threshold = 1e-4, 27 ThresholdMode threshold_mode = rel, 28 int cooldown = 0, 29 const std::vector<float>& min_lr = std::vector<float>(), 30 double eps = 1e-8, 31 bool verbose = false); 32 33 virtual ~ReduceLROnPlateauScheduler() = default; 34 35 void step(float metric); 36 37 private: 38 void reset(); 39 void reduce_lr(int epoch); 40 bool in_cooldown(); 41 bool is_better(float a); 42 void init_is_better( 43 SchedulerMode mode, 44 double threshold, 45 ThresholdMode threshold_mode); 46 47 Optimizer& optimizer; 48 SchedulerMode mode; 49 float mode_worse; 50 float factor; 51 int patience; 52 double threshold; 53 ThresholdMode threshold_mode; 54 int cooldown; 55 int cooldown_counter; 56 std::vector<float> min_lrs; 57 double eps; 58 float best; 59 bool verbose; 60 int last_epoch; 61 int num_bad_epochs; 62 }; 63 } // namespace optim 64 } // namespace torch 65