xref: /aosp_15_r20/external/pytorch/torch/optim/lr_scheduler.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2r"""Learning Rate Scheduler."""
3import math
4import types
5import warnings
6from bisect import bisect_right
7from collections import Counter
8from functools import partial, wraps
9from typing import (
10    Any,
11    Callable,
12    cast,
13    Dict,
14    Iterable,
15    List,
16    Literal,
17    Optional,
18    Sequence,
19    SupportsFloat,
20    TypedDict,
21    Union,
22)
23from weakref import ref
24
25from torch import inf, Tensor
26
27from .optimizer import Optimizer
28
29
30__all__ = [
31    "LambdaLR",
32    "MultiplicativeLR",
33    "StepLR",
34    "MultiStepLR",
35    "ConstantLR",
36    "LinearLR",
37    "ExponentialLR",
38    "SequentialLR",
39    "CosineAnnealingLR",
40    "ChainedScheduler",
41    "ReduceLROnPlateau",
42    "CyclicLR",
43    "CosineAnnealingWarmRestarts",
44    "OneCycleLR",
45    "PolynomialLR",
46    "LRScheduler",
47]
48
49EPOCH_DEPRECATION_WARNING = (
50    "The epoch parameter in `scheduler.step()` was not necessary and is being "
51    "deprecated where possible. Please use `scheduler.step()` to step the "
52    "scheduler. During the deprecation, if epoch is different from None, the "
53    "closed form is used instead of the new chainable form, where available. "
54    "Please open an issue if you are unable to replicate your use case: "
55    "https://github.com/pytorch/pytorch/issues/new/choose."
56)
57
58
59def _check_verbose_deprecated_warning(verbose):
60    """Raise a warning when verbose is not the default value."""
61    if verbose != "deprecated":
62        warnings.warn(
63            "The verbose parameter is deprecated. Please use get_last_lr() "
64            "to access the learning rate.",
65            UserWarning,
66        )
67        return verbose
68    return False
69
70
71def _format_param(name: str, optimizer: Optimizer, param):
72    """Return correctly formatted lr/momentum for each param group."""
73
74    def _copy(_param):
75        return _param.clone() if isinstance(_param, Tensor) else _param
76
77    if isinstance(param, (list, tuple)):
78        if len(param) != len(optimizer.param_groups):
79            raise ValueError(
80                f"{name} must have the same length as optimizer.param_groups. "
81                f"{name} has {len(param)} values, param_groups has {len(optimizer.param_groups)}."
82            )
83    else:
84        param = [param] * len(optimizer.param_groups)
85
86    return list(map(_copy, param))
87
88
89class LRScheduler:
90    r"""Adjusts the learning rate during optimization."""
91
92    _get_lr_called_within_step: bool = False
93
94    def __init__(
95        self, optimizer: Optimizer, last_epoch=-1, verbose="deprecated"
96    ):  # noqa: D107
97        # Attach optimizer
98        if not isinstance(optimizer, Optimizer):
99            raise TypeError(f"{type(optimizer).__name__} is not an Optimizer")
100        self.optimizer = optimizer
101
102        # Initialize epoch and base learning rates
103        if last_epoch == -1:
104            for group in optimizer.param_groups:
105                initial_lr = group["lr"]
106                if isinstance(initial_lr, Tensor):
107                    initial_lr = initial_lr.clone()
108                group.setdefault("initial_lr", initial_lr)
109        else:
110            for i, group in enumerate(optimizer.param_groups):
111                if "initial_lr" not in group:
112                    raise KeyError(
113                        "param 'initial_lr' is not specified "
114                        f"in param_groups[{i}] when resuming an optimizer"
115                    )
116        self.base_lrs: List[float] = [
117            group["initial_lr"] for group in optimizer.param_groups
118        ]
119        self.last_epoch = last_epoch
120
121        # Following https://github.com/pytorch/pytorch/issues/20124
122        # We would like to ensure that `lr_scheduler.step()` is called after
123        # `optimizer.step()`
124        def patch_track_step_called(opt: Optimizer):
125            if hasattr(opt.step, "_wrapped_by_lr_sched"):
126                # we've already patched
127                return opt.step
128
129            def wrap_step(step_fn):
130                opt_ref = ref(self.optimizer)
131                func = step_fn.__func__
132
133                @wraps(func)
134                def wrapper(*args, **kwargs):
135                    opt = opt_ref()
136                    opt._opt_called = True  # type: ignore[union-attr]
137                    return func.__get__(opt, opt.__class__)(*args, **kwargs)
138
139                wrapper._wrapped_by_lr_sched = True  # type: ignore[attr-defined]
140                return wrapper
141
142            opt.step = wrap_step(opt.step)  # type: ignore[method-assign]
143
144        patch_track_step_called(self.optimizer)
145        self.verbose = _check_verbose_deprecated_warning(verbose)
146        self._initial_step()
147
148    def _initial_step(self):
149        """Initialize step counts and perform a step."""
150        self._step_count = 0
151        self.step()
152
153    def state_dict(self):
154        """Return the state of the scheduler as a :class:`dict`.
155
156        It contains an entry for every variable in self.__dict__ which
157        is not the optimizer.
158        """
159        return {
160            key: value for key, value in self.__dict__.items() if key != "optimizer"
161        }
162
163    def load_state_dict(self, state_dict: Dict[str, Any]):
164        """Load the scheduler's state.
165
166        Args:
167            state_dict (dict): scheduler state. Should be an object returned
168                from a call to :meth:`state_dict`.
169        """
170        self.__dict__.update(state_dict)
171
172    def get_last_lr(self) -> List[float]:
173        """Return last computed learning rate by current scheduler."""
174        return self._last_lr
175
176    def get_lr(self) -> List[float]:
177        """Compute learning rate using chainable form of the scheduler."""
178        raise NotImplementedError
179
180    def print_lr(
181        self,
182        is_verbose: bool,
183        group: Dict[str, Any],
184        lr: float,
185        epoch: Optional[int] = None,
186    ):
187        """Display the current learning rate.
188
189        .. deprecated:: 2.4
190            ``print_lr()`` is deprecated. Please use ``get_last_lr()`` to access the
191            learning rate.
192        """
193        warnings.warn(
194            "`LRScheduler.print_lr()` is being deprecated. To fetch the learning rate, "
195            "please use `get_last_lr()` instead. For more details, "
196            "see https://github.com/pytorch/pytorch/issues/99270.",
197            UserWarning,
198        )
199        if is_verbose:
200            if epoch is None:
201                print(f"Adjusting learning rate of group {group} to {lr:.4e}.")
202            else:
203                epoch_str = ("%.2f" if isinstance(epoch, float) else "%.5d") % epoch
204                print(
205                    f"Epoch {epoch_str}: adjusting learning rate of group {group} to {lr:.4e}."
206                )
207
208    def step(self, epoch: Optional[int] = None):
209        """Perform a step."""
210        # Raise a warning if old pattern is detected
211        # https://github.com/pytorch/pytorch/issues/20124
212        if self._step_count == 1:
213            if not hasattr(self.optimizer.step, "_wrapped_by_lr_sched"):
214                warnings.warn(
215                    "Seems like `optimizer.step()` has been overridden after learning rate scheduler "
216                    "initialization. Please, make sure to call `optimizer.step()` before "
217                    "`lr_scheduler.step()`. See more details at "
218                    "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate",
219                    UserWarning,
220                )
221
222            # Just check if there were two first lr_scheduler.step() calls before optimizer.step()
223            elif not getattr(self.optimizer, "_opt_called", False):
224                warnings.warn(
225                    "Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
226                    "In PyTorch 1.1.0 and later, you should call them in the opposite order: "
227                    "`optimizer.step()` before `lr_scheduler.step()`.  Failure to do this "
228                    "will result in PyTorch skipping the first value of the learning rate schedule. "
229                    "See more details at "
230                    "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate",
231                    UserWarning,
232                )
233        self._step_count += 1
234
235        with _enable_get_lr_call(self):
236            if epoch is None:
237                self.last_epoch += 1
238                values = self.get_lr()
239            else:
240                warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
241                self.last_epoch = epoch
242                if hasattr(self, "_get_closed_form_lr"):
243                    values = cast(List[float], self._get_closed_form_lr())
244                else:
245                    values = self.get_lr()
246
247        for i, data in enumerate(zip(self.optimizer.param_groups, values)):
248            param_group, lr = data
249            if isinstance(param_group["lr"], Tensor):
250                param_group["lr"].fill_(lr)
251            else:
252                param_group["lr"] = lr
253
254        self._last_lr: List[float] = [
255            group["lr"] for group in self.optimizer.param_groups
256        ]
257
258
259def _warn_get_lr_called_within_step(lr_scheduler: LRScheduler):
260    if not lr_scheduler._get_lr_called_within_step:
261        warnings.warn(
262            "To get the last learning rate computed by the scheduler, "
263            "please use `get_last_lr()`.",
264            UserWarning,
265            stacklevel=2,
266        )
267
268
269# Including _LRScheduler for backwards compatibility
270# Subclass instead of assign because we want __name__ of _LRScheduler to be _LRScheduler (assigning would make it LRScheduler).
271class _LRScheduler(LRScheduler):
272    pass
273
274
275class _enable_get_lr_call:
276    def __init__(self, o: LRScheduler):
277        self.o = o
278
279    def __enter__(self):
280        self.o._get_lr_called_within_step = True
281        return self
282
283    def __exit__(self, type, value, traceback):
284        self.o._get_lr_called_within_step = False
285
286
287class LambdaLR(LRScheduler):
288    """Sets the initial learning rate.
289
290    The learning rate of each parameter group is set to the initial lr
291    times a given function. When last_epoch=-1, sets initial lr as lr.
292
293    Args:
294        optimizer (Optimizer): Wrapped optimizer.
295        lr_lambda (function or list): A function which computes a multiplicative
296            factor given an integer parameter epoch, or a list of such
297            functions, one for each group in optimizer.param_groups.
298        last_epoch (int): The index of last epoch. Default: -1.
299        verbose (bool | str): If ``True``, prints a message to stdout for
300            each update. Default: ``False``.
301
302            .. deprecated:: 2.2
303                ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
304                learning rate.
305
306    Example:
307        >>> # xdoctest: +SKIP
308        >>> # Assuming optimizer has two groups.
309        >>> lambda1 = lambda epoch: epoch // 30
310        >>> lambda2 = lambda epoch: 0.95 ** epoch
311        >>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
312        >>> for epoch in range(100):
313        >>>     train(...)
314        >>>     validate(...)
315        >>>     scheduler.step()
316    """
317
318    def __init__(
319        self,
320        optimizer: Optimizer,
321        lr_lambda: Union[Callable[[int], float], List[Callable[[int], float]]],
322        last_epoch=-1,
323        verbose="deprecated",
324    ):  # noqa: D107
325        self.optimizer = optimizer
326
327        self.lr_lambdas: List[Callable[[int], float]]
328        if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
329            self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
330        else:
331            if len(lr_lambda) != len(optimizer.param_groups):
332                raise ValueError(
333                    f"Expected {len(optimizer.param_groups)} lr_lambdas, but got {len(lr_lambda)}"
334                )
335            self.lr_lambdas = list(lr_lambda)
336        super().__init__(optimizer, last_epoch, verbose)
337
338    def state_dict(self):
339        """Return the state of the scheduler as a :class:`dict`.
340
341        It contains an entry for every variable in self.__dict__ which
342        is not the optimizer.
343        The learning rate lambda functions will only be saved if they are callable objects
344        and not if they are functions or lambdas.
345
346        When saving or loading the scheduler, please make sure to also save or load the state of the optimizer.
347        """
348        state_dict = {
349            key: value
350            for key, value in self.__dict__.items()
351            if key not in ("optimizer", "lr_lambdas")
352        }
353        state_dict["lr_lambdas"] = [None] * len(self.lr_lambdas)
354
355        for idx, fn in enumerate(self.lr_lambdas):
356            if not isinstance(fn, types.FunctionType):
357                state_dict["lr_lambdas"][idx] = fn.__dict__.copy()
358
359        return state_dict
360
361    def load_state_dict(self, state_dict):
362        """Load the scheduler's state.
363
364        When saving or loading the scheduler, please make sure to also save or load the state of the optimizer.
365
366        Args:
367            state_dict (dict): scheduler state. Should be an object returned
368                from a call to :meth:`state_dict`.
369        """
370        lr_lambdas = state_dict.pop("lr_lambdas")
371        self.__dict__.update(state_dict)
372        # Restore state_dict keys in order to prevent side effects
373        # https://github.com/pytorch/pytorch/issues/32756
374        state_dict["lr_lambdas"] = lr_lambdas
375
376        for idx, fn in enumerate(lr_lambdas):
377            if fn is not None:
378                self.lr_lambdas[idx].__dict__.update(fn)
379
380    def get_lr(self):
381        """Compute learning rate."""
382        _warn_get_lr_called_within_step(self)
383
384        return [
385            base_lr * lmbda(self.last_epoch)
386            for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)
387        ]
388
389
390class MultiplicativeLR(LRScheduler):
391    """Multiply the learning rate of each parameter group by the factor given in the specified function.
392
393    When last_epoch=-1, set initial lr as lr.
394
395    Args:
396        optimizer (Optimizer): Wrapped optimizer.
397        lr_lambda (function or list): A function which computes a multiplicative
398            factor given an integer parameter epoch, or a list of such
399            functions, one for each group in optimizer.param_groups.
400        last_epoch (int): The index of last epoch. Default: -1.
401        verbose (bool | str): If ``True``, prints a message to stdout for
402            each update. Default: ``False``.
403
404            .. deprecated:: 2.2
405                ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
406                learning rate.
407
408    Example:
409        >>> # xdoctest: +SKIP
410        >>> lmbda = lambda epoch: 0.95
411        >>> scheduler = MultiplicativeLR(optimizer, lr_lambda=lmbda)
412        >>> for epoch in range(100):
413        >>>     train(...)
414        >>>     validate(...)
415        >>>     scheduler.step()
416    """
417
418    def __init__(
419        self,
420        optimizer: Optimizer,
421        lr_lambda: Union[Callable[[int], float], List[Callable[[int], float]]],
422        last_epoch=-1,
423        verbose="deprecated",
424    ):  # noqa: D107
425        self.optimizer = optimizer
426
427        self.lr_lambdas: List[Callable[[int], float]]
428        if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
429            self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
430        else:
431            if len(lr_lambda) != len(optimizer.param_groups):
432                raise ValueError(
433                    f"Expected {len(optimizer.param_groups)} lr_lambdas, but got {len(lr_lambda)}"
434                )
435            self.lr_lambdas = list(lr_lambda)
436        super().__init__(optimizer, last_epoch, verbose)
437
438    def state_dict(self):
439        """Return the state of the scheduler as a :class:`dict`.
440
441        It contains an entry for every variable in self.__dict__ which
442        is not the optimizer.
443        The learning rate lambda functions will only be saved if they are callable objects
444        and not if they are functions or lambdas.
445        """
446        state_dict = {
447            key: value
448            for key, value in self.__dict__.items()
449            if key not in ("optimizer", "lr_lambdas")
450        }
451        state_dict["lr_lambdas"] = [None] * len(self.lr_lambdas)
452
453        for idx, fn in enumerate(self.lr_lambdas):
454            if not isinstance(fn, types.FunctionType):
455                state_dict["lr_lambdas"][idx] = fn.__dict__.copy()
456
457        return state_dict
458
459    def load_state_dict(self, state_dict):
460        """Load the scheduler's state.
461
462        Args:
463            state_dict (dict): scheduler state. Should be an object returned
464                from a call to :meth:`state_dict`.
465        """
466        lr_lambdas = state_dict.pop("lr_lambdas")
467        self.__dict__.update(state_dict)
468        # Restore state_dict keys in order to prevent side effects
469        # https://github.com/pytorch/pytorch/issues/32756
470        state_dict["lr_lambdas"] = lr_lambdas
471
472        for idx, fn in enumerate(lr_lambdas):
473            if fn is not None:
474                self.lr_lambdas[idx].__dict__.update(fn)
475
476    def get_lr(self):
477        """Compute the learning rate of each parameter group."""
478        _warn_get_lr_called_within_step(self)
479
480        if self.last_epoch > 0:
481            return [
482                group["lr"] * lmbda(self.last_epoch)
483                for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups)
484            ]
485        else:
486            return [group["lr"] for group in self.optimizer.param_groups]
487
488
489class StepLR(LRScheduler):
490    """Decays the learning rate of each parameter group by gamma every step_size epochs.
491
492    Notice that such decay can happen simultaneously with other changes to the learning rate
493    from outside this scheduler. When last_epoch=-1, sets initial lr as lr.
494
495    Args:
496        optimizer (Optimizer): Wrapped optimizer.
497        step_size (int): Period of learning rate decay.
498        gamma (float): Multiplicative factor of learning rate decay.
499            Default: 0.1.
500        last_epoch (int): The index of last epoch. Default: -1.
501        verbose (bool | str): If ``True``, prints a message to stdout for
502            each update. Default: ``False``.
503
504            .. deprecated:: 2.2
505                ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
506                learning rate.
507
508    Example:
509        >>> # xdoctest: +SKIP
510        >>> # Assuming optimizer uses lr = 0.05 for all groups
511        >>> # lr = 0.05     if epoch < 30
512        >>> # lr = 0.005    if 30 <= epoch < 60
513        >>> # lr = 0.0005   if 60 <= epoch < 90
514        >>> # ...
515        >>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
516        >>> for epoch in range(100):
517        >>>     train(...)
518        >>>     validate(...)
519        >>>     scheduler.step()
520    """
521
522    def __init__(
523        self,
524        optimizer: Optimizer,
525        step_size: int,
526        gamma=0.1,
527        last_epoch=-1,
528        verbose="deprecated",
529    ):  # noqa: D107
530        self.step_size = step_size
531        self.gamma = gamma
532        super().__init__(optimizer, last_epoch, verbose)
533
534    def get_lr(self):
535        """Compute the learning rate of each parameter group."""
536        _warn_get_lr_called_within_step(self)
537
538        if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0):
539            return [group["lr"] for group in self.optimizer.param_groups]
540        return [group["lr"] * self.gamma for group in self.optimizer.param_groups]
541
542    def _get_closed_form_lr(self):
543        return [
544            base_lr * self.gamma ** (self.last_epoch // self.step_size)
545            for base_lr in self.base_lrs
546        ]
547
548
549class MultiStepLR(LRScheduler):
550    """Decays the learning rate of each parameter group by gamma once the number of epoch reaches one of the milestones.
551
552    Notice that such decay can happen simultaneously with other changes to the learning rate
553    from outside this scheduler. When last_epoch=-1, sets initial lr as lr.
554
555    Args:
556        optimizer (Optimizer): Wrapped optimizer.
557        milestones (list): List of epoch indices. Must be increasing.
558        gamma (float): Multiplicative factor of learning rate decay.
559            Default: 0.1.
560        last_epoch (int): The index of last epoch. Default: -1.
561        verbose (bool | str): If ``True``, prints a message to stdout for
562            each update. Default: ``False``.
563
564            .. deprecated:: 2.2
565                ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
566                learning rate.
567
568    Example:
569        >>> # xdoctest: +SKIP
570        >>> # Assuming optimizer uses lr = 0.05 for all groups
571        >>> # lr = 0.05     if epoch < 30
572        >>> # lr = 0.005    if 30 <= epoch < 80
573        >>> # lr = 0.0005   if epoch >= 80
574        >>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)
575        >>> for epoch in range(100):
576        >>>     train(...)
577        >>>     validate(...)
578        >>>     scheduler.step()
579    """
580
581    def __init__(
582        self,
583        optimizer: Optimizer,
584        milestones: Iterable[int],
585        gamma=0.1,
586        last_epoch=-1,
587        verbose="deprecated",
588    ):  # noqa: D107
589        self.milestones = Counter(milestones)
590        self.gamma = gamma
591        super().__init__(optimizer, last_epoch, verbose)
592
593    def get_lr(self):
594        """Compute the learning rate of each parameter group."""
595        _warn_get_lr_called_within_step(self)
596
597        if self.last_epoch not in self.milestones:
598            return [group["lr"] for group in self.optimizer.param_groups]
599        return [
600            group["lr"] * self.gamma ** self.milestones[self.last_epoch]
601            for group in self.optimizer.param_groups
602        ]
603
604    def _get_closed_form_lr(self):
605        milestones = sorted(self.milestones.elements())
606        return [
607            base_lr * self.gamma ** bisect_right(milestones, self.last_epoch)
608            for base_lr in self.base_lrs
609        ]
610
611
612class ConstantLR(LRScheduler):
613    """Multiply the learning rate of each parameter group by a small constant factor.
614
615    The multiplication is done until the number of epoch reaches a pre-defined milestone: total_iters.
616    Notice that such multiplication of the small constant factor can
617    happen simultaneously with other changes to the learning rate from outside this scheduler.
618    When last_epoch=-1, sets initial lr as lr.
619
620    Args:
621        optimizer (Optimizer): Wrapped optimizer.
622        factor (float): The number we multiply learning rate until the milestone. Default: 1./3.
623        total_iters (int): The number of steps that the scheduler multiplies the learning rate by the factor.
624            Default: 5.
625        last_epoch (int): The index of the last epoch. Default: -1.
626        verbose (bool | str): If ``True``, prints a message to stdout for
627            each update. Default: ``False``.
628
629            .. deprecated:: 2.2
630                ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
631                learning rate.
632
633    Example:
634        >>> # xdoctest: +SKIP
635        >>> # Assuming optimizer uses lr = 0.05 for all groups
636        >>> # lr = 0.025   if epoch == 0
637        >>> # lr = 0.025   if epoch == 1
638        >>> # lr = 0.025   if epoch == 2
639        >>> # lr = 0.025   if epoch == 3
640        >>> # lr = 0.05    if epoch >= 4
641        >>> scheduler = ConstantLR(optimizer, factor=0.5, total_iters=4)
642        >>> for epoch in range(100):
643        >>>     train(...)
644        >>>     validate(...)
645        >>>     scheduler.step()
646    """
647
648    def __init__(
649        self,
650        optimizer: Optimizer,
651        factor=1.0 / 3,
652        total_iters=5,
653        last_epoch=-1,
654        verbose="deprecated",
655    ):  # noqa: D107
656        if factor > 1.0 or factor < 0:
657            raise ValueError(
658                "Constant multiplicative factor expected to be between 0 and 1."
659            )
660
661        self.factor = factor
662        self.total_iters = total_iters
663        super().__init__(optimizer, last_epoch, verbose)
664
665    def get_lr(self):
666        """Compute the learning rate of each parameter group."""
667        _warn_get_lr_called_within_step(self)
668
669        if self.last_epoch == 0:
670            return [group["lr"] * self.factor for group in self.optimizer.param_groups]
671
672        if self.last_epoch != self.total_iters:
673            return [group["lr"] for group in self.optimizer.param_groups]
674
675        return [
676            group["lr"] * (1.0 / self.factor) for group in self.optimizer.param_groups
677        ]
678
679    def _get_closed_form_lr(self):
680        return [
681            base_lr
682            * (self.factor + (self.last_epoch >= self.total_iters) * (1 - self.factor))
683            for base_lr in self.base_lrs
684        ]
685
686
687class LinearLR(LRScheduler):
688    """Decays the learning rate of each parameter group by linearly changing small multiplicative factor.
689
690    The multiplication is done until the number of epoch reaches a pre-defined milestone: total_iters.
691    Notice that such decay can happen simultaneously with other changes to the learning rate
692    from outside this scheduler. When last_epoch=-1, sets initial lr as lr.
693
694    Args:
695        optimizer (Optimizer): Wrapped optimizer.
696        start_factor (float): The number we multiply learning rate in the first epoch.
697            The multiplication factor changes towards end_factor in the following epochs.
698            Default: 1./3.
699        end_factor (float): The number we multiply learning rate at the end of linear changing
700            process. Default: 1.0.
701        total_iters (int): The number of iterations that multiplicative factor reaches to 1.
702            Default: 5.
703        last_epoch (int): The index of the last epoch. Default: -1.
704        verbose (bool | str): If ``True``, prints a message to stdout for
705            each update. Default: ``False``.
706
707            .. deprecated:: 2.2
708                ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
709                learning rate.
710
711    Example:
712        >>> # xdoctest: +SKIP
713        >>> # Assuming optimizer uses lr = 0.05 for all groups
714        >>> # lr = 0.025    if epoch == 0
715        >>> # lr = 0.03125  if epoch == 1
716        >>> # lr = 0.0375   if epoch == 2
717        >>> # lr = 0.04375  if epoch == 3
718        >>> # lr = 0.05    if epoch >= 4
719        >>> scheduler = LinearLR(optimizer, start_factor=0.5, total_iters=4)
720        >>> for epoch in range(100):
721        >>>     train(...)
722        >>>     validate(...)
723        >>>     scheduler.step()
724    """
725
726    def __init__(
727        self,
728        optimizer: Optimizer,
729        start_factor=1.0 / 3,
730        end_factor=1.0,
731        total_iters=5,
732        last_epoch=-1,
733        verbose="deprecated",
734    ):  # noqa: D107
735        if start_factor > 1.0 or start_factor <= 0:
736            raise ValueError(
737                "Starting multiplicative factor expected to be greater than 0 and less or equal to 1."
738            )
739
740        if end_factor > 1.0 or end_factor < 0:
741            raise ValueError(
742                "Ending multiplicative factor expected to be between 0 and 1."
743            )
744
745        self.start_factor = start_factor
746        self.end_factor = end_factor
747        self.total_iters = total_iters
748        super().__init__(optimizer, last_epoch, verbose)
749
750    def get_lr(self):
751        """Compute the learning rate."""
752        _warn_get_lr_called_within_step(self)
753
754        if self.last_epoch == 0:
755            return [
756                group["lr"] * self.start_factor for group in self.optimizer.param_groups
757            ]
758
759        if self.last_epoch > self.total_iters:
760            return [group["lr"] for group in self.optimizer.param_groups]
761
762        return [
763            group["lr"]
764            * (
765                1.0
766                + (self.end_factor - self.start_factor)
767                / (
768                    self.total_iters * self.start_factor
769                    + (self.last_epoch - 1) * (self.end_factor - self.start_factor)
770                )
771            )
772            for group in self.optimizer.param_groups
773        ]
774
775    def _get_closed_form_lr(self):
776        return [
777            base_lr
778            * (
779                self.start_factor
780                + (self.end_factor - self.start_factor)
781                * min(self.total_iters, self.last_epoch)
782                / self.total_iters
783            )
784            for base_lr in self.base_lrs
785        ]
786
787
788class ExponentialLR(LRScheduler):
789    """Decays the learning rate of each parameter group by gamma every epoch.
790
791    When last_epoch=-1, sets initial lr as lr.
792
793    Args:
794        optimizer (Optimizer): Wrapped optimizer.
795        gamma (float): Multiplicative factor of learning rate decay.
796        last_epoch (int): The index of last epoch. Default: -1.
797        verbose (bool | str): If ``True``, prints a message to stdout for
798            each update. Default: ``False``.
799
800            .. deprecated:: 2.2
801                ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
802                learning rate.
803    """
804
805    def __init__(
806        self, optimizer: Optimizer, gamma: float, last_epoch=-1, verbose="deprecated"
807    ):  # noqa: D107
808        self.gamma = gamma
809        super().__init__(optimizer, last_epoch, verbose)
810
811    def get_lr(self):
812        """Compute the learning rate of each parameter group."""
813        _warn_get_lr_called_within_step(self)
814
815        if self.last_epoch == 0:
816            return [group["lr"] for group in self.optimizer.param_groups]
817        return [group["lr"] * self.gamma for group in self.optimizer.param_groups]
818
819    def _get_closed_form_lr(self):
820        return [base_lr * self.gamma**self.last_epoch for base_lr in self.base_lrs]
821
822
823class SequentialLR(LRScheduler):
824    """Contains a list of schedulers expected to be called sequentially during the optimization process.
825
826    Specifically, the schedulers will be called according to the milestone points, which should provide exact
827    intervals by which each scheduler should be called at a given epoch.
828
829    Args:
830        optimizer (Optimizer): Wrapped optimizer.
831        schedulers (list): List of chained schedulers.
832        milestones (list): List of integers that reflects milestone points.
833        last_epoch (int): The index of last epoch. Default: -1.
834        verbose (bool | str): Does nothing.
835
836            .. deprecated:: 2.2
837                ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
838                learning rate.
839
840    Example:
841        >>> # xdoctest: +SKIP
842        >>> # Assuming optimizer uses lr = 1. for all groups
843        >>> # lr = 0.1     if epoch == 0
844        >>> # lr = 0.1     if epoch == 1
845        >>> # lr = 0.9     if epoch == 2
846        >>> # lr = 0.81    if epoch == 3
847        >>> # lr = 0.729   if epoch == 4
848        >>> scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=2)
849        >>> scheduler2 = ExponentialLR(optimizer, gamma=0.9)
850        >>> scheduler = SequentialLR(optimizer, schedulers=[scheduler1, scheduler2], milestones=[2])
851        >>> for epoch in range(100):
852        >>>     train(...)
853        >>>     validate(...)
854        >>>     scheduler.step()
855    """
856
857    def __init__(
858        self,
859        optimizer: Optimizer,
860        schedulers: List[LRScheduler],
861        milestones: List[int],
862        last_epoch=-1,
863        verbose="deprecated",
864    ):  # noqa: D107
865        if len(schedulers) < 1:
866            raise ValueError(
867                f"{self.__class__.__name__} expects at least one scheduler, but got no scheduler."
868            )
869
870        for scheduler_idx, scheduler in enumerate(schedulers):
871            if not hasattr(scheduler, "optimizer"):
872                raise TypeError(
873                    f"{self.__class__.__name__} at index {scheduler_idx} should have `optimizer` as its attribute."
874                )
875            if isinstance(scheduler, ReduceLROnPlateau):
876                raise ValueError(
877                    f"{self.__class__.__name__} does not support `ReduceLROnPlateau` scheduler as it "
878                    "requires additional kwargs to be specified when calling `step`, "
879                    f"but got one at index {scheduler_idx} in the given schedulers sequence."
880                )
881            if optimizer != scheduler.optimizer:
882                raise ValueError(
883                    f"{self.__class__.__name__} expects all schedulers to belong to the same optimizer, but "
884                    f"got scheduler {scheduler.__class__.__name__} at index {scheduler_idx} has {scheduler.optimizer}, "
885                    f"which is different from {optimizer.__class__.__name__}."
886                )
887
888        if len(milestones) != len(schedulers) - 1:
889            raise ValueError(
890                "Sequential Schedulers expects number of schedulers provided to be one more "
891                f"than the number of milestone points, but got number of schedulers {len(schedulers)} and the "
892                f"number of milestones to be equal to {len(milestones)}"
893            )
894        _check_verbose_deprecated_warning(verbose)
895        self._schedulers = schedulers
896        self._milestones = milestones
897        self.last_epoch = last_epoch + 1
898        self.optimizer = optimizer
899
900        # Reset learning rates back to initial values
901        for group in self.optimizer.param_groups:
902            group["lr"] = group["initial_lr"]
903
904        # "Undo" the step performed by other schedulers
905        for scheduler in self._schedulers:
906            scheduler.last_epoch -= 1
907
908        # Perform the initial step for only the first scheduler
909        self._schedulers[0]._initial_step()
910
911        self._last_lr = schedulers[0].get_last_lr()
912
913    def step(self):
914        """Perform a step."""
915        self.last_epoch += 1
916        idx = bisect_right(self._milestones, self.last_epoch)
917        scheduler = self._schedulers[idx]
918        if idx > 0 and self._milestones[idx - 1] == self.last_epoch:
919            scheduler.step(0)
920        else:
921            scheduler.step()
922
923        self._last_lr = scheduler.get_last_lr()
924
925    def state_dict(self):
926        """Return the state of the scheduler as a :class:`dict`.
927
928        It contains an entry for every variable in self.__dict__ which
929        is not the optimizer.
930        The wrapped scheduler states will also be saved.
931        """
932        state_dict = {
933            key: value
934            for key, value in self.__dict__.items()
935            if key not in ("optimizer", "_schedulers")
936        }
937        state_dict["_schedulers"] = [None] * len(self._schedulers)
938
939        for idx, s in enumerate(self._schedulers):
940            state_dict["_schedulers"][idx] = s.state_dict()
941
942        return state_dict
943
944    def load_state_dict(self, state_dict):
945        """Load the scheduler's state.
946
947        Args:
948            state_dict (dict): scheduler state. Should be an object returned
949                from a call to :meth:`state_dict`.
950        """
951        _schedulers = state_dict.pop("_schedulers")
952        self.__dict__.update(state_dict)
953        # Restore state_dict keys in order to prevent side effects
954        # https://github.com/pytorch/pytorch/issues/32756
955        state_dict["_schedulers"] = _schedulers
956
957        for idx, s in enumerate(_schedulers):
958            self._schedulers[idx].load_state_dict(s)
959
960
961class PolynomialLR(LRScheduler):
962    """Decays the learning rate of each parameter group using a polynomial function in the given total_iters.
963
964    When last_epoch=-1, sets initial lr as lr.
965
966    Args:
967        optimizer (Optimizer): Wrapped optimizer.
968        total_iters (int): The number of steps that the scheduler decays the learning rate. Default: 5.
969        power (float): The power of the polynomial. Default: 1.0.
970        verbose (bool | str): If ``True``, prints a message to stdout for
971            each update. Default: ``False``.
972
973            .. deprecated:: 2.2
974                ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
975                learning rate.
976
977    Example:
978        >>> # xdoctest: +SKIP("undefined vars")
979        >>> # Assuming optimizer uses lr = 0.001 for all groups
980        >>> # lr = 0.001     if epoch == 0
981        >>> # lr = 0.00075   if epoch == 1
982        >>> # lr = 0.00050   if epoch == 2
983        >>> # lr = 0.00025   if epoch == 3
984        >>> # lr = 0.0       if epoch >= 4
985        >>> scheduler = PolynomialLR(optimizer, total_iters=4, power=1.0)
986        >>> for epoch in range(100):
987        >>>     train(...)
988        >>>     validate(...)
989        >>>     scheduler.step()
990    """
991
992    def __init__(
993        self,
994        optimizer: Optimizer,
995        total_iters=5,
996        power=1.0,
997        last_epoch=-1,
998        verbose="deprecated",
999    ):  # noqa: D107
1000        self.total_iters = total_iters
1001        self.power = power
1002        super().__init__(optimizer, last_epoch, verbose)
1003
1004    def get_lr(self):
1005        """Compute the learning rate."""
1006        _warn_get_lr_called_within_step(self)
1007
1008        if self.last_epoch == 0 or self.last_epoch > self.total_iters:
1009            return [group["lr"] for group in self.optimizer.param_groups]
1010
1011        decay_factor = (
1012            (1.0 - self.last_epoch / self.total_iters)
1013            / (1.0 - (self.last_epoch - 1) / self.total_iters)
1014        ) ** self.power
1015        return [group["lr"] * decay_factor for group in self.optimizer.param_groups]
1016
1017    def _get_closed_form_lr(self):
1018        return [
1019            (
1020                base_lr
1021                * (1.0 - min(self.total_iters, self.last_epoch) / self.total_iters)
1022                ** self.power
1023            )
1024            for base_lr in self.base_lrs
1025        ]
1026
1027
1028class CosineAnnealingLR(LRScheduler):
1029    r"""Set the learning rate of each parameter group using a cosine annealing schedule.
1030
1031    The :math:`\eta_{max}` is set to the initial lr and
1032    :math:`T_{cur}` is the number of epochs since the last restart in SGDR:
1033
1034    .. math::
1035        \begin{aligned}
1036            \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1
1037            + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right),
1038            & T_{cur} \neq (2k+1)T_{max}; \\
1039            \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})
1040            \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right),
1041            & T_{cur} = (2k+1)T_{max}.
1042        \end{aligned}
1043
1044    When last_epoch=-1, sets initial lr as lr. Notice that because the schedule
1045    is defined recursively, the learning rate can be simultaneously modified
1046    outside this scheduler by other operators. If the learning rate is set
1047    solely by this scheduler, the learning rate at each step becomes:
1048
1049    .. math::
1050        \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
1051        \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right)
1052
1053    It has been proposed in
1054    `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
1055    implements the cosine annealing part of SGDR, and not the restarts.
1056
1057    Args:
1058        optimizer (Optimizer): Wrapped optimizer.
1059        T_max (int): Maximum number of iterations.
1060        eta_min (float): Minimum learning rate. Default: 0.
1061        last_epoch (int): The index of last epoch. Default: -1.
1062        verbose (bool | str): If ``True``, prints a message to stdout for
1063            each update. Default: ``False``.
1064
1065            .. deprecated:: 2.2
1066                ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
1067                learning rate.
1068
1069    .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
1070        https://arxiv.org/abs/1608.03983
1071    """
1072
1073    def __init__(
1074        self,
1075        optimizer: Optimizer,
1076        T_max: int,
1077        eta_min=0.0,
1078        last_epoch=-1,
1079        verbose="deprecated",
1080    ):  # noqa: D107
1081        self.T_max = T_max
1082        self.eta_min = eta_min
1083        super().__init__(optimizer, last_epoch, verbose)
1084
1085    def get_lr(self):
1086        """Retrieve the learning rate of each parameter group."""
1087        _warn_get_lr_called_within_step(self)
1088
1089        if self.last_epoch == 0:
1090            return [group["lr"] for group in self.optimizer.param_groups]
1091        elif self._step_count == 1 and self.last_epoch > 0:
1092            return [
1093                self.eta_min
1094                + (base_lr - self.eta_min)
1095                * (1 + math.cos((self.last_epoch) * math.pi / self.T_max))
1096                / 2
1097                for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
1098            ]
1099        elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
1100            return [
1101                group["lr"]
1102                + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2
1103                for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
1104            ]
1105        return [
1106            (1 + math.cos(math.pi * self.last_epoch / self.T_max))
1107            / (1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max))
1108            * (group["lr"] - self.eta_min)
1109            + self.eta_min
1110            for group in self.optimizer.param_groups
1111        ]
1112
1113    def _get_closed_form_lr(self):
1114        return [
1115            self.eta_min
1116            + (base_lr - self.eta_min)
1117            * (1 + math.cos(math.pi * self.last_epoch / self.T_max))
1118            / 2
1119            for base_lr in self.base_lrs
1120        ]
1121
1122
1123class ChainedScheduler(LRScheduler):
1124    """Chains a list of learning rate schedulers.
1125
1126    Takes in a sequence of chainable learning rate schedulers and calls their
1127    step() functions consecutively in just one call to step().
1128
1129    Args:
1130        schedulers (sequence): sequence of chained schedulers.
1131        optimizer (Optimizer, optional): Wrapped optimizer. Default: None.
1132
1133    Example:
1134        >>> # xdoctest: +SKIP
1135        >>> # Assuming optimizer uses lr = 1. for all groups
1136        >>> # lr = 0.09     if epoch == 0
1137        >>> # lr = 0.081    if epoch == 1
1138        >>> # lr = 0.729    if epoch == 2
1139        >>> # lr = 0.6561   if epoch == 3
1140        >>> # lr = 0.59049  if epoch >= 4
1141        >>> scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=2)
1142        >>> scheduler2 = ExponentialLR(optimizer, gamma=0.9)
1143        >>> scheduler = ChainedScheduler([scheduler1, scheduler2], optimizer=optimizer)
1144        >>> for epoch in range(100):
1145        >>>     train(...)
1146        >>>     validate(...)
1147        >>>     scheduler.step()
1148    """
1149
1150    def __init__(
1151        self, schedulers: Sequence[LRScheduler], optimizer: Optional[Optimizer] = None
1152    ):  # noqa: D107
1153        if len(schedulers) < 1:
1154            raise ValueError(
1155                f"{self.__class__.__name__} expects at least one scheduler to be chained, but got no scheduler."
1156            )
1157
1158        optimizer = optimizer or schedulers[0].optimizer
1159        for scheduler_idx, scheduler in enumerate(schedulers):
1160            if not hasattr(scheduler, "optimizer"):
1161                raise TypeError(
1162                    f"{self.__class__.__name__} at index {scheduler_idx} should have `optimizer` as its attribute."
1163                )
1164            if isinstance(scheduler, ReduceLROnPlateau):
1165                raise ValueError(
1166                    f"{self.__class__.__name__} does not support `ReduceLROnPlateau` scheduler as it "
1167                    "requires additional kwargs to be specified when calling `step`, "
1168                    f"but got one at index {scheduler_idx} in the given schedulers sequence."
1169                )
1170            if optimizer != scheduler.optimizer:
1171                raise ValueError(
1172                    f"{self.__class__.__name__} expects all schedulers to belong to the same optimizer, but "
1173                    f"got scheduler {scheduler.__class__.__name__} at index {scheduler_idx} has {scheduler.optimizer}, "
1174                    f"which is different from {optimizer.__class__.__name__}."
1175                )
1176        self._schedulers = schedulers
1177        self.optimizer = optimizer
1178        self._last_lr = [
1179            group["lr"] for group in self._schedulers[-1].optimizer.param_groups
1180        ]
1181
1182    def step(self):
1183        """Perform a step."""
1184        for scheduler in self._schedulers:
1185            scheduler.step()
1186        self._last_lr = [
1187            group["lr"] for group in self._schedulers[-1].optimizer.param_groups
1188        ]
1189
1190    def state_dict(self):
1191        """Return the state of the scheduler as a :class:`dict`.
1192
1193        It contains an entry for every variable in self.__dict__ which
1194        is not the optimizer.
1195        The wrapped scheduler states will also be saved.
1196        """
1197        state_dict = {
1198            key: value
1199            for key, value in self.__dict__.items()
1200            if key not in ("optimizer", "_schedulers")
1201        }
1202        state_dict["_schedulers"] = [None] * len(self._schedulers)
1203
1204        for idx, s in enumerate(self._schedulers):
1205            state_dict["_schedulers"][idx] = s.state_dict()
1206
1207        return state_dict
1208
1209    def load_state_dict(self, state_dict):
1210        """Load the scheduler's state.
1211
1212        Args:
1213            state_dict (dict): scheduler state. Should be an object returned
1214                from a call to :meth:`state_dict`.
1215        """
1216        _schedulers = state_dict.pop("_schedulers")
1217        self.__dict__.update(state_dict)
1218        # Restore state_dict keys in order to prevent side effects
1219        # https://github.com/pytorch/pytorch/issues/32756
1220        state_dict["_schedulers"] = _schedulers
1221
1222        for idx, s in enumerate(_schedulers):
1223            self._schedulers[idx].load_state_dict(s)
1224
1225
1226class ReduceLROnPlateau(LRScheduler):
1227    """Reduce learning rate when a metric has stopped improving.
1228
1229    Models often benefit from reducing the learning rate by a factor
1230    of 2-10 once learning stagnates. This scheduler reads a metrics
1231    quantity and if no improvement is seen for a 'patience' number
1232    of epochs, the learning rate is reduced.
1233
1234    Args:
1235        optimizer (Optimizer): Wrapped optimizer.
1236        mode (str): One of `min`, `max`. In `min` mode, lr will
1237            be reduced when the quantity monitored has stopped
1238            decreasing; in `max` mode it will be reduced when the
1239            quantity monitored has stopped increasing. Default: 'min'.
1240        factor (float): Factor by which the learning rate will be
1241            reduced. new_lr = lr * factor. Default: 0.1.
1242        patience (int): The number of allowed epochs with no improvement after
1243            which the learning rate will be reduced.
1244            For example, consider the case of having no patience (`patience = 0`).
1245            In the first epoch, a baseline is established and is always considered good as there's no previous baseline.
1246            In the second epoch, if the performance is worse than the baseline,
1247            we have what is considered an intolerable epoch.
1248            Since the count of intolerable epochs (1) is greater than the patience level (0),
1249            the learning rate is reduced at the end of this epoch.
1250            From the third epoch onwards, the learning rate continues to be reduced at the end of each epoch
1251            if the performance is worse than the baseline. If the performance improves or remains the same,
1252            the learning rate is not adjusted.
1253            Default: 10.
1254        threshold (float): Threshold for measuring the new optimum,
1255            to only focus on significant changes. Default: 1e-4.
1256        threshold_mode (str): One of `rel`, `abs`. In `rel` mode,
1257            dynamic_threshold = best * ( 1 + threshold ) in 'max'
1258            mode or best * ( 1 - threshold ) in `min` mode.
1259            In `abs` mode, dynamic_threshold = best + threshold in
1260            `max` mode or best - threshold in `min` mode. Default: 'rel'.
1261        cooldown (int): Number of epochs to wait before resuming
1262            normal operation after lr has been reduced. Default: 0.
1263        min_lr (float or list): A scalar or a list of scalars. A
1264            lower bound on the learning rate of all param groups
1265            or each group respectively. Default: 0.
1266        eps (float): Minimal decay applied to lr. If the difference
1267            between new and old lr is smaller than eps, the update is
1268            ignored. Default: 1e-8.
1269        verbose (bool | str): If ``True``, prints a message to stdout for
1270            each update. Default: ``False``.
1271
1272            .. deprecated:: 2.2
1273                ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
1274                learning rate.
1275
1276    Example:
1277        >>> # xdoctest: +SKIP
1278        >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
1279        >>> scheduler = ReduceLROnPlateau(optimizer, 'min')
1280        >>> for epoch in range(10):
1281        >>>     train(...)
1282        >>>     val_loss = validate(...)
1283        >>>     # Note that step should be called after validate()
1284        >>>     scheduler.step(val_loss)
1285    """
1286
1287    def __init__(
1288        self,
1289        optimizer: Optimizer,
1290        mode: Literal["min", "max"] = "min",
1291        factor=0.1,
1292        patience=10,
1293        threshold=1e-4,
1294        threshold_mode: Literal["rel", "abs"] = "rel",
1295        cooldown=0,
1296        min_lr: Union[List[float], float] = 0,
1297        eps=1e-8,
1298        verbose="deprecated",
1299    ):  # noqa: D107
1300        if factor >= 1.0:
1301            raise ValueError("Factor should be < 1.0.")
1302        self.factor = factor
1303
1304        # Attach optimizer
1305        if not isinstance(optimizer, Optimizer):
1306            raise TypeError(f"{type(optimizer).__name__} is not an Optimizer")
1307        self.optimizer = optimizer
1308
1309        if isinstance(min_lr, (list, tuple)):
1310            if len(min_lr) != len(optimizer.param_groups):
1311                raise ValueError(
1312                    f"expected {len(optimizer.param_groups)} min_lrs, got {len(min_lr)}"
1313                )
1314            self.min_lrs = list(min_lr)
1315        else:
1316            self.min_lrs = [min_lr] * len(optimizer.param_groups)
1317
1318        self.patience = patience
1319
1320        self.verbose = _check_verbose_deprecated_warning(verbose)
1321        self.cooldown = cooldown
1322        self.cooldown_counter = 0
1323        self.mode = mode
1324        self.threshold = threshold
1325        self.threshold_mode = threshold_mode
1326        self.best: float
1327        self.num_bad_epochs: int
1328        self.mode_worse: float  # the worse value for the chosen mode
1329        self.eps = eps
1330        self.last_epoch = 0
1331        self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
1332        self._init_is_better(
1333            mode=mode, threshold=threshold, threshold_mode=threshold_mode
1334        )
1335        self._reset()
1336
1337    def _reset(self):
1338        """Reset num_bad_epochs counter and cooldown counter."""
1339        self.best = self.mode_worse
1340        self.cooldown_counter = 0
1341        self.num_bad_epochs = 0
1342
1343    def step(self, metrics: SupportsFloat, epoch=None):  # type: ignore[override]
1344        """Perform a step."""
1345        # convert `metrics` to float, in case it's a zero-dim Tensor
1346        current = float(metrics)
1347        if epoch is None:
1348            epoch = self.last_epoch + 1
1349        else:
1350            warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
1351        self.last_epoch = epoch
1352
1353        if self.is_better(current, self.best):
1354            self.best = current
1355            self.num_bad_epochs = 0
1356        else:
1357            self.num_bad_epochs += 1
1358
1359        if self.in_cooldown:
1360            self.cooldown_counter -= 1
1361            self.num_bad_epochs = 0  # ignore any bad epochs in cooldown
1362
1363        if self.num_bad_epochs > self.patience:
1364            self._reduce_lr(epoch)
1365            self.cooldown_counter = self.cooldown
1366            self.num_bad_epochs = 0
1367
1368        self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
1369
1370    def _reduce_lr(self, epoch):
1371        for i, param_group in enumerate(self.optimizer.param_groups):
1372            old_lr = float(param_group["lr"])
1373            new_lr = max(old_lr * self.factor, self.min_lrs[i])
1374            if old_lr - new_lr > self.eps:
1375                param_group["lr"] = new_lr
1376
1377    @property
1378    def in_cooldown(self):  # noqa: D102
1379        return self.cooldown_counter > 0
1380
1381    def is_better(self, a, best):  # noqa: D102
1382        if self.mode == "min" and self.threshold_mode == "rel":
1383            rel_epsilon = 1.0 - self.threshold
1384            return a < best * rel_epsilon
1385
1386        elif self.mode == "min" and self.threshold_mode == "abs":
1387            return a < best - self.threshold
1388
1389        elif self.mode == "max" and self.threshold_mode == "rel":
1390            rel_epsilon = self.threshold + 1.0
1391            return a > best * rel_epsilon
1392
1393        else:  # mode == 'max' and epsilon_mode == 'abs':
1394            return a > best + self.threshold
1395
1396    def _init_is_better(self, mode, threshold, threshold_mode):
1397        if mode not in {"min", "max"}:
1398            raise ValueError("mode " + mode + " is unknown!")
1399        if threshold_mode not in {"rel", "abs"}:
1400            raise ValueError("threshold mode " + threshold_mode + " is unknown!")
1401
1402        if mode == "min":
1403            self.mode_worse = inf
1404        else:  # mode == 'max':
1405            self.mode_worse = -inf
1406
1407        self.mode = mode
1408        self.threshold = threshold
1409        self.threshold_mode = threshold_mode
1410
1411    def state_dict(self):  # noqa: D102
1412        return {
1413            key: value for key, value in self.__dict__.items() if key != "optimizer"
1414        }
1415
1416    def load_state_dict(self, state_dict):
1417        """Load the scheduler's state."""
1418        self.__dict__.update(state_dict)
1419        self._init_is_better(
1420            mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode
1421        )
1422
1423
1424class CyclicLR(LRScheduler):
1425    r"""Sets the learning rate of each parameter group according to cyclical learning rate policy (CLR).
1426
1427    The policy cycles the learning rate between two boundaries with a constant frequency,
1428    as detailed in the paper `Cyclical Learning Rates for Training Neural Networks`_.
1429    The distance between the two boundaries can be scaled on a per-iteration
1430    or per-cycle basis.
1431
1432    Cyclical learning rate policy changes the learning rate after every batch.
1433    `step` should be called after a batch has been used for training.
1434
1435    This class has three built-in policies, as put forth in the paper:
1436
1437    * "triangular": A basic triangular cycle without amplitude scaling.
1438    * "triangular2": A basic triangular cycle that scales initial amplitude by half each cycle.
1439    * "exp_range": A cycle that scales initial amplitude by :math:`\text{gamma}^{\text{cycle iterations}}`
1440      at each cycle iteration.
1441
1442    This implementation was adapted from the github repo: `bckenstler/CLR`_
1443
1444    Args:
1445        optimizer (Optimizer): Wrapped optimizer.
1446        base_lr (float or list): Initial learning rate which is the
1447            lower boundary in the cycle for each parameter group.
1448        max_lr (float or list): Upper learning rate boundaries in the cycle
1449            for each parameter group. Functionally,
1450            it defines the cycle amplitude (max_lr - base_lr).
1451            The lr at any cycle is the sum of base_lr
1452            and some scaling of the amplitude; therefore
1453            max_lr may not actually be reached depending on
1454            scaling function.
1455        step_size_up (int): Number of training iterations in the
1456            increasing half of a cycle. Default: 2000
1457        step_size_down (int): Number of training iterations in the
1458            decreasing half of a cycle. If step_size_down is None,
1459            it is set to step_size_up. Default: None
1460        mode (str): One of {triangular, triangular2, exp_range}.
1461            Values correspond to policies detailed above.
1462            If scale_fn is not None, this argument is ignored.
1463            Default: 'triangular'
1464        gamma (float): Constant in 'exp_range' scaling function:
1465            gamma**(cycle iterations)
1466            Default: 1.0
1467        scale_fn (function): Custom scaling policy defined by a single
1468            argument lambda function, where
1469            0 <= scale_fn(x) <= 1 for all x >= 0.
1470            If specified, then 'mode' is ignored.
1471            Default: None
1472        scale_mode (str): {'cycle', 'iterations'}.
1473            Defines whether scale_fn is evaluated on
1474            cycle number or cycle iterations (training
1475            iterations since start of cycle).
1476            Default: 'cycle'
1477        cycle_momentum (bool): If ``True``, momentum is cycled inversely
1478            to learning rate between 'base_momentum' and 'max_momentum'.
1479            Default: True
1480        base_momentum (float or list): Lower momentum boundaries in the cycle
1481            for each parameter group. Note that momentum is cycled inversely
1482            to learning rate; at the peak of a cycle, momentum is
1483            'base_momentum' and learning rate is 'max_lr'.
1484            Default: 0.8
1485        max_momentum (float or list): Upper momentum boundaries in the cycle
1486            for each parameter group. Functionally,
1487            it defines the cycle amplitude (max_momentum - base_momentum).
1488            The momentum at any cycle is the difference of max_momentum
1489            and some scaling of the amplitude; therefore
1490            base_momentum may not actually be reached depending on
1491            scaling function. Note that momentum is cycled inversely
1492            to learning rate; at the start of a cycle, momentum is 'max_momentum'
1493            and learning rate is 'base_lr'
1494            Default: 0.9
1495        last_epoch (int): The index of the last batch. This parameter is used when
1496            resuming a training job. Since `step()` should be invoked after each
1497            batch instead of after each epoch, this number represents the total
1498            number of *batches* computed, not the total number of epochs computed.
1499            When last_epoch=-1, the schedule is started from the beginning.
1500            Default: -1
1501        verbose (bool | str): If ``True``, prints a message to stdout for
1502            each update. Default: ``False``.
1503
1504            .. deprecated:: 2.2
1505                ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
1506                learning rate.
1507
1508    Example:
1509        >>> # xdoctest: +SKIP
1510        >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
1511        >>> scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.1)
1512        >>> data_loader = torch.utils.data.DataLoader(...)
1513        >>> for epoch in range(10):
1514        >>>     for batch in data_loader:
1515        >>>         train_batch(...)
1516        >>>         scheduler.step()
1517
1518
1519    .. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186
1520    .. _bckenstler/CLR: https://github.com/bckenstler/CLR
1521    """
1522
1523    def __init__(
1524        self,
1525        optimizer: Optimizer,
1526        base_lr: Union[float, List[float]],
1527        max_lr: Union[float, List[float]],
1528        step_size_up=2000,
1529        step_size_down: Optional[int] = None,
1530        mode: Literal["triangular", "triangular2", "exp_range"] = "triangular",
1531        gamma=1.0,
1532        scale_fn: Optional[Callable[[float], float]] = None,
1533        scale_mode: Literal["cycle", "iterations"] = "cycle",
1534        cycle_momentum=True,
1535        base_momentum=0.8,
1536        max_momentum=0.9,
1537        last_epoch=-1,
1538        verbose="deprecated",
1539    ):  # noqa: D107
1540        # Attach optimizer
1541        if not isinstance(optimizer, Optimizer):
1542            raise TypeError(f"{type(optimizer).__name__} is not an Optimizer")
1543        self.optimizer = optimizer
1544
1545        base_lrs = _format_param("base_lr", optimizer, base_lr)
1546        if last_epoch == -1:
1547            for lr, group in zip(base_lrs, optimizer.param_groups):
1548                if isinstance(group["lr"], Tensor):
1549                    lr_val = lr.item() if isinstance(lr, Tensor) else lr
1550                    group["lr"].fill_(lr_val)
1551                else:
1552                    group["lr"] = lr
1553
1554        self.max_lrs = _format_param("max_lr", optimizer, max_lr)
1555
1556        step_size_up = float(step_size_up)
1557        step_size_down = (
1558            float(step_size_down) if step_size_down is not None else step_size_up
1559        )
1560        self.total_size = step_size_up + step_size_down
1561        self.step_ratio = step_size_up / self.total_size
1562
1563        if mode not in ["triangular", "triangular2", "exp_range"] and scale_fn is None:
1564            raise ValueError("mode is invalid and scale_fn is None")
1565
1566        self.mode = mode
1567        self.gamma = gamma
1568
1569        self._scale_fn_ref: Callable[[float], float]
1570        self._scale_fn_custom = scale_fn
1571        self.scale_mode = scale_mode
1572        self._init_scale_fn()
1573
1574        self.cycle_momentum = cycle_momentum
1575        if cycle_momentum:
1576            if (
1577                "momentum" not in optimizer.defaults
1578                and "betas" not in optimizer.defaults
1579            ):
1580                raise ValueError(
1581                    "optimizer must support momentum or beta1 with `cycle_momentum` option enabled"
1582                )
1583
1584            self.use_beta1 = "betas" in self.optimizer.defaults
1585            self.base_momentums = _format_param(
1586                "base_momentum", optimizer, base_momentum
1587            )
1588            self.max_momentums = _format_param("max_momentum", optimizer, max_momentum)
1589            if last_epoch == -1:
1590                for m_momentum, b_momentum, group in zip(
1591                    self.max_momentums, self.base_momentums, optimizer.param_groups
1592                ):
1593                    if self.use_beta1:
1594                        group["betas"] = (m_momentum, *group["betas"][1:])
1595                    else:
1596                        group["momentum"] = m_momentum
1597                    group["max_momentum"] = m_momentum
1598                    group["base_momentum"] = b_momentum
1599
1600        super().__init__(optimizer, last_epoch, verbose)
1601        self.base_lrs = base_lrs
1602
1603    def _init_scale_fn(self):
1604        if self._scale_fn_custom is not None:
1605            return
1606        if self.mode == "triangular":
1607            self._scale_fn_ref = self._triangular_scale_fn
1608            self.scale_mode = "cycle"
1609        elif self.mode == "triangular2":
1610            self._scale_fn_ref = self._triangular2_scale_fn
1611            self.scale_mode = "cycle"
1612        elif self.mode == "exp_range":
1613            self._scale_fn_ref = partial(self._exp_range_scale_fn, self.gamma)
1614            self.scale_mode = "iterations"
1615
1616    def scale_fn(self, x) -> float:
1617        """Get the scaling policy."""
1618        if self._scale_fn_custom is not None:
1619            return self._scale_fn_custom(x)
1620        else:
1621            return self._scale_fn_ref(x)  # static method
1622
1623    @staticmethod
1624    def _triangular_scale_fn(x: float) -> float:
1625        return 1.0
1626
1627    @staticmethod
1628    def _triangular2_scale_fn(x: float) -> float:
1629        return 1 / (2.0 ** (x - 1))
1630
1631    @staticmethod
1632    def _exp_range_scale_fn(gamma: float, x: float) -> float:
1633        return gamma**x
1634
1635    def get_lr(self):
1636        """Calculate the learning rate at batch index.
1637
1638        This function treats `self.last_epoch` as the last batch index.
1639
1640        If `self.cycle_momentum` is ``True``, this function has a side effect of
1641        updating the optimizer's momentum.
1642        """
1643        _warn_get_lr_called_within_step(self)
1644
1645        cycle = math.floor(1 + self.last_epoch / self.total_size)
1646        x = 1.0 + self.last_epoch / self.total_size - cycle
1647        if x <= self.step_ratio:
1648            scale_factor = x / self.step_ratio
1649        else:
1650            scale_factor = (x - 1) / (self.step_ratio - 1)
1651
1652        lrs = []
1653        for base_lr, max_lr in zip(self.base_lrs, self.max_lrs):
1654            base_height = (max_lr - base_lr) * scale_factor
1655            if self.scale_mode == "cycle":
1656                lr = base_lr + base_height * self.scale_fn(cycle)
1657            else:
1658                lr = base_lr + base_height * self.scale_fn(self.last_epoch)
1659            lrs.append(lr)
1660
1661        if self.cycle_momentum:
1662            momentums = []
1663            for base_momentum, max_momentum in zip(
1664                self.base_momentums, self.max_momentums
1665            ):
1666                base_height = (max_momentum - base_momentum) * scale_factor
1667                if self.scale_mode == "cycle":
1668                    momentum = max_momentum - base_height * self.scale_fn(cycle)
1669                else:
1670                    momentum = max_momentum - base_height * self.scale_fn(
1671                        self.last_epoch
1672                    )
1673                momentums.append(momentum)
1674            for param_group, momentum in zip(self.optimizer.param_groups, momentums):
1675                if self.use_beta1:
1676                    param_group["betas"] = (momentum, *param_group["betas"][1:])
1677                else:
1678                    param_group["momentum"] = momentum
1679
1680        return lrs
1681
1682    def state_dict(self):  # noqa: D102
1683        state = super().state_dict()
1684        # We are dropping the `_scale_fn_ref` attribute because it is a
1685        # `weakref.WeakMethod` and can't be pickled.
1686        state.pop("_scale_fn_ref", None)
1687        fn = state.pop("_scale_fn_custom")
1688        state["_scale_fn_custom"] = None
1689        if fn is not None and not isinstance(fn, types.FunctionType):
1690            # The _scale_fn_custom will only be saved if it is a callable object
1691            # and not if it is a function or lambda.
1692            state["_scale_fn_custom"] = fn.__dict__.copy()
1693
1694        return state
1695
1696    def load_state_dict(self, state_dict):
1697        """Load the scheduler's state."""
1698        fn = state_dict.pop("_scale_fn_custom")
1699        super().load_state_dict(state_dict)
1700        if fn is not None:
1701            self._scale_fn_custom.__dict__.update(fn)
1702        self._init_scale_fn()
1703
1704
1705class CosineAnnealingWarmRestarts(LRScheduler):
1706    r"""Set the learning rate of each parameter group using a cosine annealing schedule.
1707
1708    The :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}`
1709    is the number of epochs since the last restart and :math:`T_{i}` is the number
1710    of epochs between two warm restarts in SGDR:
1711
1712    .. math::
1713        \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
1714        \cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right)
1715
1716    When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`.
1717    When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`.
1718
1719    It has been proposed in
1720    `SGDR: Stochastic Gradient Descent with Warm Restarts`_.
1721
1722    Args:
1723        optimizer (Optimizer): Wrapped optimizer.
1724        T_0 (int): Number of iterations until the first restart.
1725        T_mult (int, optional): A factor by which :math:`T_{i}` increases after a restart. Default: 1.
1726        eta_min (float, optional): Minimum learning rate. Default: 0.
1727        last_epoch (int, optional): The index of the last epoch. Default: -1.
1728        verbose (bool | str): If ``True``, prints a message to stdout for
1729            each update. Default: ``False``.
1730
1731            .. deprecated:: 2.2
1732                ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
1733                learning rate.
1734
1735    .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
1736        https://arxiv.org/abs/1608.03983
1737    """
1738
1739    def __init__(
1740        self,
1741        optimizer: Optimizer,
1742        T_0: int,
1743        T_mult=1,
1744        eta_min=0.0,
1745        last_epoch=-1,
1746        verbose="deprecated",
1747    ):  # noqa: D107
1748        if T_0 <= 0 or not isinstance(T_0, int):
1749            raise ValueError(f"Expected positive integer T_0, but got {T_0}")
1750        if T_mult < 1 or not isinstance(T_mult, int):
1751            raise ValueError(f"Expected integer T_mult >= 1, but got {T_mult}")
1752        if not isinstance(eta_min, (float, int)):
1753            raise ValueError(
1754                f"Expected float or int eta_min, but got {eta_min} of type {type(eta_min)}"
1755            )
1756        self.T_0 = T_0
1757        self.T_i = T_0
1758        self.T_mult = T_mult
1759        self.eta_min = eta_min
1760        self.T_cur = last_epoch
1761        super().__init__(optimizer, last_epoch, verbose)
1762
1763    def get_lr(self):
1764        """Compute the initial learning rate."""
1765        _warn_get_lr_called_within_step(self)
1766
1767        return [
1768            self.eta_min
1769            + (base_lr - self.eta_min)
1770            * (1 + math.cos(math.pi * self.T_cur / self.T_i))
1771            / 2
1772            for base_lr in self.base_lrs
1773        ]
1774
1775    def step(self, epoch=None):
1776        """Step could be called after every batch update.
1777
1778        Example:
1779            >>> # xdoctest: +SKIP("Undefined vars")
1780            >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)
1781            >>> iters = len(dataloader)
1782            >>> for epoch in range(20):
1783            >>>     for i, sample in enumerate(dataloader):
1784            >>>         inputs, labels = sample['inputs'], sample['labels']
1785            >>>         optimizer.zero_grad()
1786            >>>         outputs = net(inputs)
1787            >>>         loss = criterion(outputs, labels)
1788            >>>         loss.backward()
1789            >>>         optimizer.step()
1790            >>>         scheduler.step(epoch + i / iters)
1791
1792        This function can be called in an interleaved way.
1793
1794        Example:
1795            >>> # xdoctest: +SKIP("Undefined vars")
1796            >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)
1797            >>> for epoch in range(20):
1798            >>>     scheduler.step()
1799            >>> scheduler.step(26)
1800            >>> scheduler.step() # scheduler.step(27), instead of scheduler(20)
1801        """
1802        if epoch is None and self.last_epoch < 0:
1803            epoch = 0
1804
1805        if epoch is None:
1806            epoch = self.last_epoch + 1
1807            self.T_cur = self.T_cur + 1
1808            if self.T_cur >= self.T_i:
1809                self.T_cur = self.T_cur - self.T_i
1810                self.T_i = self.T_i * self.T_mult
1811        else:
1812            if epoch < 0:
1813                raise ValueError(f"Expected non-negative epoch, but got {epoch}")
1814            if epoch >= self.T_0:
1815                if self.T_mult == 1:
1816                    self.T_cur = epoch % self.T_0
1817                else:
1818                    n = int(
1819                        math.log(
1820                            (epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult
1821                        )
1822                    )
1823                    self.T_cur = epoch - self.T_0 * (self.T_mult**n - 1) / (
1824                        self.T_mult - 1
1825                    )
1826                    self.T_i = self.T_0 * self.T_mult ** (n)
1827            else:
1828                self.T_i = self.T_0
1829                self.T_cur = epoch
1830        self.last_epoch = math.floor(epoch)
1831
1832        with _enable_get_lr_call(self):
1833            for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())):
1834                param_group, lr = data
1835                param_group["lr"] = lr
1836
1837        self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
1838
1839
1840class _SchedulePhase(TypedDict):
1841    end_step: float
1842    start_lr: str
1843    end_lr: str
1844    start_momentum: str
1845    end_momentum: str
1846
1847
1848class OneCycleLR(LRScheduler):
1849    r"""Sets the learning rate of each parameter group according to the 1cycle learning rate policy.
1850
1851    The 1cycle policy anneals the learning rate from an initial learning rate to some maximum
1852    learning rate and then from that maximum learning rate to some minimum learning rate much
1853    lower than the initial learning rate.
1854    This policy was initially described in the paper `Super-Convergence:
1855    Very Fast Training of Neural Networks Using Large Learning Rates`_.
1856
1857    The 1cycle learning rate policy changes the learning rate after every batch.
1858    `step` should be called after a batch has been used for training.
1859
1860    This scheduler is not chainable.
1861
1862    Note also that the total number of steps in the cycle can be determined in one
1863    of two ways (listed in order of precedence):
1864
1865    #. A value for total_steps is explicitly provided.
1866    #. A number of epochs (epochs) and a number of steps per epoch
1867       (steps_per_epoch) are provided.
1868       In this case, the number of total steps is inferred by
1869       total_steps = epochs * steps_per_epoch
1870
1871    You must either provide a value for total_steps or provide a value for both
1872    epochs and steps_per_epoch.
1873
1874    The default behaviour of this scheduler follows the fastai implementation of 1cycle, which
1875    claims that "unpublished work has shown even better results by using only two phases". To
1876    mimic the behaviour of the original paper instead, set ``three_phase=True``.
1877
1878    Args:
1879        optimizer (Optimizer): Wrapped optimizer.
1880        max_lr (float or list): Upper learning rate boundaries in the cycle
1881            for each parameter group.
1882        total_steps (int): The total number of steps in the cycle. Note that
1883            if a value is not provided here, then it must be inferred by providing
1884            a value for epochs and steps_per_epoch.
1885            Default: None
1886        epochs (int): The number of epochs to train for. This is used along
1887            with steps_per_epoch in order to infer the total number of steps in the cycle
1888            if a value for total_steps is not provided.
1889            Default: None
1890        steps_per_epoch (int): The number of steps per epoch to train for. This is
1891            used along with epochs in order to infer the total number of steps in the
1892            cycle if a value for total_steps is not provided.
1893            Default: None
1894        pct_start (float): The percentage of the cycle (in number of steps) spent
1895            increasing the learning rate.
1896            Default: 0.3
1897        anneal_strategy (str): {'cos', 'linear'}
1898            Specifies the annealing strategy: "cos" for cosine annealing, "linear" for
1899            linear annealing.
1900            Default: 'cos'
1901        cycle_momentum (bool): If ``True``, momentum is cycled inversely
1902            to learning rate between 'base_momentum' and 'max_momentum'.
1903            Default: True
1904        base_momentum (float or list): Lower momentum boundaries in the cycle
1905            for each parameter group. Note that momentum is cycled inversely
1906            to learning rate; at the peak of a cycle, momentum is
1907            'base_momentum' and learning rate is 'max_lr'.
1908            Default: 0.85
1909        max_momentum (float or list): Upper momentum boundaries in the cycle
1910            for each parameter group. Functionally,
1911            it defines the cycle amplitude (max_momentum - base_momentum).
1912            Note that momentum is cycled inversely
1913            to learning rate; at the start of a cycle, momentum is 'max_momentum'
1914            and learning rate is 'base_lr'
1915            Default: 0.95
1916        div_factor (float): Determines the initial learning rate via
1917            initial_lr = max_lr/div_factor
1918            Default: 25
1919        final_div_factor (float): Determines the minimum learning rate via
1920            min_lr = initial_lr/final_div_factor
1921            Default: 1e4
1922        three_phase (bool): If ``True``, use a third phase of the schedule to annihilate the
1923            learning rate according to 'final_div_factor' instead of modifying the second
1924            phase (the first two phases will be symmetrical about the step indicated by
1925            'pct_start').
1926        last_epoch (int): The index of the last batch. This parameter is used when
1927            resuming a training job. Since `step()` should be invoked after each
1928            batch instead of after each epoch, this number represents the total
1929            number of *batches* computed, not the total number of epochs computed.
1930            When last_epoch=-1, the schedule is started from the beginning.
1931            Default: -1
1932        verbose (bool | str): If ``True``, prints a message to stdout for
1933            each update. Default: ``False``.
1934
1935            .. deprecated:: 2.2
1936                ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
1937                learning rate.
1938
1939    Example:
1940        >>> # xdoctest: +SKIP
1941        >>> data_loader = torch.utils.data.DataLoader(...)
1942        >>> optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
1943        >>> scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10)
1944        >>> for epoch in range(10):
1945        >>>     for batch in data_loader:
1946        >>>         train_batch(...)
1947        >>>         optimizer.step()
1948        >>>         scheduler.step()
1949
1950
1951    .. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates:
1952        https://arxiv.org/abs/1708.07120
1953    """
1954
1955    def __init__(
1956        self,
1957        optimizer: Optimizer,
1958        max_lr: Union[float, List[float]],
1959        total_steps: Optional[int] = None,
1960        epochs: Optional[int] = None,
1961        steps_per_epoch: Optional[int] = None,
1962        pct_start=0.3,
1963        anneal_strategy: Literal["cos", "linear"] = "cos",
1964        cycle_momentum=True,
1965        base_momentum: Union[float, List[float]] = 0.85,
1966        max_momentum: Union[float, List[float]] = 0.95,
1967        div_factor=25.0,
1968        final_div_factor=1e4,
1969        three_phase=False,
1970        last_epoch=-1,
1971        verbose="deprecated",
1972    ):  # noqa: D107
1973        # Validate optimizer
1974        if not isinstance(optimizer, Optimizer):
1975            raise TypeError(f"{type(optimizer).__name__} is not an Optimizer")
1976        self.optimizer = optimizer
1977
1978        # Validate total_steps
1979        if total_steps is not None:
1980            if total_steps <= 0 or not isinstance(total_steps, int):
1981                raise ValueError(
1982                    f"Expected positive integer total_steps, but got {total_steps}"
1983                )
1984            self.total_steps = total_steps
1985        elif epochs is not None and steps_per_epoch is not None:
1986            if not isinstance(epochs, int) or epochs <= 0:
1987                raise ValueError(f"Expected positive integer epochs, but got {epochs}")
1988            if not isinstance(steps_per_epoch, int) or steps_per_epoch <= 0:
1989                raise ValueError(
1990                    f"Expected positive integer steps_per_epoch, but got {steps_per_epoch}"
1991                )
1992            self.total_steps = epochs * steps_per_epoch
1993        else:
1994            raise ValueError(
1995                "You must define either total_steps OR (epochs AND steps_per_epoch)"
1996            )
1997
1998        self._schedule_phases: List[_SchedulePhase]
1999        if three_phase:
2000            self._schedule_phases = [
2001                {
2002                    "end_step": float(pct_start * self.total_steps) - 1,
2003                    "start_lr": "initial_lr",
2004                    "end_lr": "max_lr",
2005                    "start_momentum": "max_momentum",
2006                    "end_momentum": "base_momentum",
2007                },
2008                {
2009                    "end_step": float(2 * pct_start * self.total_steps) - 2,
2010                    "start_lr": "max_lr",
2011                    "end_lr": "initial_lr",
2012                    "start_momentum": "base_momentum",
2013                    "end_momentum": "max_momentum",
2014                },
2015                {
2016                    "end_step": self.total_steps - 1,
2017                    "start_lr": "initial_lr",
2018                    "end_lr": "min_lr",
2019                    "start_momentum": "max_momentum",
2020                    "end_momentum": "max_momentum",
2021                },
2022            ]
2023        else:
2024            self._schedule_phases = [
2025                {
2026                    "end_step": float(pct_start * self.total_steps) - 1,
2027                    "start_lr": "initial_lr",
2028                    "end_lr": "max_lr",
2029                    "start_momentum": "max_momentum",
2030                    "end_momentum": "base_momentum",
2031                },
2032                {
2033                    "end_step": self.total_steps - 1,
2034                    "start_lr": "max_lr",
2035                    "end_lr": "min_lr",
2036                    "start_momentum": "base_momentum",
2037                    "end_momentum": "max_momentum",
2038                },
2039            ]
2040
2041        # Validate pct_start
2042        if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
2043            raise ValueError(
2044                f"Expected float between 0 and 1 pct_start, but got {pct_start}"
2045            )
2046
2047        # Validate anneal_strategy
2048        if anneal_strategy not in ["cos", "linear"]:
2049            raise ValueError(
2050                f"anneal_strategy must be one of 'cos' or 'linear', instead got {anneal_strategy}"
2051            )
2052        else:
2053            self._anneal_func_type = anneal_strategy
2054
2055        # Initialize learning rate variables
2056        max_lrs = _format_param("max_lr", self.optimizer, max_lr)
2057        if last_epoch == -1:
2058            for idx, group in enumerate(self.optimizer.param_groups):
2059                group["initial_lr"] = max_lrs[idx] / div_factor
2060                group["max_lr"] = max_lrs[idx]
2061                group["min_lr"] = group["initial_lr"] / final_div_factor
2062
2063        # Initialize momentum variables
2064        self.cycle_momentum = cycle_momentum
2065        if self.cycle_momentum:
2066            if (
2067                "momentum" not in self.optimizer.defaults
2068                and "betas" not in self.optimizer.defaults
2069            ):
2070                raise ValueError(
2071                    "optimizer must support momentum or beta1 with `cycle_momentum` option enabled"
2072                )
2073            self.use_beta1 = "betas" in self.optimizer.defaults
2074            max_momentums = _format_param("max_momentum", optimizer, max_momentum)
2075            base_momentums = _format_param("base_momentum", optimizer, base_momentum)
2076            if last_epoch == -1:
2077                for m_momentum, b_momentum, group in zip(
2078                    max_momentums, base_momentums, optimizer.param_groups
2079                ):
2080                    if self.use_beta1:
2081                        group["betas"] = (m_momentum, *group["betas"][1:])
2082                    else:
2083                        group["momentum"] = m_momentum
2084                    group["max_momentum"] = m_momentum
2085                    group["base_momentum"] = b_momentum
2086
2087        super().__init__(optimizer, last_epoch, verbose)
2088
2089    def _anneal_func(self, *args, **kwargs):
2090        if hasattr(self, "_anneal_func_type"):
2091            if self._anneal_func_type == "cos":
2092                return self._annealing_cos(*args, **kwargs)
2093            elif self._anneal_func_type == "linear":
2094                return self._annealing_linear(*args, **kwargs)
2095            else:
2096                raise ValueError(f"Unknown _anneal_func_type: {self._anneal_func_type}")
2097        else:
2098            # For BC
2099            return self.anneal_func(*args, **kwargs)  # type: ignore[attr-defined]
2100
2101    @staticmethod
2102    def _annealing_cos(start, end, pct):
2103        """Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."""
2104        cos_out = math.cos(math.pi * pct) + 1
2105        return end + (start - end) / 2.0 * cos_out
2106
2107    @staticmethod
2108    def _annealing_linear(start, end, pct):
2109        """Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0."""
2110        return (end - start) * pct + start
2111
2112    def get_lr(self):
2113        """Compute the learning rate of each parameter group."""
2114        _warn_get_lr_called_within_step(self)
2115
2116        lrs = []
2117        step_num = self.last_epoch
2118
2119        if step_num > self.total_steps:
2120            raise ValueError(
2121                f"Tried to step {step_num} times. The specified number of total steps is {self.total_steps}"  # noqa: UP032
2122            )
2123
2124        for group in self.optimizer.param_groups:
2125            start_step = 0.0
2126            for i, phase in enumerate(self._schedule_phases):
2127                end_step = phase["end_step"]
2128                if step_num <= end_step or i == len(self._schedule_phases) - 1:
2129                    pct = (step_num - start_step) / (end_step - start_step)
2130                    computed_lr = self._anneal_func(
2131                        group[phase["start_lr"]], group[phase["end_lr"]], pct
2132                    )
2133                    if self.cycle_momentum:
2134                        computed_momentum = self._anneal_func(
2135                            group[phase["start_momentum"]],
2136                            group[phase["end_momentum"]],
2137                            pct,
2138                        )
2139                    break
2140                start_step = phase["end_step"]
2141
2142            lrs.append(computed_lr)  # type: ignore[possibly-undefined]
2143            if self.cycle_momentum:
2144                if self.use_beta1:
2145                    group["betas"] = (computed_momentum, *group["betas"][1:])  # type: ignore[possibly-undefined]
2146                else:
2147                    group[
2148                        "momentum"
2149                    ] = computed_momentum  # type: ignore[possibly-undefined]
2150
2151        return lrs
2152