1# mypy: allow-untyped-defs 2r"""Implementation for Stochastic Weight Averaging implementation.""" 3import itertools 4import math 5import warnings 6from copy import deepcopy 7from typing import Any, Callable, Iterable, List, Literal, Optional, Tuple, Union 8 9import torch 10from torch import Tensor 11from torch.nn import Module 12from torch.optim.lr_scheduler import _format_param, LRScheduler 13from torch.utils._foreach_utils import _get_foreach_kernels_supported_devices 14 15from .optimizer import Optimizer 16 17 18__all__ = [ 19 "AveragedModel", 20 "update_bn", 21 "SWALR", 22 "get_ema_multi_avg_fn", 23 "get_swa_multi_avg_fn", 24 "get_ema_avg_fn", 25 "get_swa_avg_fn", 26] 27 28from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype 29 30 31PARAM_LIST = Union[Tuple[Tensor, ...], List[Tensor]] 32 33 34def get_ema_multi_avg_fn(decay=0.999): 35 """Get the function applying exponential moving average (EMA) across multiple params.""" 36 37 @torch.no_grad() 38 def ema_update(ema_param_list: PARAM_LIST, current_param_list: PARAM_LIST, _): 39 # foreach lerp only handles float and complex 40 if torch.is_floating_point(ema_param_list[0]) or torch.is_complex( 41 ema_param_list[0] 42 ): 43 torch._foreach_lerp_(ema_param_list, current_param_list, 1 - decay) 44 else: 45 for p_ema, p_model in zip(ema_param_list, current_param_list): 46 p_ema.copy_(p_ema * decay + p_model * (1 - decay)) 47 48 return ema_update 49 50 51def get_swa_multi_avg_fn(): 52 """Get the function applying stochastic weight average (SWA) across multiple params.""" 53 54 @torch.no_grad() 55 def swa_update( 56 averaged_param_list: PARAM_LIST, 57 current_param_list: PARAM_LIST, 58 num_averaged: Union[Tensor, int], 59 ): 60 # foreach lerp only handles float and complex 61 if torch.is_floating_point(averaged_param_list[0]) or torch.is_complex( 62 averaged_param_list[0] 63 ): 64 torch._foreach_lerp_( 65 averaged_param_list, current_param_list, 1 / (num_averaged + 1) 66 ) 67 else: 68 diffs = torch._foreach_sub(current_param_list, averaged_param_list) 69 if isinstance(num_averaged, Tensor): 70 torch._foreach_addcdiv_( 71 averaged_param_list, 72 diffs, 73 [num_averaged + 1] * len(averaged_param_list), 74 ) 75 else: 76 torch._foreach_add_( 77 averaged_param_list, diffs, alpha=1.0 / (num_averaged + 1) 78 ) 79 80 return swa_update 81 82 83def get_ema_avg_fn(decay=0.999): 84 """Get the function applying exponential moving average (EMA) across a single param.""" 85 86 @torch.no_grad() 87 def ema_update(ema_param: Tensor, current_param: Tensor, num_averaged): 88 return decay * ema_param + (1 - decay) * current_param 89 90 return ema_update 91 92 93def get_swa_avg_fn(): 94 """Get the function applying stochastic weight average (SWA) across a single param.""" 95 96 @torch.no_grad() 97 def swa_update( 98 averaged_param: Tensor, current_param: Tensor, num_averaged: Union[Tensor, int] 99 ): 100 return averaged_param + (current_param - averaged_param) / (num_averaged + 1) 101 102 return swa_update 103 104 105class AveragedModel(Module): 106 r"""Implements averaged model for Stochastic Weight Averaging (SWA) and Exponential Moving Average (EMA). 107 108 Stochastic Weight Averaging was proposed in `Averaging Weights Leads to 109 Wider Optima and Better Generalization`_ by Pavel Izmailov, Dmitrii 110 Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson 111 (UAI 2018). 112 113 Exponential Moving Average is a variation of `Polyak averaging`_, 114 but using exponential weights instead of equal weights across iterations. 115 116 AveragedModel class creates a copy of the provided module :attr:`model` 117 on the device :attr:`device` and allows to compute running averages of the 118 parameters of the :attr:`model`. 119 120 Args: 121 model (torch.nn.Module): model to use with SWA/EMA 122 device (torch.device, optional): if provided, the averaged model will be 123 stored on the :attr:`device` 124 avg_fn (function, optional): the averaging function used to update 125 parameters; the function must take in the current value of the 126 :class:`AveragedModel` parameter, the current value of :attr:`model` 127 parameter, and the number of models already averaged; if None, 128 an equally weighted average is used (default: None) 129 multi_avg_fn (function, optional): the averaging function used to update 130 parameters inplace; the function must take in the current values of the 131 :class:`AveragedModel` parameters as a list, the current values of :attr:`model` 132 parameters as a list, and the number of models already averaged; if None, 133 an equally weighted average is used (default: None) 134 use_buffers (bool): if ``True``, it will compute running averages for 135 both the parameters and the buffers of the model. (default: ``False``) 136 137 Example: 138 >>> # xdoctest: +SKIP("undefined variables") 139 >>> loader, optimizer, model, loss_fn = ... 140 >>> swa_model = torch.optim.swa_utils.AveragedModel(model) 141 >>> scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 142 >>> T_max=300) 143 >>> swa_start = 160 144 >>> swa_scheduler = SWALR(optimizer, swa_lr=0.05) 145 >>> for i in range(300): 146 >>> for input, target in loader: 147 >>> optimizer.zero_grad() 148 >>> loss_fn(model(input), target).backward() 149 >>> optimizer.step() 150 >>> if i > swa_start: 151 >>> swa_model.update_parameters(model) 152 >>> swa_scheduler.step() 153 >>> else: 154 >>> scheduler.step() 155 >>> 156 >>> # Update bn statistics for the swa_model at the end 157 >>> torch.optim.swa_utils.update_bn(loader, swa_model) 158 159 You can also use custom averaging functions with the `avg_fn` or `multi_avg_fn` parameters. 160 If no averaging function is provided, the default is to compute 161 equally-weighted average of the weights (SWA). 162 163 Example: 164 >>> # xdoctest: +SKIP("undefined variables") 165 >>> # Compute exponential moving averages of the weights and buffers 166 >>> ema_model = torch.optim.swa_utils.AveragedModel(model, 167 >>> torch.optim.swa_utils.get_ema_multi_avg_fn(0.9), use_buffers=True) 168 169 .. note:: 170 When using SWA/EMA with models containing Batch Normalization you may 171 need to update the activation statistics for Batch Normalization. 172 This can be done either by using the :meth:`torch.optim.swa_utils.update_bn` 173 or by setting :attr:`use_buffers` to `True`. The first approach updates the 174 statistics in a post-training step by passing data through the model. The 175 second does it during the parameter update phase by averaging all buffers. 176 Empirical evidence has shown that updating the statistics in normalization 177 layers increases accuracy, but you may wish to empirically test which 178 approach yields the best results in your problem. 179 180 .. note:: 181 :attr:`avg_fn` and `multi_avg_fn` are not saved in the :meth:`state_dict` of the model. 182 183 .. note:: 184 When :meth:`update_parameters` is called for the first time (i.e. 185 :attr:`n_averaged` is `0`) the parameters of `model` are copied 186 to the parameters of :class:`AveragedModel`. For every subsequent 187 call of :meth:`update_parameters` the function `avg_fn` is used 188 to update the parameters. 189 190 .. _Averaging Weights Leads to Wider Optima and Better Generalization: 191 https://arxiv.org/abs/1803.05407 192 .. _There Are Many Consistent Explanations of Unlabeled Data: Why You Should 193 Average: 194 https://arxiv.org/abs/1806.05594 195 .. _SWALP: Stochastic Weight Averaging in Low-Precision Training: 196 https://arxiv.org/abs/1904.11943 197 .. _Stochastic Weight Averaging in Parallel: Large-Batch Training That 198 Generalizes Well: 199 https://arxiv.org/abs/2001.02312 200 .. _Polyak averaging: 201 https://paperswithcode.com/method/polyak-averaging 202 """ 203 204 n_averaged: Tensor 205 206 def __init__( 207 self, 208 model: Module, 209 device: Optional[Union[int, torch.device]] = None, 210 avg_fn: Optional[Callable[[Tensor, Tensor, Union[Tensor, int]], Tensor]] = None, 211 multi_avg_fn: Optional[ 212 Callable[[PARAM_LIST, PARAM_LIST, Union[Tensor, int]], None] 213 ] = None, 214 use_buffers=False, 215 ): # noqa: D107 216 super().__init__() 217 assert ( 218 avg_fn is None or multi_avg_fn is None 219 ), "Only one of avg_fn and multi_avg_fn should be provided" 220 self.module = deepcopy(model) 221 if device is not None: 222 self.module = self.module.to(device) 223 self.register_buffer( 224 "n_averaged", torch.tensor(0, dtype=torch.long, device=device) 225 ) 226 self.avg_fn = avg_fn 227 self.multi_avg_fn = multi_avg_fn 228 self.use_buffers = use_buffers 229 230 def forward(self, *args, **kwargs): 231 """Forward pass.""" 232 return self.module(*args, **kwargs) 233 234 def update_parameters(self, model: Module): 235 """Update model parameters.""" 236 self_param = ( 237 itertools.chain(self.module.parameters(), self.module.buffers()) 238 if self.use_buffers 239 else self.parameters() 240 ) 241 model_param = ( 242 itertools.chain(model.parameters(), model.buffers()) 243 if self.use_buffers 244 else model.parameters() 245 ) 246 self_param_detached: List[Optional[Tensor]] = [] 247 model_param_detached: List[Optional[Tensor]] = [] 248 for p_averaged, p_model in zip(self_param, model_param): 249 p_model_ = p_model.detach().to(p_averaged.device) 250 self_param_detached.append(p_averaged.detach()) 251 model_param_detached.append(p_model_) 252 if self.n_averaged == 0: 253 p_averaged.detach().copy_(p_model_) 254 255 if self.n_averaged > 0: 256 if self.multi_avg_fn is not None or self.avg_fn is None: 257 grouped_tensors = _group_tensors_by_device_and_dtype( 258 [self_param_detached, model_param_detached] 259 ) 260 for (device, _), ( 261 [self_params, model_params], 262 _, 263 ) in grouped_tensors.items(): 264 if self.multi_avg_fn: 265 self.multi_avg_fn( 266 self_params, model_params, self.n_averaged.to(device) # type: ignore[arg-type] 267 ) 268 elif ( 269 device is not None 270 and device.type in _get_foreach_kernels_supported_devices() 271 ): 272 multi_avg_fn = get_swa_multi_avg_fn() 273 multi_avg_fn( 274 self_params, model_params, self.n_averaged.to(device) 275 ) 276 else: 277 avg_fn = get_swa_avg_fn() 278 n_averaged = self.n_averaged.to(device) 279 for p_averaged, p_model in zip(self_params, model_params): # type: ignore[assignment] 280 p_averaged.copy_(avg_fn(p_averaged, p_model, n_averaged)) 281 else: 282 for p_averaged, p_model in zip( # type: ignore[assignment] 283 self_param_detached, model_param_detached 284 ): 285 n_averaged = self.n_averaged.to(p_averaged.device) 286 p_averaged.detach().copy_( 287 self.avg_fn(p_averaged.detach(), p_model, n_averaged) 288 ) 289 290 if not self.use_buffers: 291 # If not apply running averages to the buffers, 292 # keep the buffers in sync with the source model. 293 for b_swa, b_model in zip(self.module.buffers(), model.buffers()): 294 b_swa.detach().copy_(b_model.detach().to(b_swa.device)) 295 self.n_averaged += 1 296 297 298@torch.no_grad() 299def update_bn( 300 loader: Iterable[Any], 301 model: Module, 302 device: Optional[Union[int, torch.device]] = None, 303): 304 r"""Update BatchNorm running_mean, running_var buffers in the model. 305 306 It performs one pass over data in `loader` to estimate the activation 307 statistics for BatchNorm layers in the model. 308 309 Args: 310 loader (torch.utils.data.DataLoader): dataset loader to compute the 311 activation statistics on. Each data batch should be either a 312 tensor, or a list/tuple whose first element is a tensor 313 containing data. 314 model (torch.nn.Module): model for which we seek to update BatchNorm 315 statistics. 316 device (torch.device, optional): If set, data will be transferred to 317 :attr:`device` before being passed into :attr:`model`. 318 319 Example: 320 >>> # xdoctest: +SKIP("Undefined variables") 321 >>> loader, model = ... 322 >>> torch.optim.swa_utils.update_bn(loader, model) 323 324 .. note:: 325 The `update_bn` utility assumes that each data batch in :attr:`loader` 326 is either a tensor or a list or tuple of tensors; in the latter case it 327 is assumed that :meth:`model.forward()` should be called on the first 328 element of the list or tuple corresponding to the data batch. 329 """ 330 momenta = {} 331 for module in model.modules(): 332 if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): 333 module.reset_running_stats() 334 momenta[module] = module.momentum 335 336 if not momenta: 337 return 338 339 was_training = model.training 340 model.train() 341 for module in momenta.keys(): 342 module.momentum = None 343 344 for input in loader: 345 if isinstance(input, (list, tuple)): 346 input = input[0] 347 if device is not None: 348 input = input.to(device) 349 350 model(input) 351 352 for bn_module in momenta.keys(): 353 bn_module.momentum = momenta[bn_module] 354 model.train(was_training) 355 356 357class SWALR(LRScheduler): 358 r"""Anneals the learning rate in each parameter group to a fixed value. 359 360 This learning rate scheduler is meant to be used with Stochastic Weight 361 Averaging (SWA) method (see `torch.optim.swa_utils.AveragedModel`). 362 363 Args: 364 optimizer (torch.optim.Optimizer): wrapped optimizer 365 swa_lrs (float or list): the learning rate value for all param groups 366 together or separately for each group. 367 annealing_epochs (int): number of epochs in the annealing phase 368 (default: 10) 369 annealing_strategy (str): "cos" or "linear"; specifies the annealing 370 strategy: "cos" for cosine annealing, "linear" for linear annealing 371 (default: "cos") 372 last_epoch (int): the index of the last epoch (default: -1) 373 374 The :class:`SWALR` scheduler can be used together with other 375 schedulers to switch to a constant learning rate late in the training 376 as in the example below. 377 378 Example: 379 >>> # xdoctest: +SKIP("Undefined variables") 380 >>> loader, optimizer, model = ... 381 >>> lr_lambda = lambda epoch: 0.9 382 >>> scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, 383 >>> lr_lambda=lr_lambda) 384 >>> swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, 385 >>> anneal_strategy="linear", anneal_epochs=20, swa_lr=0.05) 386 >>> swa_start = 160 387 >>> for i in range(300): 388 >>> for input, target in loader: 389 >>> optimizer.zero_grad() 390 >>> loss_fn(model(input), target).backward() 391 >>> optimizer.step() 392 >>> if i > swa_start: 393 >>> swa_scheduler.step() 394 >>> else: 395 >>> scheduler.step() 396 397 .. _Averaging Weights Leads to Wider Optima and Better Generalization: 398 https://arxiv.org/abs/1803.05407 399 """ 400 401 def __init__( 402 self, 403 optimizer: Optimizer, 404 swa_lr: float, 405 anneal_epochs=10, 406 anneal_strategy: Literal["cos", "linear"] = "cos", 407 last_epoch=-1, 408 ): # noqa: D107 409 swa_lrs = _format_param("swa_lr", optimizer, swa_lr) 410 for swa_lr, group in zip(swa_lrs, optimizer.param_groups): 411 group["swa_lr"] = swa_lr 412 if anneal_strategy not in ["cos", "linear"]: 413 raise ValueError( 414 "anneal_strategy must by one of 'cos' or 'linear', " 415 f"instead got {anneal_strategy}" 416 ) 417 elif anneal_strategy == "cos": 418 self.anneal_func = self._cosine_anneal 419 elif anneal_strategy == "linear": 420 self.anneal_func = self._linear_anneal 421 if not isinstance(anneal_epochs, int) or anneal_epochs < 0: 422 raise ValueError( 423 f"anneal_epochs must be equal or greater than 0, got {anneal_epochs}" 424 ) 425 self.anneal_epochs = anneal_epochs 426 super().__init__(optimizer, last_epoch) 427 428 @staticmethod 429 def _linear_anneal(t): 430 return t 431 432 @staticmethod 433 def _cosine_anneal(t): 434 return (1 - math.cos(math.pi * t)) / 2 435 436 @staticmethod 437 def _get_initial_lr(lr, swa_lr, alpha): 438 if alpha == 1: 439 return swa_lr 440 return (lr - alpha * swa_lr) / (1 - alpha) 441 442 def get_lr(self): 443 """Get learning rate.""" 444 # `_get_lr_called_within_step` is only available `_enable_get_lr_call`, 445 # so we ignore the type error here. See `LRScheduler.step()` for more details. 446 if not self._get_lr_called_within_step: # type: ignore[attr-defined] 447 warnings.warn( 448 "To get the last learning rate computed by the scheduler, " 449 "please use `get_last_lr()`.", 450 UserWarning, 451 ) 452 # Set in `LRScheduler._initial_step()` 453 step = self._step_count - 1 # type: ignore[attr-defined] 454 if self.anneal_epochs == 0: 455 step = max(1, step) 456 prev_t = max(0, min(1, (step - 1) / max(1, self.anneal_epochs))) 457 prev_alpha = self.anneal_func(prev_t) 458 prev_lrs = [ 459 self._get_initial_lr(group["lr"], group["swa_lr"], prev_alpha) 460 for group in self.optimizer.param_groups 461 ] 462 t = max(0, min(1, step / max(1, self.anneal_epochs))) 463 alpha = self.anneal_func(t) 464 return [ 465 group["swa_lr"] * alpha + lr * (1 - alpha) 466 for group, lr in zip(self.optimizer.param_groups, prev_lrs) 467 ] 468