xref: /aosp_15_r20/external/pytorch/torch/optim/swa_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2r"""Implementation for Stochastic Weight Averaging implementation."""
3import itertools
4import math
5import warnings
6from copy import deepcopy
7from typing import Any, Callable, Iterable, List, Literal, Optional, Tuple, Union
8
9import torch
10from torch import Tensor
11from torch.nn import Module
12from torch.optim.lr_scheduler import _format_param, LRScheduler
13from torch.utils._foreach_utils import _get_foreach_kernels_supported_devices
14
15from .optimizer import Optimizer
16
17
18__all__ = [
19    "AveragedModel",
20    "update_bn",
21    "SWALR",
22    "get_ema_multi_avg_fn",
23    "get_swa_multi_avg_fn",
24    "get_ema_avg_fn",
25    "get_swa_avg_fn",
26]
27
28from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
29
30
31PARAM_LIST = Union[Tuple[Tensor, ...], List[Tensor]]
32
33
34def get_ema_multi_avg_fn(decay=0.999):
35    """Get the function applying exponential moving average (EMA) across multiple params."""
36
37    @torch.no_grad()
38    def ema_update(ema_param_list: PARAM_LIST, current_param_list: PARAM_LIST, _):
39        # foreach lerp only handles float and complex
40        if torch.is_floating_point(ema_param_list[0]) or torch.is_complex(
41            ema_param_list[0]
42        ):
43            torch._foreach_lerp_(ema_param_list, current_param_list, 1 - decay)
44        else:
45            for p_ema, p_model in zip(ema_param_list, current_param_list):
46                p_ema.copy_(p_ema * decay + p_model * (1 - decay))
47
48    return ema_update
49
50
51def get_swa_multi_avg_fn():
52    """Get the function applying stochastic weight average (SWA) across multiple params."""
53
54    @torch.no_grad()
55    def swa_update(
56        averaged_param_list: PARAM_LIST,
57        current_param_list: PARAM_LIST,
58        num_averaged: Union[Tensor, int],
59    ):
60        # foreach lerp only handles float and complex
61        if torch.is_floating_point(averaged_param_list[0]) or torch.is_complex(
62            averaged_param_list[0]
63        ):
64            torch._foreach_lerp_(
65                averaged_param_list, current_param_list, 1 / (num_averaged + 1)
66            )
67        else:
68            diffs = torch._foreach_sub(current_param_list, averaged_param_list)
69            if isinstance(num_averaged, Tensor):
70                torch._foreach_addcdiv_(
71                    averaged_param_list,
72                    diffs,
73                    [num_averaged + 1] * len(averaged_param_list),
74                )
75            else:
76                torch._foreach_add_(
77                    averaged_param_list, diffs, alpha=1.0 / (num_averaged + 1)
78                )
79
80    return swa_update
81
82
83def get_ema_avg_fn(decay=0.999):
84    """Get the function applying exponential moving average (EMA) across a single param."""
85
86    @torch.no_grad()
87    def ema_update(ema_param: Tensor, current_param: Tensor, num_averaged):
88        return decay * ema_param + (1 - decay) * current_param
89
90    return ema_update
91
92
93def get_swa_avg_fn():
94    """Get the function applying stochastic weight average (SWA) across a single param."""
95
96    @torch.no_grad()
97    def swa_update(
98        averaged_param: Tensor, current_param: Tensor, num_averaged: Union[Tensor, int]
99    ):
100        return averaged_param + (current_param - averaged_param) / (num_averaged + 1)
101
102    return swa_update
103
104
105class AveragedModel(Module):
106    r"""Implements averaged model for Stochastic Weight Averaging (SWA) and Exponential Moving Average (EMA).
107
108    Stochastic Weight Averaging was proposed in `Averaging Weights Leads to
109    Wider Optima and Better Generalization`_ by Pavel Izmailov, Dmitrii
110    Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson
111    (UAI 2018).
112
113    Exponential Moving Average is a variation of `Polyak averaging`_,
114    but using exponential weights instead of equal weights across iterations.
115
116    AveragedModel class creates a copy of the provided module :attr:`model`
117    on the device :attr:`device` and allows to compute running averages of the
118    parameters of the :attr:`model`.
119
120    Args:
121        model (torch.nn.Module): model to use with SWA/EMA
122        device (torch.device, optional): if provided, the averaged model will be
123            stored on the :attr:`device`
124        avg_fn (function, optional): the averaging function used to update
125            parameters; the function must take in the current value of the
126            :class:`AveragedModel` parameter, the current value of :attr:`model`
127            parameter, and the number of models already averaged; if None,
128            an equally weighted average is used (default: None)
129        multi_avg_fn (function, optional): the averaging function used to update
130            parameters inplace; the function must take in the current values of the
131            :class:`AveragedModel` parameters as a list, the current values of :attr:`model`
132            parameters as a list, and the number of models already averaged; if None,
133            an equally weighted average is used (default: None)
134        use_buffers (bool): if ``True``, it will compute running averages for
135            both the parameters and the buffers of the model. (default: ``False``)
136
137    Example:
138        >>> # xdoctest: +SKIP("undefined variables")
139        >>> loader, optimizer, model, loss_fn = ...
140        >>> swa_model = torch.optim.swa_utils.AveragedModel(model)
141        >>> scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
142        >>>                                     T_max=300)
143        >>> swa_start = 160
144        >>> swa_scheduler = SWALR(optimizer, swa_lr=0.05)
145        >>> for i in range(300):
146        >>>      for input, target in loader:
147        >>>          optimizer.zero_grad()
148        >>>          loss_fn(model(input), target).backward()
149        >>>          optimizer.step()
150        >>>      if i > swa_start:
151        >>>          swa_model.update_parameters(model)
152        >>>          swa_scheduler.step()
153        >>>      else:
154        >>>          scheduler.step()
155        >>>
156        >>> # Update bn statistics for the swa_model at the end
157        >>> torch.optim.swa_utils.update_bn(loader, swa_model)
158
159    You can also use custom averaging functions with the `avg_fn` or `multi_avg_fn` parameters.
160    If no averaging function is provided, the default is to compute
161    equally-weighted average of the weights (SWA).
162
163    Example:
164        >>> # xdoctest: +SKIP("undefined variables")
165        >>> # Compute exponential moving averages of the weights and buffers
166        >>> ema_model = torch.optim.swa_utils.AveragedModel(model,
167        >>>             torch.optim.swa_utils.get_ema_multi_avg_fn(0.9), use_buffers=True)
168
169    .. note::
170        When using SWA/EMA with models containing Batch Normalization you may
171        need to update the activation statistics for Batch Normalization.
172        This can be done either by using the :meth:`torch.optim.swa_utils.update_bn`
173        or by setting :attr:`use_buffers` to `True`. The first approach updates the
174        statistics in a post-training step by passing data through the model. The
175        second does it during the parameter update phase by averaging all buffers.
176        Empirical evidence has shown that updating the statistics in normalization
177        layers increases accuracy, but you may wish to empirically test which
178        approach yields the best results in your problem.
179
180    .. note::
181        :attr:`avg_fn` and `multi_avg_fn` are not saved in the :meth:`state_dict` of the model.
182
183    .. note::
184        When :meth:`update_parameters` is called for the first time (i.e.
185        :attr:`n_averaged` is `0`) the parameters of `model` are copied
186        to the parameters of :class:`AveragedModel`. For every subsequent
187        call of :meth:`update_parameters` the function `avg_fn` is used
188        to update the parameters.
189
190    .. _Averaging Weights Leads to Wider Optima and Better Generalization:
191        https://arxiv.org/abs/1803.05407
192    .. _There Are Many Consistent Explanations of Unlabeled Data: Why You Should
193        Average:
194        https://arxiv.org/abs/1806.05594
195    .. _SWALP: Stochastic Weight Averaging in Low-Precision Training:
196        https://arxiv.org/abs/1904.11943
197    .. _Stochastic Weight Averaging in Parallel: Large-Batch Training That
198        Generalizes Well:
199        https://arxiv.org/abs/2001.02312
200    .. _Polyak averaging:
201        https://paperswithcode.com/method/polyak-averaging
202    """
203
204    n_averaged: Tensor
205
206    def __init__(
207        self,
208        model: Module,
209        device: Optional[Union[int, torch.device]] = None,
210        avg_fn: Optional[Callable[[Tensor, Tensor, Union[Tensor, int]], Tensor]] = None,
211        multi_avg_fn: Optional[
212            Callable[[PARAM_LIST, PARAM_LIST, Union[Tensor, int]], None]
213        ] = None,
214        use_buffers=False,
215    ):  # noqa: D107
216        super().__init__()
217        assert (
218            avg_fn is None or multi_avg_fn is None
219        ), "Only one of avg_fn and multi_avg_fn should be provided"
220        self.module = deepcopy(model)
221        if device is not None:
222            self.module = self.module.to(device)
223        self.register_buffer(
224            "n_averaged", torch.tensor(0, dtype=torch.long, device=device)
225        )
226        self.avg_fn = avg_fn
227        self.multi_avg_fn = multi_avg_fn
228        self.use_buffers = use_buffers
229
230    def forward(self, *args, **kwargs):
231        """Forward pass."""
232        return self.module(*args, **kwargs)
233
234    def update_parameters(self, model: Module):
235        """Update model parameters."""
236        self_param = (
237            itertools.chain(self.module.parameters(), self.module.buffers())
238            if self.use_buffers
239            else self.parameters()
240        )
241        model_param = (
242            itertools.chain(model.parameters(), model.buffers())
243            if self.use_buffers
244            else model.parameters()
245        )
246        self_param_detached: List[Optional[Tensor]] = []
247        model_param_detached: List[Optional[Tensor]] = []
248        for p_averaged, p_model in zip(self_param, model_param):
249            p_model_ = p_model.detach().to(p_averaged.device)
250            self_param_detached.append(p_averaged.detach())
251            model_param_detached.append(p_model_)
252            if self.n_averaged == 0:
253                p_averaged.detach().copy_(p_model_)
254
255        if self.n_averaged > 0:
256            if self.multi_avg_fn is not None or self.avg_fn is None:
257                grouped_tensors = _group_tensors_by_device_and_dtype(
258                    [self_param_detached, model_param_detached]
259                )
260                for (device, _), (
261                    [self_params, model_params],
262                    _,
263                ) in grouped_tensors.items():
264                    if self.multi_avg_fn:
265                        self.multi_avg_fn(
266                            self_params, model_params, self.n_averaged.to(device)  # type: ignore[arg-type]
267                        )
268                    elif (
269                        device is not None
270                        and device.type in _get_foreach_kernels_supported_devices()
271                    ):
272                        multi_avg_fn = get_swa_multi_avg_fn()
273                        multi_avg_fn(
274                            self_params, model_params, self.n_averaged.to(device)
275                        )
276                    else:
277                        avg_fn = get_swa_avg_fn()
278                        n_averaged = self.n_averaged.to(device)
279                        for p_averaged, p_model in zip(self_params, model_params):  # type: ignore[assignment]
280                            p_averaged.copy_(avg_fn(p_averaged, p_model, n_averaged))
281            else:
282                for p_averaged, p_model in zip(  # type: ignore[assignment]
283                    self_param_detached, model_param_detached
284                ):
285                    n_averaged = self.n_averaged.to(p_averaged.device)
286                    p_averaged.detach().copy_(
287                        self.avg_fn(p_averaged.detach(), p_model, n_averaged)
288                    )
289
290        if not self.use_buffers:
291            # If not apply running averages to the buffers,
292            # keep the buffers in sync with the source model.
293            for b_swa, b_model in zip(self.module.buffers(), model.buffers()):
294                b_swa.detach().copy_(b_model.detach().to(b_swa.device))
295        self.n_averaged += 1
296
297
298@torch.no_grad()
299def update_bn(
300    loader: Iterable[Any],
301    model: Module,
302    device: Optional[Union[int, torch.device]] = None,
303):
304    r"""Update BatchNorm running_mean, running_var buffers in the model.
305
306    It performs one pass over data in `loader` to estimate the activation
307    statistics for BatchNorm layers in the model.
308
309    Args:
310        loader (torch.utils.data.DataLoader): dataset loader to compute the
311            activation statistics on. Each data batch should be either a
312            tensor, or a list/tuple whose first element is a tensor
313            containing data.
314        model (torch.nn.Module): model for which we seek to update BatchNorm
315            statistics.
316        device (torch.device, optional): If set, data will be transferred to
317            :attr:`device` before being passed into :attr:`model`.
318
319    Example:
320        >>> # xdoctest: +SKIP("Undefined variables")
321        >>> loader, model = ...
322        >>> torch.optim.swa_utils.update_bn(loader, model)
323
324    .. note::
325        The `update_bn` utility assumes that each data batch in :attr:`loader`
326        is either a tensor or a list or tuple of tensors; in the latter case it
327        is assumed that :meth:`model.forward()` should be called on the first
328        element of the list or tuple corresponding to the data batch.
329    """
330    momenta = {}
331    for module in model.modules():
332        if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
333            module.reset_running_stats()
334            momenta[module] = module.momentum
335
336    if not momenta:
337        return
338
339    was_training = model.training
340    model.train()
341    for module in momenta.keys():
342        module.momentum = None
343
344    for input in loader:
345        if isinstance(input, (list, tuple)):
346            input = input[0]
347        if device is not None:
348            input = input.to(device)
349
350        model(input)
351
352    for bn_module in momenta.keys():
353        bn_module.momentum = momenta[bn_module]
354    model.train(was_training)
355
356
357class SWALR(LRScheduler):
358    r"""Anneals the learning rate in each parameter group to a fixed value.
359
360    This learning rate scheduler is meant to be used with Stochastic Weight
361    Averaging (SWA) method (see `torch.optim.swa_utils.AveragedModel`).
362
363    Args:
364        optimizer (torch.optim.Optimizer): wrapped optimizer
365        swa_lrs (float or list): the learning rate value for all param groups
366            together or separately for each group.
367        annealing_epochs (int): number of epochs in the annealing phase
368            (default: 10)
369        annealing_strategy (str): "cos" or "linear"; specifies the annealing
370            strategy: "cos" for cosine annealing, "linear" for linear annealing
371            (default: "cos")
372        last_epoch (int): the index of the last epoch (default: -1)
373
374    The :class:`SWALR` scheduler can be used together with other
375    schedulers to switch to a constant learning rate late in the training
376    as in the example below.
377
378    Example:
379        >>> # xdoctest: +SKIP("Undefined variables")
380        >>> loader, optimizer, model = ...
381        >>> lr_lambda = lambda epoch: 0.9
382        >>> scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer,
383        >>>        lr_lambda=lr_lambda)
384        >>> swa_scheduler = torch.optim.swa_utils.SWALR(optimizer,
385        >>>        anneal_strategy="linear", anneal_epochs=20, swa_lr=0.05)
386        >>> swa_start = 160
387        >>> for i in range(300):
388        >>>      for input, target in loader:
389        >>>          optimizer.zero_grad()
390        >>>          loss_fn(model(input), target).backward()
391        >>>          optimizer.step()
392        >>>      if i > swa_start:
393        >>>          swa_scheduler.step()
394        >>>      else:
395        >>>          scheduler.step()
396
397    .. _Averaging Weights Leads to Wider Optima and Better Generalization:
398        https://arxiv.org/abs/1803.05407
399    """
400
401    def __init__(
402        self,
403        optimizer: Optimizer,
404        swa_lr: float,
405        anneal_epochs=10,
406        anneal_strategy: Literal["cos", "linear"] = "cos",
407        last_epoch=-1,
408    ):  # noqa: D107
409        swa_lrs = _format_param("swa_lr", optimizer, swa_lr)
410        for swa_lr, group in zip(swa_lrs, optimizer.param_groups):
411            group["swa_lr"] = swa_lr
412        if anneal_strategy not in ["cos", "linear"]:
413            raise ValueError(
414                "anneal_strategy must by one of 'cos' or 'linear', "
415                f"instead got {anneal_strategy}"
416            )
417        elif anneal_strategy == "cos":
418            self.anneal_func = self._cosine_anneal
419        elif anneal_strategy == "linear":
420            self.anneal_func = self._linear_anneal
421        if not isinstance(anneal_epochs, int) or anneal_epochs < 0:
422            raise ValueError(
423                f"anneal_epochs must be equal or greater than 0, got {anneal_epochs}"
424            )
425        self.anneal_epochs = anneal_epochs
426        super().__init__(optimizer, last_epoch)
427
428    @staticmethod
429    def _linear_anneal(t):
430        return t
431
432    @staticmethod
433    def _cosine_anneal(t):
434        return (1 - math.cos(math.pi * t)) / 2
435
436    @staticmethod
437    def _get_initial_lr(lr, swa_lr, alpha):
438        if alpha == 1:
439            return swa_lr
440        return (lr - alpha * swa_lr) / (1 - alpha)
441
442    def get_lr(self):
443        """Get learning rate."""
444        # `_get_lr_called_within_step` is only available `_enable_get_lr_call`,
445        # so we ignore the type error here. See `LRScheduler.step()` for more details.
446        if not self._get_lr_called_within_step:  # type: ignore[attr-defined]
447            warnings.warn(
448                "To get the last learning rate computed by the scheduler, "
449                "please use `get_last_lr()`.",
450                UserWarning,
451            )
452        # Set in `LRScheduler._initial_step()`
453        step = self._step_count - 1  # type: ignore[attr-defined]
454        if self.anneal_epochs == 0:
455            step = max(1, step)
456        prev_t = max(0, min(1, (step - 1) / max(1, self.anneal_epochs)))
457        prev_alpha = self.anneal_func(prev_t)
458        prev_lrs = [
459            self._get_initial_lr(group["lr"], group["swa_lr"], prev_alpha)
460            for group in self.optimizer.param_groups
461        ]
462        t = max(0, min(1, step / max(1, self.anneal_epochs)))
463        alpha = self.anneal_func(t)
464        return [
465            group["swa_lr"] * alpha + lr * (1 - alpha)
466            for group, lr in zip(self.optimizer.param_groups, prev_lrs)
467        ]
468