1# mypy: allow-untyped-defs 2r"""Learning Rate Scheduler.""" 3import math 4import types 5import warnings 6from bisect import bisect_right 7from collections import Counter 8from functools import partial, wraps 9from typing import ( 10 Any, 11 Callable, 12 cast, 13 Dict, 14 Iterable, 15 List, 16 Literal, 17 Optional, 18 Sequence, 19 SupportsFloat, 20 TypedDict, 21 Union, 22) 23from weakref import ref 24 25from torch import inf, Tensor 26 27from .optimizer import Optimizer 28 29 30__all__ = [ 31 "LambdaLR", 32 "MultiplicativeLR", 33 "StepLR", 34 "MultiStepLR", 35 "ConstantLR", 36 "LinearLR", 37 "ExponentialLR", 38 "SequentialLR", 39 "CosineAnnealingLR", 40 "ChainedScheduler", 41 "ReduceLROnPlateau", 42 "CyclicLR", 43 "CosineAnnealingWarmRestarts", 44 "OneCycleLR", 45 "PolynomialLR", 46 "LRScheduler", 47] 48 49EPOCH_DEPRECATION_WARNING = ( 50 "The epoch parameter in `scheduler.step()` was not necessary and is being " 51 "deprecated where possible. Please use `scheduler.step()` to step the " 52 "scheduler. During the deprecation, if epoch is different from None, the " 53 "closed form is used instead of the new chainable form, where available. " 54 "Please open an issue if you are unable to replicate your use case: " 55 "https://github.com/pytorch/pytorch/issues/new/choose." 56) 57 58 59def _check_verbose_deprecated_warning(verbose): 60 """Raise a warning when verbose is not the default value.""" 61 if verbose != "deprecated": 62 warnings.warn( 63 "The verbose parameter is deprecated. Please use get_last_lr() " 64 "to access the learning rate.", 65 UserWarning, 66 ) 67 return verbose 68 return False 69 70 71def _format_param(name: str, optimizer: Optimizer, param): 72 """Return correctly formatted lr/momentum for each param group.""" 73 74 def _copy(_param): 75 return _param.clone() if isinstance(_param, Tensor) else _param 76 77 if isinstance(param, (list, tuple)): 78 if len(param) != len(optimizer.param_groups): 79 raise ValueError( 80 f"{name} must have the same length as optimizer.param_groups. " 81 f"{name} has {len(param)} values, param_groups has {len(optimizer.param_groups)}." 82 ) 83 else: 84 param = [param] * len(optimizer.param_groups) 85 86 return list(map(_copy, param)) 87 88 89class LRScheduler: 90 r"""Adjusts the learning rate during optimization.""" 91 92 _get_lr_called_within_step: bool = False 93 94 def __init__( 95 self, optimizer: Optimizer, last_epoch=-1, verbose="deprecated" 96 ): # noqa: D107 97 # Attach optimizer 98 if not isinstance(optimizer, Optimizer): 99 raise TypeError(f"{type(optimizer).__name__} is not an Optimizer") 100 self.optimizer = optimizer 101 102 # Initialize epoch and base learning rates 103 if last_epoch == -1: 104 for group in optimizer.param_groups: 105 initial_lr = group["lr"] 106 if isinstance(initial_lr, Tensor): 107 initial_lr = initial_lr.clone() 108 group.setdefault("initial_lr", initial_lr) 109 else: 110 for i, group in enumerate(optimizer.param_groups): 111 if "initial_lr" not in group: 112 raise KeyError( 113 "param 'initial_lr' is not specified " 114 f"in param_groups[{i}] when resuming an optimizer" 115 ) 116 self.base_lrs: List[float] = [ 117 group["initial_lr"] for group in optimizer.param_groups 118 ] 119 self.last_epoch = last_epoch 120 121 # Following https://github.com/pytorch/pytorch/issues/20124 122 # We would like to ensure that `lr_scheduler.step()` is called after 123 # `optimizer.step()` 124 def patch_track_step_called(opt: Optimizer): 125 if hasattr(opt.step, "_wrapped_by_lr_sched"): 126 # we've already patched 127 return opt.step 128 129 def wrap_step(step_fn): 130 opt_ref = ref(self.optimizer) 131 func = step_fn.__func__ 132 133 @wraps(func) 134 def wrapper(*args, **kwargs): 135 opt = opt_ref() 136 opt._opt_called = True # type: ignore[union-attr] 137 return func.__get__(opt, opt.__class__)(*args, **kwargs) 138 139 wrapper._wrapped_by_lr_sched = True # type: ignore[attr-defined] 140 return wrapper 141 142 opt.step = wrap_step(opt.step) # type: ignore[method-assign] 143 144 patch_track_step_called(self.optimizer) 145 self.verbose = _check_verbose_deprecated_warning(verbose) 146 self._initial_step() 147 148 def _initial_step(self): 149 """Initialize step counts and perform a step.""" 150 self._step_count = 0 151 self.step() 152 153 def state_dict(self): 154 """Return the state of the scheduler as a :class:`dict`. 155 156 It contains an entry for every variable in self.__dict__ which 157 is not the optimizer. 158 """ 159 return { 160 key: value for key, value in self.__dict__.items() if key != "optimizer" 161 } 162 163 def load_state_dict(self, state_dict: Dict[str, Any]): 164 """Load the scheduler's state. 165 166 Args: 167 state_dict (dict): scheduler state. Should be an object returned 168 from a call to :meth:`state_dict`. 169 """ 170 self.__dict__.update(state_dict) 171 172 def get_last_lr(self) -> List[float]: 173 """Return last computed learning rate by current scheduler.""" 174 return self._last_lr 175 176 def get_lr(self) -> List[float]: 177 """Compute learning rate using chainable form of the scheduler.""" 178 raise NotImplementedError 179 180 def print_lr( 181 self, 182 is_verbose: bool, 183 group: Dict[str, Any], 184 lr: float, 185 epoch: Optional[int] = None, 186 ): 187 """Display the current learning rate. 188 189 .. deprecated:: 2.4 190 ``print_lr()`` is deprecated. Please use ``get_last_lr()`` to access the 191 learning rate. 192 """ 193 warnings.warn( 194 "`LRScheduler.print_lr()` is being deprecated. To fetch the learning rate, " 195 "please use `get_last_lr()` instead. For more details, " 196 "see https://github.com/pytorch/pytorch/issues/99270.", 197 UserWarning, 198 ) 199 if is_verbose: 200 if epoch is None: 201 print(f"Adjusting learning rate of group {group} to {lr:.4e}.") 202 else: 203 epoch_str = ("%.2f" if isinstance(epoch, float) else "%.5d") % epoch 204 print( 205 f"Epoch {epoch_str}: adjusting learning rate of group {group} to {lr:.4e}." 206 ) 207 208 def step(self, epoch: Optional[int] = None): 209 """Perform a step.""" 210 # Raise a warning if old pattern is detected 211 # https://github.com/pytorch/pytorch/issues/20124 212 if self._step_count == 1: 213 if not hasattr(self.optimizer.step, "_wrapped_by_lr_sched"): 214 warnings.warn( 215 "Seems like `optimizer.step()` has been overridden after learning rate scheduler " 216 "initialization. Please, make sure to call `optimizer.step()` before " 217 "`lr_scheduler.step()`. See more details at " 218 "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", 219 UserWarning, 220 ) 221 222 # Just check if there were two first lr_scheduler.step() calls before optimizer.step() 223 elif not getattr(self.optimizer, "_opt_called", False): 224 warnings.warn( 225 "Detected call of `lr_scheduler.step()` before `optimizer.step()`. " 226 "In PyTorch 1.1.0 and later, you should call them in the opposite order: " 227 "`optimizer.step()` before `lr_scheduler.step()`. Failure to do this " 228 "will result in PyTorch skipping the first value of the learning rate schedule. " 229 "See more details at " 230 "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", 231 UserWarning, 232 ) 233 self._step_count += 1 234 235 with _enable_get_lr_call(self): 236 if epoch is None: 237 self.last_epoch += 1 238 values = self.get_lr() 239 else: 240 warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning) 241 self.last_epoch = epoch 242 if hasattr(self, "_get_closed_form_lr"): 243 values = cast(List[float], self._get_closed_form_lr()) 244 else: 245 values = self.get_lr() 246 247 for i, data in enumerate(zip(self.optimizer.param_groups, values)): 248 param_group, lr = data 249 if isinstance(param_group["lr"], Tensor): 250 param_group["lr"].fill_(lr) 251 else: 252 param_group["lr"] = lr 253 254 self._last_lr: List[float] = [ 255 group["lr"] for group in self.optimizer.param_groups 256 ] 257 258 259def _warn_get_lr_called_within_step(lr_scheduler: LRScheduler): 260 if not lr_scheduler._get_lr_called_within_step: 261 warnings.warn( 262 "To get the last learning rate computed by the scheduler, " 263 "please use `get_last_lr()`.", 264 UserWarning, 265 stacklevel=2, 266 ) 267 268 269# Including _LRScheduler for backwards compatibility 270# Subclass instead of assign because we want __name__ of _LRScheduler to be _LRScheduler (assigning would make it LRScheduler). 271class _LRScheduler(LRScheduler): 272 pass 273 274 275class _enable_get_lr_call: 276 def __init__(self, o: LRScheduler): 277 self.o = o 278 279 def __enter__(self): 280 self.o._get_lr_called_within_step = True 281 return self 282 283 def __exit__(self, type, value, traceback): 284 self.o._get_lr_called_within_step = False 285 286 287class LambdaLR(LRScheduler): 288 """Sets the initial learning rate. 289 290 The learning rate of each parameter group is set to the initial lr 291 times a given function. When last_epoch=-1, sets initial lr as lr. 292 293 Args: 294 optimizer (Optimizer): Wrapped optimizer. 295 lr_lambda (function or list): A function which computes a multiplicative 296 factor given an integer parameter epoch, or a list of such 297 functions, one for each group in optimizer.param_groups. 298 last_epoch (int): The index of last epoch. Default: -1. 299 verbose (bool | str): If ``True``, prints a message to stdout for 300 each update. Default: ``False``. 301 302 .. deprecated:: 2.2 303 ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the 304 learning rate. 305 306 Example: 307 >>> # xdoctest: +SKIP 308 >>> # Assuming optimizer has two groups. 309 >>> lambda1 = lambda epoch: epoch // 30 310 >>> lambda2 = lambda epoch: 0.95 ** epoch 311 >>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2]) 312 >>> for epoch in range(100): 313 >>> train(...) 314 >>> validate(...) 315 >>> scheduler.step() 316 """ 317 318 def __init__( 319 self, 320 optimizer: Optimizer, 321 lr_lambda: Union[Callable[[int], float], List[Callable[[int], float]]], 322 last_epoch=-1, 323 verbose="deprecated", 324 ): # noqa: D107 325 self.optimizer = optimizer 326 327 self.lr_lambdas: List[Callable[[int], float]] 328 if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple): 329 self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups) 330 else: 331 if len(lr_lambda) != len(optimizer.param_groups): 332 raise ValueError( 333 f"Expected {len(optimizer.param_groups)} lr_lambdas, but got {len(lr_lambda)}" 334 ) 335 self.lr_lambdas = list(lr_lambda) 336 super().__init__(optimizer, last_epoch, verbose) 337 338 def state_dict(self): 339 """Return the state of the scheduler as a :class:`dict`. 340 341 It contains an entry for every variable in self.__dict__ which 342 is not the optimizer. 343 The learning rate lambda functions will only be saved if they are callable objects 344 and not if they are functions or lambdas. 345 346 When saving or loading the scheduler, please make sure to also save or load the state of the optimizer. 347 """ 348 state_dict = { 349 key: value 350 for key, value in self.__dict__.items() 351 if key not in ("optimizer", "lr_lambdas") 352 } 353 state_dict["lr_lambdas"] = [None] * len(self.lr_lambdas) 354 355 for idx, fn in enumerate(self.lr_lambdas): 356 if not isinstance(fn, types.FunctionType): 357 state_dict["lr_lambdas"][idx] = fn.__dict__.copy() 358 359 return state_dict 360 361 def load_state_dict(self, state_dict): 362 """Load the scheduler's state. 363 364 When saving or loading the scheduler, please make sure to also save or load the state of the optimizer. 365 366 Args: 367 state_dict (dict): scheduler state. Should be an object returned 368 from a call to :meth:`state_dict`. 369 """ 370 lr_lambdas = state_dict.pop("lr_lambdas") 371 self.__dict__.update(state_dict) 372 # Restore state_dict keys in order to prevent side effects 373 # https://github.com/pytorch/pytorch/issues/32756 374 state_dict["lr_lambdas"] = lr_lambdas 375 376 for idx, fn in enumerate(lr_lambdas): 377 if fn is not None: 378 self.lr_lambdas[idx].__dict__.update(fn) 379 380 def get_lr(self): 381 """Compute learning rate.""" 382 _warn_get_lr_called_within_step(self) 383 384 return [ 385 base_lr * lmbda(self.last_epoch) 386 for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs) 387 ] 388 389 390class MultiplicativeLR(LRScheduler): 391 """Multiply the learning rate of each parameter group by the factor given in the specified function. 392 393 When last_epoch=-1, set initial lr as lr. 394 395 Args: 396 optimizer (Optimizer): Wrapped optimizer. 397 lr_lambda (function or list): A function which computes a multiplicative 398 factor given an integer parameter epoch, or a list of such 399 functions, one for each group in optimizer.param_groups. 400 last_epoch (int): The index of last epoch. Default: -1. 401 verbose (bool | str): If ``True``, prints a message to stdout for 402 each update. Default: ``False``. 403 404 .. deprecated:: 2.2 405 ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the 406 learning rate. 407 408 Example: 409 >>> # xdoctest: +SKIP 410 >>> lmbda = lambda epoch: 0.95 411 >>> scheduler = MultiplicativeLR(optimizer, lr_lambda=lmbda) 412 >>> for epoch in range(100): 413 >>> train(...) 414 >>> validate(...) 415 >>> scheduler.step() 416 """ 417 418 def __init__( 419 self, 420 optimizer: Optimizer, 421 lr_lambda: Union[Callable[[int], float], List[Callable[[int], float]]], 422 last_epoch=-1, 423 verbose="deprecated", 424 ): # noqa: D107 425 self.optimizer = optimizer 426 427 self.lr_lambdas: List[Callable[[int], float]] 428 if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple): 429 self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups) 430 else: 431 if len(lr_lambda) != len(optimizer.param_groups): 432 raise ValueError( 433 f"Expected {len(optimizer.param_groups)} lr_lambdas, but got {len(lr_lambda)}" 434 ) 435 self.lr_lambdas = list(lr_lambda) 436 super().__init__(optimizer, last_epoch, verbose) 437 438 def state_dict(self): 439 """Return the state of the scheduler as a :class:`dict`. 440 441 It contains an entry for every variable in self.__dict__ which 442 is not the optimizer. 443 The learning rate lambda functions will only be saved if they are callable objects 444 and not if they are functions or lambdas. 445 """ 446 state_dict = { 447 key: value 448 for key, value in self.__dict__.items() 449 if key not in ("optimizer", "lr_lambdas") 450 } 451 state_dict["lr_lambdas"] = [None] * len(self.lr_lambdas) 452 453 for idx, fn in enumerate(self.lr_lambdas): 454 if not isinstance(fn, types.FunctionType): 455 state_dict["lr_lambdas"][idx] = fn.__dict__.copy() 456 457 return state_dict 458 459 def load_state_dict(self, state_dict): 460 """Load the scheduler's state. 461 462 Args: 463 state_dict (dict): scheduler state. Should be an object returned 464 from a call to :meth:`state_dict`. 465 """ 466 lr_lambdas = state_dict.pop("lr_lambdas") 467 self.__dict__.update(state_dict) 468 # Restore state_dict keys in order to prevent side effects 469 # https://github.com/pytorch/pytorch/issues/32756 470 state_dict["lr_lambdas"] = lr_lambdas 471 472 for idx, fn in enumerate(lr_lambdas): 473 if fn is not None: 474 self.lr_lambdas[idx].__dict__.update(fn) 475 476 def get_lr(self): 477 """Compute the learning rate of each parameter group.""" 478 _warn_get_lr_called_within_step(self) 479 480 if self.last_epoch > 0: 481 return [ 482 group["lr"] * lmbda(self.last_epoch) 483 for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups) 484 ] 485 else: 486 return [group["lr"] for group in self.optimizer.param_groups] 487 488 489class StepLR(LRScheduler): 490 """Decays the learning rate of each parameter group by gamma every step_size epochs. 491 492 Notice that such decay can happen simultaneously with other changes to the learning rate 493 from outside this scheduler. When last_epoch=-1, sets initial lr as lr. 494 495 Args: 496 optimizer (Optimizer): Wrapped optimizer. 497 step_size (int): Period of learning rate decay. 498 gamma (float): Multiplicative factor of learning rate decay. 499 Default: 0.1. 500 last_epoch (int): The index of last epoch. Default: -1. 501 verbose (bool | str): If ``True``, prints a message to stdout for 502 each update. Default: ``False``. 503 504 .. deprecated:: 2.2 505 ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the 506 learning rate. 507 508 Example: 509 >>> # xdoctest: +SKIP 510 >>> # Assuming optimizer uses lr = 0.05 for all groups 511 >>> # lr = 0.05 if epoch < 30 512 >>> # lr = 0.005 if 30 <= epoch < 60 513 >>> # lr = 0.0005 if 60 <= epoch < 90 514 >>> # ... 515 >>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1) 516 >>> for epoch in range(100): 517 >>> train(...) 518 >>> validate(...) 519 >>> scheduler.step() 520 """ 521 522 def __init__( 523 self, 524 optimizer: Optimizer, 525 step_size: int, 526 gamma=0.1, 527 last_epoch=-1, 528 verbose="deprecated", 529 ): # noqa: D107 530 self.step_size = step_size 531 self.gamma = gamma 532 super().__init__(optimizer, last_epoch, verbose) 533 534 def get_lr(self): 535 """Compute the learning rate of each parameter group.""" 536 _warn_get_lr_called_within_step(self) 537 538 if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0): 539 return [group["lr"] for group in self.optimizer.param_groups] 540 return [group["lr"] * self.gamma for group in self.optimizer.param_groups] 541 542 def _get_closed_form_lr(self): 543 return [ 544 base_lr * self.gamma ** (self.last_epoch // self.step_size) 545 for base_lr in self.base_lrs 546 ] 547 548 549class MultiStepLR(LRScheduler): 550 """Decays the learning rate of each parameter group by gamma once the number of epoch reaches one of the milestones. 551 552 Notice that such decay can happen simultaneously with other changes to the learning rate 553 from outside this scheduler. When last_epoch=-1, sets initial lr as lr. 554 555 Args: 556 optimizer (Optimizer): Wrapped optimizer. 557 milestones (list): List of epoch indices. Must be increasing. 558 gamma (float): Multiplicative factor of learning rate decay. 559 Default: 0.1. 560 last_epoch (int): The index of last epoch. Default: -1. 561 verbose (bool | str): If ``True``, prints a message to stdout for 562 each update. Default: ``False``. 563 564 .. deprecated:: 2.2 565 ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the 566 learning rate. 567 568 Example: 569 >>> # xdoctest: +SKIP 570 >>> # Assuming optimizer uses lr = 0.05 for all groups 571 >>> # lr = 0.05 if epoch < 30 572 >>> # lr = 0.005 if 30 <= epoch < 80 573 >>> # lr = 0.0005 if epoch >= 80 574 >>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1) 575 >>> for epoch in range(100): 576 >>> train(...) 577 >>> validate(...) 578 >>> scheduler.step() 579 """ 580 581 def __init__( 582 self, 583 optimizer: Optimizer, 584 milestones: Iterable[int], 585 gamma=0.1, 586 last_epoch=-1, 587 verbose="deprecated", 588 ): # noqa: D107 589 self.milestones = Counter(milestones) 590 self.gamma = gamma 591 super().__init__(optimizer, last_epoch, verbose) 592 593 def get_lr(self): 594 """Compute the learning rate of each parameter group.""" 595 _warn_get_lr_called_within_step(self) 596 597 if self.last_epoch not in self.milestones: 598 return [group["lr"] for group in self.optimizer.param_groups] 599 return [ 600 group["lr"] * self.gamma ** self.milestones[self.last_epoch] 601 for group in self.optimizer.param_groups 602 ] 603 604 def _get_closed_form_lr(self): 605 milestones = sorted(self.milestones.elements()) 606 return [ 607 base_lr * self.gamma ** bisect_right(milestones, self.last_epoch) 608 for base_lr in self.base_lrs 609 ] 610 611 612class ConstantLR(LRScheduler): 613 """Multiply the learning rate of each parameter group by a small constant factor. 614 615 The multiplication is done until the number of epoch reaches a pre-defined milestone: total_iters. 616 Notice that such multiplication of the small constant factor can 617 happen simultaneously with other changes to the learning rate from outside this scheduler. 618 When last_epoch=-1, sets initial lr as lr. 619 620 Args: 621 optimizer (Optimizer): Wrapped optimizer. 622 factor (float): The number we multiply learning rate until the milestone. Default: 1./3. 623 total_iters (int): The number of steps that the scheduler multiplies the learning rate by the factor. 624 Default: 5. 625 last_epoch (int): The index of the last epoch. Default: -1. 626 verbose (bool | str): If ``True``, prints a message to stdout for 627 each update. Default: ``False``. 628 629 .. deprecated:: 2.2 630 ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the 631 learning rate. 632 633 Example: 634 >>> # xdoctest: +SKIP 635 >>> # Assuming optimizer uses lr = 0.05 for all groups 636 >>> # lr = 0.025 if epoch == 0 637 >>> # lr = 0.025 if epoch == 1 638 >>> # lr = 0.025 if epoch == 2 639 >>> # lr = 0.025 if epoch == 3 640 >>> # lr = 0.05 if epoch >= 4 641 >>> scheduler = ConstantLR(optimizer, factor=0.5, total_iters=4) 642 >>> for epoch in range(100): 643 >>> train(...) 644 >>> validate(...) 645 >>> scheduler.step() 646 """ 647 648 def __init__( 649 self, 650 optimizer: Optimizer, 651 factor=1.0 / 3, 652 total_iters=5, 653 last_epoch=-1, 654 verbose="deprecated", 655 ): # noqa: D107 656 if factor > 1.0 or factor < 0: 657 raise ValueError( 658 "Constant multiplicative factor expected to be between 0 and 1." 659 ) 660 661 self.factor = factor 662 self.total_iters = total_iters 663 super().__init__(optimizer, last_epoch, verbose) 664 665 def get_lr(self): 666 """Compute the learning rate of each parameter group.""" 667 _warn_get_lr_called_within_step(self) 668 669 if self.last_epoch == 0: 670 return [group["lr"] * self.factor for group in self.optimizer.param_groups] 671 672 if self.last_epoch != self.total_iters: 673 return [group["lr"] for group in self.optimizer.param_groups] 674 675 return [ 676 group["lr"] * (1.0 / self.factor) for group in self.optimizer.param_groups 677 ] 678 679 def _get_closed_form_lr(self): 680 return [ 681 base_lr 682 * (self.factor + (self.last_epoch >= self.total_iters) * (1 - self.factor)) 683 for base_lr in self.base_lrs 684 ] 685 686 687class LinearLR(LRScheduler): 688 """Decays the learning rate of each parameter group by linearly changing small multiplicative factor. 689 690 The multiplication is done until the number of epoch reaches a pre-defined milestone: total_iters. 691 Notice that such decay can happen simultaneously with other changes to the learning rate 692 from outside this scheduler. When last_epoch=-1, sets initial lr as lr. 693 694 Args: 695 optimizer (Optimizer): Wrapped optimizer. 696 start_factor (float): The number we multiply learning rate in the first epoch. 697 The multiplication factor changes towards end_factor in the following epochs. 698 Default: 1./3. 699 end_factor (float): The number we multiply learning rate at the end of linear changing 700 process. Default: 1.0. 701 total_iters (int): The number of iterations that multiplicative factor reaches to 1. 702 Default: 5. 703 last_epoch (int): The index of the last epoch. Default: -1. 704 verbose (bool | str): If ``True``, prints a message to stdout for 705 each update. Default: ``False``. 706 707 .. deprecated:: 2.2 708 ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the 709 learning rate. 710 711 Example: 712 >>> # xdoctest: +SKIP 713 >>> # Assuming optimizer uses lr = 0.05 for all groups 714 >>> # lr = 0.025 if epoch == 0 715 >>> # lr = 0.03125 if epoch == 1 716 >>> # lr = 0.0375 if epoch == 2 717 >>> # lr = 0.04375 if epoch == 3 718 >>> # lr = 0.05 if epoch >= 4 719 >>> scheduler = LinearLR(optimizer, start_factor=0.5, total_iters=4) 720 >>> for epoch in range(100): 721 >>> train(...) 722 >>> validate(...) 723 >>> scheduler.step() 724 """ 725 726 def __init__( 727 self, 728 optimizer: Optimizer, 729 start_factor=1.0 / 3, 730 end_factor=1.0, 731 total_iters=5, 732 last_epoch=-1, 733 verbose="deprecated", 734 ): # noqa: D107 735 if start_factor > 1.0 or start_factor <= 0: 736 raise ValueError( 737 "Starting multiplicative factor expected to be greater than 0 and less or equal to 1." 738 ) 739 740 if end_factor > 1.0 or end_factor < 0: 741 raise ValueError( 742 "Ending multiplicative factor expected to be between 0 and 1." 743 ) 744 745 self.start_factor = start_factor 746 self.end_factor = end_factor 747 self.total_iters = total_iters 748 super().__init__(optimizer, last_epoch, verbose) 749 750 def get_lr(self): 751 """Compute the learning rate.""" 752 _warn_get_lr_called_within_step(self) 753 754 if self.last_epoch == 0: 755 return [ 756 group["lr"] * self.start_factor for group in self.optimizer.param_groups 757 ] 758 759 if self.last_epoch > self.total_iters: 760 return [group["lr"] for group in self.optimizer.param_groups] 761 762 return [ 763 group["lr"] 764 * ( 765 1.0 766 + (self.end_factor - self.start_factor) 767 / ( 768 self.total_iters * self.start_factor 769 + (self.last_epoch - 1) * (self.end_factor - self.start_factor) 770 ) 771 ) 772 for group in self.optimizer.param_groups 773 ] 774 775 def _get_closed_form_lr(self): 776 return [ 777 base_lr 778 * ( 779 self.start_factor 780 + (self.end_factor - self.start_factor) 781 * min(self.total_iters, self.last_epoch) 782 / self.total_iters 783 ) 784 for base_lr in self.base_lrs 785 ] 786 787 788class ExponentialLR(LRScheduler): 789 """Decays the learning rate of each parameter group by gamma every epoch. 790 791 When last_epoch=-1, sets initial lr as lr. 792 793 Args: 794 optimizer (Optimizer): Wrapped optimizer. 795 gamma (float): Multiplicative factor of learning rate decay. 796 last_epoch (int): The index of last epoch. Default: -1. 797 verbose (bool | str): If ``True``, prints a message to stdout for 798 each update. Default: ``False``. 799 800 .. deprecated:: 2.2 801 ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the 802 learning rate. 803 """ 804 805 def __init__( 806 self, optimizer: Optimizer, gamma: float, last_epoch=-1, verbose="deprecated" 807 ): # noqa: D107 808 self.gamma = gamma 809 super().__init__(optimizer, last_epoch, verbose) 810 811 def get_lr(self): 812 """Compute the learning rate of each parameter group.""" 813 _warn_get_lr_called_within_step(self) 814 815 if self.last_epoch == 0: 816 return [group["lr"] for group in self.optimizer.param_groups] 817 return [group["lr"] * self.gamma for group in self.optimizer.param_groups] 818 819 def _get_closed_form_lr(self): 820 return [base_lr * self.gamma**self.last_epoch for base_lr in self.base_lrs] 821 822 823class SequentialLR(LRScheduler): 824 """Contains a list of schedulers expected to be called sequentially during the optimization process. 825 826 Specifically, the schedulers will be called according to the milestone points, which should provide exact 827 intervals by which each scheduler should be called at a given epoch. 828 829 Args: 830 optimizer (Optimizer): Wrapped optimizer. 831 schedulers (list): List of chained schedulers. 832 milestones (list): List of integers that reflects milestone points. 833 last_epoch (int): The index of last epoch. Default: -1. 834 verbose (bool | str): Does nothing. 835 836 .. deprecated:: 2.2 837 ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the 838 learning rate. 839 840 Example: 841 >>> # xdoctest: +SKIP 842 >>> # Assuming optimizer uses lr = 1. for all groups 843 >>> # lr = 0.1 if epoch == 0 844 >>> # lr = 0.1 if epoch == 1 845 >>> # lr = 0.9 if epoch == 2 846 >>> # lr = 0.81 if epoch == 3 847 >>> # lr = 0.729 if epoch == 4 848 >>> scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=2) 849 >>> scheduler2 = ExponentialLR(optimizer, gamma=0.9) 850 >>> scheduler = SequentialLR(optimizer, schedulers=[scheduler1, scheduler2], milestones=[2]) 851 >>> for epoch in range(100): 852 >>> train(...) 853 >>> validate(...) 854 >>> scheduler.step() 855 """ 856 857 def __init__( 858 self, 859 optimizer: Optimizer, 860 schedulers: List[LRScheduler], 861 milestones: List[int], 862 last_epoch=-1, 863 verbose="deprecated", 864 ): # noqa: D107 865 if len(schedulers) < 1: 866 raise ValueError( 867 f"{self.__class__.__name__} expects at least one scheduler, but got no scheduler." 868 ) 869 870 for scheduler_idx, scheduler in enumerate(schedulers): 871 if not hasattr(scheduler, "optimizer"): 872 raise TypeError( 873 f"{self.__class__.__name__} at index {scheduler_idx} should have `optimizer` as its attribute." 874 ) 875 if isinstance(scheduler, ReduceLROnPlateau): 876 raise ValueError( 877 f"{self.__class__.__name__} does not support `ReduceLROnPlateau` scheduler as it " 878 "requires additional kwargs to be specified when calling `step`, " 879 f"but got one at index {scheduler_idx} in the given schedulers sequence." 880 ) 881 if optimizer != scheduler.optimizer: 882 raise ValueError( 883 f"{self.__class__.__name__} expects all schedulers to belong to the same optimizer, but " 884 f"got scheduler {scheduler.__class__.__name__} at index {scheduler_idx} has {scheduler.optimizer}, " 885 f"which is different from {optimizer.__class__.__name__}." 886 ) 887 888 if len(milestones) != len(schedulers) - 1: 889 raise ValueError( 890 "Sequential Schedulers expects number of schedulers provided to be one more " 891 f"than the number of milestone points, but got number of schedulers {len(schedulers)} and the " 892 f"number of milestones to be equal to {len(milestones)}" 893 ) 894 _check_verbose_deprecated_warning(verbose) 895 self._schedulers = schedulers 896 self._milestones = milestones 897 self.last_epoch = last_epoch + 1 898 self.optimizer = optimizer 899 900 # Reset learning rates back to initial values 901 for group in self.optimizer.param_groups: 902 group["lr"] = group["initial_lr"] 903 904 # "Undo" the step performed by other schedulers 905 for scheduler in self._schedulers: 906 scheduler.last_epoch -= 1 907 908 # Perform the initial step for only the first scheduler 909 self._schedulers[0]._initial_step() 910 911 self._last_lr = schedulers[0].get_last_lr() 912 913 def step(self): 914 """Perform a step.""" 915 self.last_epoch += 1 916 idx = bisect_right(self._milestones, self.last_epoch) 917 scheduler = self._schedulers[idx] 918 if idx > 0 and self._milestones[idx - 1] == self.last_epoch: 919 scheduler.step(0) 920 else: 921 scheduler.step() 922 923 self._last_lr = scheduler.get_last_lr() 924 925 def state_dict(self): 926 """Return the state of the scheduler as a :class:`dict`. 927 928 It contains an entry for every variable in self.__dict__ which 929 is not the optimizer. 930 The wrapped scheduler states will also be saved. 931 """ 932 state_dict = { 933 key: value 934 for key, value in self.__dict__.items() 935 if key not in ("optimizer", "_schedulers") 936 } 937 state_dict["_schedulers"] = [None] * len(self._schedulers) 938 939 for idx, s in enumerate(self._schedulers): 940 state_dict["_schedulers"][idx] = s.state_dict() 941 942 return state_dict 943 944 def load_state_dict(self, state_dict): 945 """Load the scheduler's state. 946 947 Args: 948 state_dict (dict): scheduler state. Should be an object returned 949 from a call to :meth:`state_dict`. 950 """ 951 _schedulers = state_dict.pop("_schedulers") 952 self.__dict__.update(state_dict) 953 # Restore state_dict keys in order to prevent side effects 954 # https://github.com/pytorch/pytorch/issues/32756 955 state_dict["_schedulers"] = _schedulers 956 957 for idx, s in enumerate(_schedulers): 958 self._schedulers[idx].load_state_dict(s) 959 960 961class PolynomialLR(LRScheduler): 962 """Decays the learning rate of each parameter group using a polynomial function in the given total_iters. 963 964 When last_epoch=-1, sets initial lr as lr. 965 966 Args: 967 optimizer (Optimizer): Wrapped optimizer. 968 total_iters (int): The number of steps that the scheduler decays the learning rate. Default: 5. 969 power (float): The power of the polynomial. Default: 1.0. 970 verbose (bool | str): If ``True``, prints a message to stdout for 971 each update. Default: ``False``. 972 973 .. deprecated:: 2.2 974 ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the 975 learning rate. 976 977 Example: 978 >>> # xdoctest: +SKIP("undefined vars") 979 >>> # Assuming optimizer uses lr = 0.001 for all groups 980 >>> # lr = 0.001 if epoch == 0 981 >>> # lr = 0.00075 if epoch == 1 982 >>> # lr = 0.00050 if epoch == 2 983 >>> # lr = 0.00025 if epoch == 3 984 >>> # lr = 0.0 if epoch >= 4 985 >>> scheduler = PolynomialLR(optimizer, total_iters=4, power=1.0) 986 >>> for epoch in range(100): 987 >>> train(...) 988 >>> validate(...) 989 >>> scheduler.step() 990 """ 991 992 def __init__( 993 self, 994 optimizer: Optimizer, 995 total_iters=5, 996 power=1.0, 997 last_epoch=-1, 998 verbose="deprecated", 999 ): # noqa: D107 1000 self.total_iters = total_iters 1001 self.power = power 1002 super().__init__(optimizer, last_epoch, verbose) 1003 1004 def get_lr(self): 1005 """Compute the learning rate.""" 1006 _warn_get_lr_called_within_step(self) 1007 1008 if self.last_epoch == 0 or self.last_epoch > self.total_iters: 1009 return [group["lr"] for group in self.optimizer.param_groups] 1010 1011 decay_factor = ( 1012 (1.0 - self.last_epoch / self.total_iters) 1013 / (1.0 - (self.last_epoch - 1) / self.total_iters) 1014 ) ** self.power 1015 return [group["lr"] * decay_factor for group in self.optimizer.param_groups] 1016 1017 def _get_closed_form_lr(self): 1018 return [ 1019 ( 1020 base_lr 1021 * (1.0 - min(self.total_iters, self.last_epoch) / self.total_iters) 1022 ** self.power 1023 ) 1024 for base_lr in self.base_lrs 1025 ] 1026 1027 1028class CosineAnnealingLR(LRScheduler): 1029 r"""Set the learning rate of each parameter group using a cosine annealing schedule. 1030 1031 The :math:`\eta_{max}` is set to the initial lr and 1032 :math:`T_{cur}` is the number of epochs since the last restart in SGDR: 1033 1034 .. math:: 1035 \begin{aligned} 1036 \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 1037 + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right), 1038 & T_{cur} \neq (2k+1)T_{max}; \\ 1039 \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min}) 1040 \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right), 1041 & T_{cur} = (2k+1)T_{max}. 1042 \end{aligned} 1043 1044 When last_epoch=-1, sets initial lr as lr. Notice that because the schedule 1045 is defined recursively, the learning rate can be simultaneously modified 1046 outside this scheduler by other operators. If the learning rate is set 1047 solely by this scheduler, the learning rate at each step becomes: 1048 1049 .. math:: 1050 \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + 1051 \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right) 1052 1053 It has been proposed in 1054 `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only 1055 implements the cosine annealing part of SGDR, and not the restarts. 1056 1057 Args: 1058 optimizer (Optimizer): Wrapped optimizer. 1059 T_max (int): Maximum number of iterations. 1060 eta_min (float): Minimum learning rate. Default: 0. 1061 last_epoch (int): The index of last epoch. Default: -1. 1062 verbose (bool | str): If ``True``, prints a message to stdout for 1063 each update. Default: ``False``. 1064 1065 .. deprecated:: 2.2 1066 ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the 1067 learning rate. 1068 1069 .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: 1070 https://arxiv.org/abs/1608.03983 1071 """ 1072 1073 def __init__( 1074 self, 1075 optimizer: Optimizer, 1076 T_max: int, 1077 eta_min=0.0, 1078 last_epoch=-1, 1079 verbose="deprecated", 1080 ): # noqa: D107 1081 self.T_max = T_max 1082 self.eta_min = eta_min 1083 super().__init__(optimizer, last_epoch, verbose) 1084 1085 def get_lr(self): 1086 """Retrieve the learning rate of each parameter group.""" 1087 _warn_get_lr_called_within_step(self) 1088 1089 if self.last_epoch == 0: 1090 return [group["lr"] for group in self.optimizer.param_groups] 1091 elif self._step_count == 1 and self.last_epoch > 0: 1092 return [ 1093 self.eta_min 1094 + (base_lr - self.eta_min) 1095 * (1 + math.cos((self.last_epoch) * math.pi / self.T_max)) 1096 / 2 1097 for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 1098 ] 1099 elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0: 1100 return [ 1101 group["lr"] 1102 + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2 1103 for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 1104 ] 1105 return [ 1106 (1 + math.cos(math.pi * self.last_epoch / self.T_max)) 1107 / (1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) 1108 * (group["lr"] - self.eta_min) 1109 + self.eta_min 1110 for group in self.optimizer.param_groups 1111 ] 1112 1113 def _get_closed_form_lr(self): 1114 return [ 1115 self.eta_min 1116 + (base_lr - self.eta_min) 1117 * (1 + math.cos(math.pi * self.last_epoch / self.T_max)) 1118 / 2 1119 for base_lr in self.base_lrs 1120 ] 1121 1122 1123class ChainedScheduler(LRScheduler): 1124 """Chains a list of learning rate schedulers. 1125 1126 Takes in a sequence of chainable learning rate schedulers and calls their 1127 step() functions consecutively in just one call to step(). 1128 1129 Args: 1130 schedulers (sequence): sequence of chained schedulers. 1131 optimizer (Optimizer, optional): Wrapped optimizer. Default: None. 1132 1133 Example: 1134 >>> # xdoctest: +SKIP 1135 >>> # Assuming optimizer uses lr = 1. for all groups 1136 >>> # lr = 0.09 if epoch == 0 1137 >>> # lr = 0.081 if epoch == 1 1138 >>> # lr = 0.729 if epoch == 2 1139 >>> # lr = 0.6561 if epoch == 3 1140 >>> # lr = 0.59049 if epoch >= 4 1141 >>> scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=2) 1142 >>> scheduler2 = ExponentialLR(optimizer, gamma=0.9) 1143 >>> scheduler = ChainedScheduler([scheduler1, scheduler2], optimizer=optimizer) 1144 >>> for epoch in range(100): 1145 >>> train(...) 1146 >>> validate(...) 1147 >>> scheduler.step() 1148 """ 1149 1150 def __init__( 1151 self, schedulers: Sequence[LRScheduler], optimizer: Optional[Optimizer] = None 1152 ): # noqa: D107 1153 if len(schedulers) < 1: 1154 raise ValueError( 1155 f"{self.__class__.__name__} expects at least one scheduler to be chained, but got no scheduler." 1156 ) 1157 1158 optimizer = optimizer or schedulers[0].optimizer 1159 for scheduler_idx, scheduler in enumerate(schedulers): 1160 if not hasattr(scheduler, "optimizer"): 1161 raise TypeError( 1162 f"{self.__class__.__name__} at index {scheduler_idx} should have `optimizer` as its attribute." 1163 ) 1164 if isinstance(scheduler, ReduceLROnPlateau): 1165 raise ValueError( 1166 f"{self.__class__.__name__} does not support `ReduceLROnPlateau` scheduler as it " 1167 "requires additional kwargs to be specified when calling `step`, " 1168 f"but got one at index {scheduler_idx} in the given schedulers sequence." 1169 ) 1170 if optimizer != scheduler.optimizer: 1171 raise ValueError( 1172 f"{self.__class__.__name__} expects all schedulers to belong to the same optimizer, but " 1173 f"got scheduler {scheduler.__class__.__name__} at index {scheduler_idx} has {scheduler.optimizer}, " 1174 f"which is different from {optimizer.__class__.__name__}." 1175 ) 1176 self._schedulers = schedulers 1177 self.optimizer = optimizer 1178 self._last_lr = [ 1179 group["lr"] for group in self._schedulers[-1].optimizer.param_groups 1180 ] 1181 1182 def step(self): 1183 """Perform a step.""" 1184 for scheduler in self._schedulers: 1185 scheduler.step() 1186 self._last_lr = [ 1187 group["lr"] for group in self._schedulers[-1].optimizer.param_groups 1188 ] 1189 1190 def state_dict(self): 1191 """Return the state of the scheduler as a :class:`dict`. 1192 1193 It contains an entry for every variable in self.__dict__ which 1194 is not the optimizer. 1195 The wrapped scheduler states will also be saved. 1196 """ 1197 state_dict = { 1198 key: value 1199 for key, value in self.__dict__.items() 1200 if key not in ("optimizer", "_schedulers") 1201 } 1202 state_dict["_schedulers"] = [None] * len(self._schedulers) 1203 1204 for idx, s in enumerate(self._schedulers): 1205 state_dict["_schedulers"][idx] = s.state_dict() 1206 1207 return state_dict 1208 1209 def load_state_dict(self, state_dict): 1210 """Load the scheduler's state. 1211 1212 Args: 1213 state_dict (dict): scheduler state. Should be an object returned 1214 from a call to :meth:`state_dict`. 1215 """ 1216 _schedulers = state_dict.pop("_schedulers") 1217 self.__dict__.update(state_dict) 1218 # Restore state_dict keys in order to prevent side effects 1219 # https://github.com/pytorch/pytorch/issues/32756 1220 state_dict["_schedulers"] = _schedulers 1221 1222 for idx, s in enumerate(_schedulers): 1223 self._schedulers[idx].load_state_dict(s) 1224 1225 1226class ReduceLROnPlateau(LRScheduler): 1227 """Reduce learning rate when a metric has stopped improving. 1228 1229 Models often benefit from reducing the learning rate by a factor 1230 of 2-10 once learning stagnates. This scheduler reads a metrics 1231 quantity and if no improvement is seen for a 'patience' number 1232 of epochs, the learning rate is reduced. 1233 1234 Args: 1235 optimizer (Optimizer): Wrapped optimizer. 1236 mode (str): One of `min`, `max`. In `min` mode, lr will 1237 be reduced when the quantity monitored has stopped 1238 decreasing; in `max` mode it will be reduced when the 1239 quantity monitored has stopped increasing. Default: 'min'. 1240 factor (float): Factor by which the learning rate will be 1241 reduced. new_lr = lr * factor. Default: 0.1. 1242 patience (int): The number of allowed epochs with no improvement after 1243 which the learning rate will be reduced. 1244 For example, consider the case of having no patience (`patience = 0`). 1245 In the first epoch, a baseline is established and is always considered good as there's no previous baseline. 1246 In the second epoch, if the performance is worse than the baseline, 1247 we have what is considered an intolerable epoch. 1248 Since the count of intolerable epochs (1) is greater than the patience level (0), 1249 the learning rate is reduced at the end of this epoch. 1250 From the third epoch onwards, the learning rate continues to be reduced at the end of each epoch 1251 if the performance is worse than the baseline. If the performance improves or remains the same, 1252 the learning rate is not adjusted. 1253 Default: 10. 1254 threshold (float): Threshold for measuring the new optimum, 1255 to only focus on significant changes. Default: 1e-4. 1256 threshold_mode (str): One of `rel`, `abs`. In `rel` mode, 1257 dynamic_threshold = best * ( 1 + threshold ) in 'max' 1258 mode or best * ( 1 - threshold ) in `min` mode. 1259 In `abs` mode, dynamic_threshold = best + threshold in 1260 `max` mode or best - threshold in `min` mode. Default: 'rel'. 1261 cooldown (int): Number of epochs to wait before resuming 1262 normal operation after lr has been reduced. Default: 0. 1263 min_lr (float or list): A scalar or a list of scalars. A 1264 lower bound on the learning rate of all param groups 1265 or each group respectively. Default: 0. 1266 eps (float): Minimal decay applied to lr. If the difference 1267 between new and old lr is smaller than eps, the update is 1268 ignored. Default: 1e-8. 1269 verbose (bool | str): If ``True``, prints a message to stdout for 1270 each update. Default: ``False``. 1271 1272 .. deprecated:: 2.2 1273 ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the 1274 learning rate. 1275 1276 Example: 1277 >>> # xdoctest: +SKIP 1278 >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 1279 >>> scheduler = ReduceLROnPlateau(optimizer, 'min') 1280 >>> for epoch in range(10): 1281 >>> train(...) 1282 >>> val_loss = validate(...) 1283 >>> # Note that step should be called after validate() 1284 >>> scheduler.step(val_loss) 1285 """ 1286 1287 def __init__( 1288 self, 1289 optimizer: Optimizer, 1290 mode: Literal["min", "max"] = "min", 1291 factor=0.1, 1292 patience=10, 1293 threshold=1e-4, 1294 threshold_mode: Literal["rel", "abs"] = "rel", 1295 cooldown=0, 1296 min_lr: Union[List[float], float] = 0, 1297 eps=1e-8, 1298 verbose="deprecated", 1299 ): # noqa: D107 1300 if factor >= 1.0: 1301 raise ValueError("Factor should be < 1.0.") 1302 self.factor = factor 1303 1304 # Attach optimizer 1305 if not isinstance(optimizer, Optimizer): 1306 raise TypeError(f"{type(optimizer).__name__} is not an Optimizer") 1307 self.optimizer = optimizer 1308 1309 if isinstance(min_lr, (list, tuple)): 1310 if len(min_lr) != len(optimizer.param_groups): 1311 raise ValueError( 1312 f"expected {len(optimizer.param_groups)} min_lrs, got {len(min_lr)}" 1313 ) 1314 self.min_lrs = list(min_lr) 1315 else: 1316 self.min_lrs = [min_lr] * len(optimizer.param_groups) 1317 1318 self.patience = patience 1319 1320 self.verbose = _check_verbose_deprecated_warning(verbose) 1321 self.cooldown = cooldown 1322 self.cooldown_counter = 0 1323 self.mode = mode 1324 self.threshold = threshold 1325 self.threshold_mode = threshold_mode 1326 self.best: float 1327 self.num_bad_epochs: int 1328 self.mode_worse: float # the worse value for the chosen mode 1329 self.eps = eps 1330 self.last_epoch = 0 1331 self._last_lr = [group["lr"] for group in self.optimizer.param_groups] 1332 self._init_is_better( 1333 mode=mode, threshold=threshold, threshold_mode=threshold_mode 1334 ) 1335 self._reset() 1336 1337 def _reset(self): 1338 """Reset num_bad_epochs counter and cooldown counter.""" 1339 self.best = self.mode_worse 1340 self.cooldown_counter = 0 1341 self.num_bad_epochs = 0 1342 1343 def step(self, metrics: SupportsFloat, epoch=None): # type: ignore[override] 1344 """Perform a step.""" 1345 # convert `metrics` to float, in case it's a zero-dim Tensor 1346 current = float(metrics) 1347 if epoch is None: 1348 epoch = self.last_epoch + 1 1349 else: 1350 warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning) 1351 self.last_epoch = epoch 1352 1353 if self.is_better(current, self.best): 1354 self.best = current 1355 self.num_bad_epochs = 0 1356 else: 1357 self.num_bad_epochs += 1 1358 1359 if self.in_cooldown: 1360 self.cooldown_counter -= 1 1361 self.num_bad_epochs = 0 # ignore any bad epochs in cooldown 1362 1363 if self.num_bad_epochs > self.patience: 1364 self._reduce_lr(epoch) 1365 self.cooldown_counter = self.cooldown 1366 self.num_bad_epochs = 0 1367 1368 self._last_lr = [group["lr"] for group in self.optimizer.param_groups] 1369 1370 def _reduce_lr(self, epoch): 1371 for i, param_group in enumerate(self.optimizer.param_groups): 1372 old_lr = float(param_group["lr"]) 1373 new_lr = max(old_lr * self.factor, self.min_lrs[i]) 1374 if old_lr - new_lr > self.eps: 1375 param_group["lr"] = new_lr 1376 1377 @property 1378 def in_cooldown(self): # noqa: D102 1379 return self.cooldown_counter > 0 1380 1381 def is_better(self, a, best): # noqa: D102 1382 if self.mode == "min" and self.threshold_mode == "rel": 1383 rel_epsilon = 1.0 - self.threshold 1384 return a < best * rel_epsilon 1385 1386 elif self.mode == "min" and self.threshold_mode == "abs": 1387 return a < best - self.threshold 1388 1389 elif self.mode == "max" and self.threshold_mode == "rel": 1390 rel_epsilon = self.threshold + 1.0 1391 return a > best * rel_epsilon 1392 1393 else: # mode == 'max' and epsilon_mode == 'abs': 1394 return a > best + self.threshold 1395 1396 def _init_is_better(self, mode, threshold, threshold_mode): 1397 if mode not in {"min", "max"}: 1398 raise ValueError("mode " + mode + " is unknown!") 1399 if threshold_mode not in {"rel", "abs"}: 1400 raise ValueError("threshold mode " + threshold_mode + " is unknown!") 1401 1402 if mode == "min": 1403 self.mode_worse = inf 1404 else: # mode == 'max': 1405 self.mode_worse = -inf 1406 1407 self.mode = mode 1408 self.threshold = threshold 1409 self.threshold_mode = threshold_mode 1410 1411 def state_dict(self): # noqa: D102 1412 return { 1413 key: value for key, value in self.__dict__.items() if key != "optimizer" 1414 } 1415 1416 def load_state_dict(self, state_dict): 1417 """Load the scheduler's state.""" 1418 self.__dict__.update(state_dict) 1419 self._init_is_better( 1420 mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode 1421 ) 1422 1423 1424class CyclicLR(LRScheduler): 1425 r"""Sets the learning rate of each parameter group according to cyclical learning rate policy (CLR). 1426 1427 The policy cycles the learning rate between two boundaries with a constant frequency, 1428 as detailed in the paper `Cyclical Learning Rates for Training Neural Networks`_. 1429 The distance between the two boundaries can be scaled on a per-iteration 1430 or per-cycle basis. 1431 1432 Cyclical learning rate policy changes the learning rate after every batch. 1433 `step` should be called after a batch has been used for training. 1434 1435 This class has three built-in policies, as put forth in the paper: 1436 1437 * "triangular": A basic triangular cycle without amplitude scaling. 1438 * "triangular2": A basic triangular cycle that scales initial amplitude by half each cycle. 1439 * "exp_range": A cycle that scales initial amplitude by :math:`\text{gamma}^{\text{cycle iterations}}` 1440 at each cycle iteration. 1441 1442 This implementation was adapted from the github repo: `bckenstler/CLR`_ 1443 1444 Args: 1445 optimizer (Optimizer): Wrapped optimizer. 1446 base_lr (float or list): Initial learning rate which is the 1447 lower boundary in the cycle for each parameter group. 1448 max_lr (float or list): Upper learning rate boundaries in the cycle 1449 for each parameter group. Functionally, 1450 it defines the cycle amplitude (max_lr - base_lr). 1451 The lr at any cycle is the sum of base_lr 1452 and some scaling of the amplitude; therefore 1453 max_lr may not actually be reached depending on 1454 scaling function. 1455 step_size_up (int): Number of training iterations in the 1456 increasing half of a cycle. Default: 2000 1457 step_size_down (int): Number of training iterations in the 1458 decreasing half of a cycle. If step_size_down is None, 1459 it is set to step_size_up. Default: None 1460 mode (str): One of {triangular, triangular2, exp_range}. 1461 Values correspond to policies detailed above. 1462 If scale_fn is not None, this argument is ignored. 1463 Default: 'triangular' 1464 gamma (float): Constant in 'exp_range' scaling function: 1465 gamma**(cycle iterations) 1466 Default: 1.0 1467 scale_fn (function): Custom scaling policy defined by a single 1468 argument lambda function, where 1469 0 <= scale_fn(x) <= 1 for all x >= 0. 1470 If specified, then 'mode' is ignored. 1471 Default: None 1472 scale_mode (str): {'cycle', 'iterations'}. 1473 Defines whether scale_fn is evaluated on 1474 cycle number or cycle iterations (training 1475 iterations since start of cycle). 1476 Default: 'cycle' 1477 cycle_momentum (bool): If ``True``, momentum is cycled inversely 1478 to learning rate between 'base_momentum' and 'max_momentum'. 1479 Default: True 1480 base_momentum (float or list): Lower momentum boundaries in the cycle 1481 for each parameter group. Note that momentum is cycled inversely 1482 to learning rate; at the peak of a cycle, momentum is 1483 'base_momentum' and learning rate is 'max_lr'. 1484 Default: 0.8 1485 max_momentum (float or list): Upper momentum boundaries in the cycle 1486 for each parameter group. Functionally, 1487 it defines the cycle amplitude (max_momentum - base_momentum). 1488 The momentum at any cycle is the difference of max_momentum 1489 and some scaling of the amplitude; therefore 1490 base_momentum may not actually be reached depending on 1491 scaling function. Note that momentum is cycled inversely 1492 to learning rate; at the start of a cycle, momentum is 'max_momentum' 1493 and learning rate is 'base_lr' 1494 Default: 0.9 1495 last_epoch (int): The index of the last batch. This parameter is used when 1496 resuming a training job. Since `step()` should be invoked after each 1497 batch instead of after each epoch, this number represents the total 1498 number of *batches* computed, not the total number of epochs computed. 1499 When last_epoch=-1, the schedule is started from the beginning. 1500 Default: -1 1501 verbose (bool | str): If ``True``, prints a message to stdout for 1502 each update. Default: ``False``. 1503 1504 .. deprecated:: 2.2 1505 ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the 1506 learning rate. 1507 1508 Example: 1509 >>> # xdoctest: +SKIP 1510 >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 1511 >>> scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.1) 1512 >>> data_loader = torch.utils.data.DataLoader(...) 1513 >>> for epoch in range(10): 1514 >>> for batch in data_loader: 1515 >>> train_batch(...) 1516 >>> scheduler.step() 1517 1518 1519 .. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186 1520 .. _bckenstler/CLR: https://github.com/bckenstler/CLR 1521 """ 1522 1523 def __init__( 1524 self, 1525 optimizer: Optimizer, 1526 base_lr: Union[float, List[float]], 1527 max_lr: Union[float, List[float]], 1528 step_size_up=2000, 1529 step_size_down: Optional[int] = None, 1530 mode: Literal["triangular", "triangular2", "exp_range"] = "triangular", 1531 gamma=1.0, 1532 scale_fn: Optional[Callable[[float], float]] = None, 1533 scale_mode: Literal["cycle", "iterations"] = "cycle", 1534 cycle_momentum=True, 1535 base_momentum=0.8, 1536 max_momentum=0.9, 1537 last_epoch=-1, 1538 verbose="deprecated", 1539 ): # noqa: D107 1540 # Attach optimizer 1541 if not isinstance(optimizer, Optimizer): 1542 raise TypeError(f"{type(optimizer).__name__} is not an Optimizer") 1543 self.optimizer = optimizer 1544 1545 base_lrs = _format_param("base_lr", optimizer, base_lr) 1546 if last_epoch == -1: 1547 for lr, group in zip(base_lrs, optimizer.param_groups): 1548 if isinstance(group["lr"], Tensor): 1549 lr_val = lr.item() if isinstance(lr, Tensor) else lr 1550 group["lr"].fill_(lr_val) 1551 else: 1552 group["lr"] = lr 1553 1554 self.max_lrs = _format_param("max_lr", optimizer, max_lr) 1555 1556 step_size_up = float(step_size_up) 1557 step_size_down = ( 1558 float(step_size_down) if step_size_down is not None else step_size_up 1559 ) 1560 self.total_size = step_size_up + step_size_down 1561 self.step_ratio = step_size_up / self.total_size 1562 1563 if mode not in ["triangular", "triangular2", "exp_range"] and scale_fn is None: 1564 raise ValueError("mode is invalid and scale_fn is None") 1565 1566 self.mode = mode 1567 self.gamma = gamma 1568 1569 self._scale_fn_ref: Callable[[float], float] 1570 self._scale_fn_custom = scale_fn 1571 self.scale_mode = scale_mode 1572 self._init_scale_fn() 1573 1574 self.cycle_momentum = cycle_momentum 1575 if cycle_momentum: 1576 if ( 1577 "momentum" not in optimizer.defaults 1578 and "betas" not in optimizer.defaults 1579 ): 1580 raise ValueError( 1581 "optimizer must support momentum or beta1 with `cycle_momentum` option enabled" 1582 ) 1583 1584 self.use_beta1 = "betas" in self.optimizer.defaults 1585 self.base_momentums = _format_param( 1586 "base_momentum", optimizer, base_momentum 1587 ) 1588 self.max_momentums = _format_param("max_momentum", optimizer, max_momentum) 1589 if last_epoch == -1: 1590 for m_momentum, b_momentum, group in zip( 1591 self.max_momentums, self.base_momentums, optimizer.param_groups 1592 ): 1593 if self.use_beta1: 1594 group["betas"] = (m_momentum, *group["betas"][1:]) 1595 else: 1596 group["momentum"] = m_momentum 1597 group["max_momentum"] = m_momentum 1598 group["base_momentum"] = b_momentum 1599 1600 super().__init__(optimizer, last_epoch, verbose) 1601 self.base_lrs = base_lrs 1602 1603 def _init_scale_fn(self): 1604 if self._scale_fn_custom is not None: 1605 return 1606 if self.mode == "triangular": 1607 self._scale_fn_ref = self._triangular_scale_fn 1608 self.scale_mode = "cycle" 1609 elif self.mode == "triangular2": 1610 self._scale_fn_ref = self._triangular2_scale_fn 1611 self.scale_mode = "cycle" 1612 elif self.mode == "exp_range": 1613 self._scale_fn_ref = partial(self._exp_range_scale_fn, self.gamma) 1614 self.scale_mode = "iterations" 1615 1616 def scale_fn(self, x) -> float: 1617 """Get the scaling policy.""" 1618 if self._scale_fn_custom is not None: 1619 return self._scale_fn_custom(x) 1620 else: 1621 return self._scale_fn_ref(x) # static method 1622 1623 @staticmethod 1624 def _triangular_scale_fn(x: float) -> float: 1625 return 1.0 1626 1627 @staticmethod 1628 def _triangular2_scale_fn(x: float) -> float: 1629 return 1 / (2.0 ** (x - 1)) 1630 1631 @staticmethod 1632 def _exp_range_scale_fn(gamma: float, x: float) -> float: 1633 return gamma**x 1634 1635 def get_lr(self): 1636 """Calculate the learning rate at batch index. 1637 1638 This function treats `self.last_epoch` as the last batch index. 1639 1640 If `self.cycle_momentum` is ``True``, this function has a side effect of 1641 updating the optimizer's momentum. 1642 """ 1643 _warn_get_lr_called_within_step(self) 1644 1645 cycle = math.floor(1 + self.last_epoch / self.total_size) 1646 x = 1.0 + self.last_epoch / self.total_size - cycle 1647 if x <= self.step_ratio: 1648 scale_factor = x / self.step_ratio 1649 else: 1650 scale_factor = (x - 1) / (self.step_ratio - 1) 1651 1652 lrs = [] 1653 for base_lr, max_lr in zip(self.base_lrs, self.max_lrs): 1654 base_height = (max_lr - base_lr) * scale_factor 1655 if self.scale_mode == "cycle": 1656 lr = base_lr + base_height * self.scale_fn(cycle) 1657 else: 1658 lr = base_lr + base_height * self.scale_fn(self.last_epoch) 1659 lrs.append(lr) 1660 1661 if self.cycle_momentum: 1662 momentums = [] 1663 for base_momentum, max_momentum in zip( 1664 self.base_momentums, self.max_momentums 1665 ): 1666 base_height = (max_momentum - base_momentum) * scale_factor 1667 if self.scale_mode == "cycle": 1668 momentum = max_momentum - base_height * self.scale_fn(cycle) 1669 else: 1670 momentum = max_momentum - base_height * self.scale_fn( 1671 self.last_epoch 1672 ) 1673 momentums.append(momentum) 1674 for param_group, momentum in zip(self.optimizer.param_groups, momentums): 1675 if self.use_beta1: 1676 param_group["betas"] = (momentum, *param_group["betas"][1:]) 1677 else: 1678 param_group["momentum"] = momentum 1679 1680 return lrs 1681 1682 def state_dict(self): # noqa: D102 1683 state = super().state_dict() 1684 # We are dropping the `_scale_fn_ref` attribute because it is a 1685 # `weakref.WeakMethod` and can't be pickled. 1686 state.pop("_scale_fn_ref", None) 1687 fn = state.pop("_scale_fn_custom") 1688 state["_scale_fn_custom"] = None 1689 if fn is not None and not isinstance(fn, types.FunctionType): 1690 # The _scale_fn_custom will only be saved if it is a callable object 1691 # and not if it is a function or lambda. 1692 state["_scale_fn_custom"] = fn.__dict__.copy() 1693 1694 return state 1695 1696 def load_state_dict(self, state_dict): 1697 """Load the scheduler's state.""" 1698 fn = state_dict.pop("_scale_fn_custom") 1699 super().load_state_dict(state_dict) 1700 if fn is not None: 1701 self._scale_fn_custom.__dict__.update(fn) 1702 self._init_scale_fn() 1703 1704 1705class CosineAnnealingWarmRestarts(LRScheduler): 1706 r"""Set the learning rate of each parameter group using a cosine annealing schedule. 1707 1708 The :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}` 1709 is the number of epochs since the last restart and :math:`T_{i}` is the number 1710 of epochs between two warm restarts in SGDR: 1711 1712 .. math:: 1713 \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + 1714 \cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right) 1715 1716 When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`. 1717 When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`. 1718 1719 It has been proposed in 1720 `SGDR: Stochastic Gradient Descent with Warm Restarts`_. 1721 1722 Args: 1723 optimizer (Optimizer): Wrapped optimizer. 1724 T_0 (int): Number of iterations until the first restart. 1725 T_mult (int, optional): A factor by which :math:`T_{i}` increases after a restart. Default: 1. 1726 eta_min (float, optional): Minimum learning rate. Default: 0. 1727 last_epoch (int, optional): The index of the last epoch. Default: -1. 1728 verbose (bool | str): If ``True``, prints a message to stdout for 1729 each update. Default: ``False``. 1730 1731 .. deprecated:: 2.2 1732 ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the 1733 learning rate. 1734 1735 .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: 1736 https://arxiv.org/abs/1608.03983 1737 """ 1738 1739 def __init__( 1740 self, 1741 optimizer: Optimizer, 1742 T_0: int, 1743 T_mult=1, 1744 eta_min=0.0, 1745 last_epoch=-1, 1746 verbose="deprecated", 1747 ): # noqa: D107 1748 if T_0 <= 0 or not isinstance(T_0, int): 1749 raise ValueError(f"Expected positive integer T_0, but got {T_0}") 1750 if T_mult < 1 or not isinstance(T_mult, int): 1751 raise ValueError(f"Expected integer T_mult >= 1, but got {T_mult}") 1752 if not isinstance(eta_min, (float, int)): 1753 raise ValueError( 1754 f"Expected float or int eta_min, but got {eta_min} of type {type(eta_min)}" 1755 ) 1756 self.T_0 = T_0 1757 self.T_i = T_0 1758 self.T_mult = T_mult 1759 self.eta_min = eta_min 1760 self.T_cur = last_epoch 1761 super().__init__(optimizer, last_epoch, verbose) 1762 1763 def get_lr(self): 1764 """Compute the initial learning rate.""" 1765 _warn_get_lr_called_within_step(self) 1766 1767 return [ 1768 self.eta_min 1769 + (base_lr - self.eta_min) 1770 * (1 + math.cos(math.pi * self.T_cur / self.T_i)) 1771 / 2 1772 for base_lr in self.base_lrs 1773 ] 1774 1775 def step(self, epoch=None): 1776 """Step could be called after every batch update. 1777 1778 Example: 1779 >>> # xdoctest: +SKIP("Undefined vars") 1780 >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult) 1781 >>> iters = len(dataloader) 1782 >>> for epoch in range(20): 1783 >>> for i, sample in enumerate(dataloader): 1784 >>> inputs, labels = sample['inputs'], sample['labels'] 1785 >>> optimizer.zero_grad() 1786 >>> outputs = net(inputs) 1787 >>> loss = criterion(outputs, labels) 1788 >>> loss.backward() 1789 >>> optimizer.step() 1790 >>> scheduler.step(epoch + i / iters) 1791 1792 This function can be called in an interleaved way. 1793 1794 Example: 1795 >>> # xdoctest: +SKIP("Undefined vars") 1796 >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult) 1797 >>> for epoch in range(20): 1798 >>> scheduler.step() 1799 >>> scheduler.step(26) 1800 >>> scheduler.step() # scheduler.step(27), instead of scheduler(20) 1801 """ 1802 if epoch is None and self.last_epoch < 0: 1803 epoch = 0 1804 1805 if epoch is None: 1806 epoch = self.last_epoch + 1 1807 self.T_cur = self.T_cur + 1 1808 if self.T_cur >= self.T_i: 1809 self.T_cur = self.T_cur - self.T_i 1810 self.T_i = self.T_i * self.T_mult 1811 else: 1812 if epoch < 0: 1813 raise ValueError(f"Expected non-negative epoch, but got {epoch}") 1814 if epoch >= self.T_0: 1815 if self.T_mult == 1: 1816 self.T_cur = epoch % self.T_0 1817 else: 1818 n = int( 1819 math.log( 1820 (epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult 1821 ) 1822 ) 1823 self.T_cur = epoch - self.T_0 * (self.T_mult**n - 1) / ( 1824 self.T_mult - 1 1825 ) 1826 self.T_i = self.T_0 * self.T_mult ** (n) 1827 else: 1828 self.T_i = self.T_0 1829 self.T_cur = epoch 1830 self.last_epoch = math.floor(epoch) 1831 1832 with _enable_get_lr_call(self): 1833 for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())): 1834 param_group, lr = data 1835 param_group["lr"] = lr 1836 1837 self._last_lr = [group["lr"] for group in self.optimizer.param_groups] 1838 1839 1840class _SchedulePhase(TypedDict): 1841 end_step: float 1842 start_lr: str 1843 end_lr: str 1844 start_momentum: str 1845 end_momentum: str 1846 1847 1848class OneCycleLR(LRScheduler): 1849 r"""Sets the learning rate of each parameter group according to the 1cycle learning rate policy. 1850 1851 The 1cycle policy anneals the learning rate from an initial learning rate to some maximum 1852 learning rate and then from that maximum learning rate to some minimum learning rate much 1853 lower than the initial learning rate. 1854 This policy was initially described in the paper `Super-Convergence: 1855 Very Fast Training of Neural Networks Using Large Learning Rates`_. 1856 1857 The 1cycle learning rate policy changes the learning rate after every batch. 1858 `step` should be called after a batch has been used for training. 1859 1860 This scheduler is not chainable. 1861 1862 Note also that the total number of steps in the cycle can be determined in one 1863 of two ways (listed in order of precedence): 1864 1865 #. A value for total_steps is explicitly provided. 1866 #. A number of epochs (epochs) and a number of steps per epoch 1867 (steps_per_epoch) are provided. 1868 In this case, the number of total steps is inferred by 1869 total_steps = epochs * steps_per_epoch 1870 1871 You must either provide a value for total_steps or provide a value for both 1872 epochs and steps_per_epoch. 1873 1874 The default behaviour of this scheduler follows the fastai implementation of 1cycle, which 1875 claims that "unpublished work has shown even better results by using only two phases". To 1876 mimic the behaviour of the original paper instead, set ``three_phase=True``. 1877 1878 Args: 1879 optimizer (Optimizer): Wrapped optimizer. 1880 max_lr (float or list): Upper learning rate boundaries in the cycle 1881 for each parameter group. 1882 total_steps (int): The total number of steps in the cycle. Note that 1883 if a value is not provided here, then it must be inferred by providing 1884 a value for epochs and steps_per_epoch. 1885 Default: None 1886 epochs (int): The number of epochs to train for. This is used along 1887 with steps_per_epoch in order to infer the total number of steps in the cycle 1888 if a value for total_steps is not provided. 1889 Default: None 1890 steps_per_epoch (int): The number of steps per epoch to train for. This is 1891 used along with epochs in order to infer the total number of steps in the 1892 cycle if a value for total_steps is not provided. 1893 Default: None 1894 pct_start (float): The percentage of the cycle (in number of steps) spent 1895 increasing the learning rate. 1896 Default: 0.3 1897 anneal_strategy (str): {'cos', 'linear'} 1898 Specifies the annealing strategy: "cos" for cosine annealing, "linear" for 1899 linear annealing. 1900 Default: 'cos' 1901 cycle_momentum (bool): If ``True``, momentum is cycled inversely 1902 to learning rate between 'base_momentum' and 'max_momentum'. 1903 Default: True 1904 base_momentum (float or list): Lower momentum boundaries in the cycle 1905 for each parameter group. Note that momentum is cycled inversely 1906 to learning rate; at the peak of a cycle, momentum is 1907 'base_momentum' and learning rate is 'max_lr'. 1908 Default: 0.85 1909 max_momentum (float or list): Upper momentum boundaries in the cycle 1910 for each parameter group. Functionally, 1911 it defines the cycle amplitude (max_momentum - base_momentum). 1912 Note that momentum is cycled inversely 1913 to learning rate; at the start of a cycle, momentum is 'max_momentum' 1914 and learning rate is 'base_lr' 1915 Default: 0.95 1916 div_factor (float): Determines the initial learning rate via 1917 initial_lr = max_lr/div_factor 1918 Default: 25 1919 final_div_factor (float): Determines the minimum learning rate via 1920 min_lr = initial_lr/final_div_factor 1921 Default: 1e4 1922 three_phase (bool): If ``True``, use a third phase of the schedule to annihilate the 1923 learning rate according to 'final_div_factor' instead of modifying the second 1924 phase (the first two phases will be symmetrical about the step indicated by 1925 'pct_start'). 1926 last_epoch (int): The index of the last batch. This parameter is used when 1927 resuming a training job. Since `step()` should be invoked after each 1928 batch instead of after each epoch, this number represents the total 1929 number of *batches* computed, not the total number of epochs computed. 1930 When last_epoch=-1, the schedule is started from the beginning. 1931 Default: -1 1932 verbose (bool | str): If ``True``, prints a message to stdout for 1933 each update. Default: ``False``. 1934 1935 .. deprecated:: 2.2 1936 ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the 1937 learning rate. 1938 1939 Example: 1940 >>> # xdoctest: +SKIP 1941 >>> data_loader = torch.utils.data.DataLoader(...) 1942 >>> optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9) 1943 >>> scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10) 1944 >>> for epoch in range(10): 1945 >>> for batch in data_loader: 1946 >>> train_batch(...) 1947 >>> optimizer.step() 1948 >>> scheduler.step() 1949 1950 1951 .. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates: 1952 https://arxiv.org/abs/1708.07120 1953 """ 1954 1955 def __init__( 1956 self, 1957 optimizer: Optimizer, 1958 max_lr: Union[float, List[float]], 1959 total_steps: Optional[int] = None, 1960 epochs: Optional[int] = None, 1961 steps_per_epoch: Optional[int] = None, 1962 pct_start=0.3, 1963 anneal_strategy: Literal["cos", "linear"] = "cos", 1964 cycle_momentum=True, 1965 base_momentum: Union[float, List[float]] = 0.85, 1966 max_momentum: Union[float, List[float]] = 0.95, 1967 div_factor=25.0, 1968 final_div_factor=1e4, 1969 three_phase=False, 1970 last_epoch=-1, 1971 verbose="deprecated", 1972 ): # noqa: D107 1973 # Validate optimizer 1974 if not isinstance(optimizer, Optimizer): 1975 raise TypeError(f"{type(optimizer).__name__} is not an Optimizer") 1976 self.optimizer = optimizer 1977 1978 # Validate total_steps 1979 if total_steps is not None: 1980 if total_steps <= 0 or not isinstance(total_steps, int): 1981 raise ValueError( 1982 f"Expected positive integer total_steps, but got {total_steps}" 1983 ) 1984 self.total_steps = total_steps 1985 elif epochs is not None and steps_per_epoch is not None: 1986 if not isinstance(epochs, int) or epochs <= 0: 1987 raise ValueError(f"Expected positive integer epochs, but got {epochs}") 1988 if not isinstance(steps_per_epoch, int) or steps_per_epoch <= 0: 1989 raise ValueError( 1990 f"Expected positive integer steps_per_epoch, but got {steps_per_epoch}" 1991 ) 1992 self.total_steps = epochs * steps_per_epoch 1993 else: 1994 raise ValueError( 1995 "You must define either total_steps OR (epochs AND steps_per_epoch)" 1996 ) 1997 1998 self._schedule_phases: List[_SchedulePhase] 1999 if three_phase: 2000 self._schedule_phases = [ 2001 { 2002 "end_step": float(pct_start * self.total_steps) - 1, 2003 "start_lr": "initial_lr", 2004 "end_lr": "max_lr", 2005 "start_momentum": "max_momentum", 2006 "end_momentum": "base_momentum", 2007 }, 2008 { 2009 "end_step": float(2 * pct_start * self.total_steps) - 2, 2010 "start_lr": "max_lr", 2011 "end_lr": "initial_lr", 2012 "start_momentum": "base_momentum", 2013 "end_momentum": "max_momentum", 2014 }, 2015 { 2016 "end_step": self.total_steps - 1, 2017 "start_lr": "initial_lr", 2018 "end_lr": "min_lr", 2019 "start_momentum": "max_momentum", 2020 "end_momentum": "max_momentum", 2021 }, 2022 ] 2023 else: 2024 self._schedule_phases = [ 2025 { 2026 "end_step": float(pct_start * self.total_steps) - 1, 2027 "start_lr": "initial_lr", 2028 "end_lr": "max_lr", 2029 "start_momentum": "max_momentum", 2030 "end_momentum": "base_momentum", 2031 }, 2032 { 2033 "end_step": self.total_steps - 1, 2034 "start_lr": "max_lr", 2035 "end_lr": "min_lr", 2036 "start_momentum": "base_momentum", 2037 "end_momentum": "max_momentum", 2038 }, 2039 ] 2040 2041 # Validate pct_start 2042 if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float): 2043 raise ValueError( 2044 f"Expected float between 0 and 1 pct_start, but got {pct_start}" 2045 ) 2046 2047 # Validate anneal_strategy 2048 if anneal_strategy not in ["cos", "linear"]: 2049 raise ValueError( 2050 f"anneal_strategy must be one of 'cos' or 'linear', instead got {anneal_strategy}" 2051 ) 2052 else: 2053 self._anneal_func_type = anneal_strategy 2054 2055 # Initialize learning rate variables 2056 max_lrs = _format_param("max_lr", self.optimizer, max_lr) 2057 if last_epoch == -1: 2058 for idx, group in enumerate(self.optimizer.param_groups): 2059 group["initial_lr"] = max_lrs[idx] / div_factor 2060 group["max_lr"] = max_lrs[idx] 2061 group["min_lr"] = group["initial_lr"] / final_div_factor 2062 2063 # Initialize momentum variables 2064 self.cycle_momentum = cycle_momentum 2065 if self.cycle_momentum: 2066 if ( 2067 "momentum" not in self.optimizer.defaults 2068 and "betas" not in self.optimizer.defaults 2069 ): 2070 raise ValueError( 2071 "optimizer must support momentum or beta1 with `cycle_momentum` option enabled" 2072 ) 2073 self.use_beta1 = "betas" in self.optimizer.defaults 2074 max_momentums = _format_param("max_momentum", optimizer, max_momentum) 2075 base_momentums = _format_param("base_momentum", optimizer, base_momentum) 2076 if last_epoch == -1: 2077 for m_momentum, b_momentum, group in zip( 2078 max_momentums, base_momentums, optimizer.param_groups 2079 ): 2080 if self.use_beta1: 2081 group["betas"] = (m_momentum, *group["betas"][1:]) 2082 else: 2083 group["momentum"] = m_momentum 2084 group["max_momentum"] = m_momentum 2085 group["base_momentum"] = b_momentum 2086 2087 super().__init__(optimizer, last_epoch, verbose) 2088 2089 def _anneal_func(self, *args, **kwargs): 2090 if hasattr(self, "_anneal_func_type"): 2091 if self._anneal_func_type == "cos": 2092 return self._annealing_cos(*args, **kwargs) 2093 elif self._anneal_func_type == "linear": 2094 return self._annealing_linear(*args, **kwargs) 2095 else: 2096 raise ValueError(f"Unknown _anneal_func_type: {self._anneal_func_type}") 2097 else: 2098 # For BC 2099 return self.anneal_func(*args, **kwargs) # type: ignore[attr-defined] 2100 2101 @staticmethod 2102 def _annealing_cos(start, end, pct): 2103 """Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0.""" 2104 cos_out = math.cos(math.pi * pct) + 1 2105 return end + (start - end) / 2.0 * cos_out 2106 2107 @staticmethod 2108 def _annealing_linear(start, end, pct): 2109 """Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0.""" 2110 return (end - start) * pct + start 2111 2112 def get_lr(self): 2113 """Compute the learning rate of each parameter group.""" 2114 _warn_get_lr_called_within_step(self) 2115 2116 lrs = [] 2117 step_num = self.last_epoch 2118 2119 if step_num > self.total_steps: 2120 raise ValueError( 2121 f"Tried to step {step_num} times. The specified number of total steps is {self.total_steps}" # noqa: UP032 2122 ) 2123 2124 for group in self.optimizer.param_groups: 2125 start_step = 0.0 2126 for i, phase in enumerate(self._schedule_phases): 2127 end_step = phase["end_step"] 2128 if step_num <= end_step or i == len(self._schedule_phases) - 1: 2129 pct = (step_num - start_step) / (end_step - start_step) 2130 computed_lr = self._anneal_func( 2131 group[phase["start_lr"]], group[phase["end_lr"]], pct 2132 ) 2133 if self.cycle_momentum: 2134 computed_momentum = self._anneal_func( 2135 group[phase["start_momentum"]], 2136 group[phase["end_momentum"]], 2137 pct, 2138 ) 2139 break 2140 start_step = phase["end_step"] 2141 2142 lrs.append(computed_lr) # type: ignore[possibly-undefined] 2143 if self.cycle_momentum: 2144 if self.use_beta1: 2145 group["betas"] = (computed_momentum, *group["betas"][1:]) # type: ignore[possibly-undefined] 2146 else: 2147 group[ 2148 "momentum" 2149 ] = computed_momentum # type: ignore[possibly-undefined] 2150 2151 return lrs 2152