1# mypy: allow-untyped-defs 2from typing import Any, Optional 3 4import torch 5from torch import Tensor 6from torch.nn import functional as F, init 7from torch.nn.parameter import Parameter, UninitializedBuffer, UninitializedParameter 8 9from ._functions import SyncBatchNorm as sync_batch_norm 10from .lazy import LazyModuleMixin 11from .module import Module 12 13 14__all__ = [ 15 "BatchNorm1d", 16 "LazyBatchNorm1d", 17 "BatchNorm2d", 18 "LazyBatchNorm2d", 19 "BatchNorm3d", 20 "LazyBatchNorm3d", 21 "SyncBatchNorm", 22] 23 24 25class _NormBase(Module): 26 """Common base of _InstanceNorm and _BatchNorm.""" 27 28 _version = 2 29 __constants__ = ["track_running_stats", "momentum", "eps", "num_features", "affine"] 30 num_features: int 31 eps: float 32 momentum: Optional[float] 33 affine: bool 34 track_running_stats: bool 35 # WARNING: weight and bias purposely not defined here. 36 # See https://github.com/pytorch/pytorch/issues/39670 37 38 def __init__( 39 self, 40 num_features: int, 41 eps: float = 1e-5, 42 momentum: Optional[float] = 0.1, 43 affine: bool = True, 44 track_running_stats: bool = True, 45 device=None, 46 dtype=None, 47 ) -> None: 48 factory_kwargs = {"device": device, "dtype": dtype} 49 super().__init__() 50 self.num_features = num_features 51 self.eps = eps 52 self.momentum = momentum 53 self.affine = affine 54 self.track_running_stats = track_running_stats 55 if self.affine: 56 self.weight = Parameter(torch.empty(num_features, **factory_kwargs)) 57 self.bias = Parameter(torch.empty(num_features, **factory_kwargs)) 58 else: 59 self.register_parameter("weight", None) 60 self.register_parameter("bias", None) 61 if self.track_running_stats: 62 self.register_buffer( 63 "running_mean", torch.zeros(num_features, **factory_kwargs) 64 ) 65 self.register_buffer( 66 "running_var", torch.ones(num_features, **factory_kwargs) 67 ) 68 self.running_mean: Optional[Tensor] 69 self.running_var: Optional[Tensor] 70 self.register_buffer( 71 "num_batches_tracked", 72 torch.tensor( 73 0, 74 dtype=torch.long, 75 **{k: v for k, v in factory_kwargs.items() if k != "dtype"}, 76 ), 77 ) 78 self.num_batches_tracked: Optional[Tensor] 79 else: 80 self.register_buffer("running_mean", None) 81 self.register_buffer("running_var", None) 82 self.register_buffer("num_batches_tracked", None) 83 self.reset_parameters() 84 85 def reset_running_stats(self) -> None: 86 if self.track_running_stats: 87 # running_mean/running_var/num_batches... are registered at runtime depending 88 # if self.track_running_stats is on 89 self.running_mean.zero_() # type: ignore[union-attr] 90 self.running_var.fill_(1) # type: ignore[union-attr] 91 self.num_batches_tracked.zero_() # type: ignore[union-attr,operator] 92 93 def reset_parameters(self) -> None: 94 self.reset_running_stats() 95 if self.affine: 96 init.ones_(self.weight) 97 init.zeros_(self.bias) 98 99 def _check_input_dim(self, input): 100 raise NotImplementedError 101 102 def extra_repr(self): 103 return ( 104 "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, " 105 "track_running_stats={track_running_stats}".format(**self.__dict__) 106 ) 107 108 def _load_from_state_dict( 109 self, 110 state_dict, 111 prefix, 112 local_metadata, 113 strict, 114 missing_keys, 115 unexpected_keys, 116 error_msgs, 117 ): 118 version = local_metadata.get("version", None) 119 120 if (version is None or version < 2) and self.track_running_stats: 121 # at version 2: added num_batches_tracked buffer 122 # this should have a default value of 0 123 num_batches_tracked_key = prefix + "num_batches_tracked" 124 if num_batches_tracked_key not in state_dict: 125 state_dict[num_batches_tracked_key] = ( 126 self.num_batches_tracked 127 if self.num_batches_tracked is not None 128 and self.num_batches_tracked.device != torch.device("meta") 129 else torch.tensor(0, dtype=torch.long) 130 ) 131 132 super()._load_from_state_dict( 133 state_dict, 134 prefix, 135 local_metadata, 136 strict, 137 missing_keys, 138 unexpected_keys, 139 error_msgs, 140 ) 141 142 143class _BatchNorm(_NormBase): 144 def __init__( 145 self, 146 num_features: int, 147 eps: float = 1e-5, 148 momentum: Optional[float] = 0.1, 149 affine: bool = True, 150 track_running_stats: bool = True, 151 device=None, 152 dtype=None, 153 ) -> None: 154 factory_kwargs = {"device": device, "dtype": dtype} 155 super().__init__( 156 num_features, eps, momentum, affine, track_running_stats, **factory_kwargs 157 ) 158 159 def forward(self, input: Tensor) -> Tensor: 160 self._check_input_dim(input) 161 162 # exponential_average_factor is set to self.momentum 163 # (when it is available) only so that it gets updated 164 # in ONNX graph when this node is exported to ONNX. 165 if self.momentum is None: 166 exponential_average_factor = 0.0 167 else: 168 exponential_average_factor = self.momentum 169 170 if self.training and self.track_running_stats: 171 # TODO: if statement only here to tell the jit to skip emitting this when it is None 172 if self.num_batches_tracked is not None: # type: ignore[has-type] 173 self.num_batches_tracked.add_(1) # type: ignore[has-type] 174 if self.momentum is None: # use cumulative moving average 175 exponential_average_factor = 1.0 / float(self.num_batches_tracked) 176 else: # use exponential moving average 177 exponential_average_factor = self.momentum 178 179 r""" 180 Decide whether the mini-batch stats should be used for normalization rather than the buffers. 181 Mini-batch stats are used in training mode, and in eval mode when buffers are None. 182 """ 183 if self.training: 184 bn_training = True 185 else: 186 bn_training = (self.running_mean is None) and (self.running_var is None) 187 188 r""" 189 Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be 190 passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are 191 used for normalization (i.e. in eval mode when buffers are not None). 192 """ 193 return F.batch_norm( 194 input, 195 # If buffers are not to be tracked, ensure that they won't be updated 196 self.running_mean 197 if not self.training or self.track_running_stats 198 else None, 199 self.running_var if not self.training or self.track_running_stats else None, 200 self.weight, 201 self.bias, 202 bn_training, 203 exponential_average_factor, 204 self.eps, 205 ) 206 207 208class _LazyNormBase(LazyModuleMixin, _NormBase): 209 weight: UninitializedParameter # type: ignore[assignment] 210 bias: UninitializedParameter # type: ignore[assignment] 211 212 def __init__( 213 self, 214 eps=1e-5, 215 momentum=0.1, 216 affine=True, 217 track_running_stats=True, 218 device=None, 219 dtype=None, 220 ) -> None: 221 factory_kwargs = {"device": device, "dtype": dtype} 222 super().__init__( 223 # affine and track_running_stats are hardcoded to False to 224 # avoid creating tensors that will soon be overwritten. 225 0, 226 eps, 227 momentum, 228 False, 229 False, 230 **factory_kwargs, 231 ) 232 self.affine = affine 233 self.track_running_stats = track_running_stats 234 if self.affine: 235 self.weight = UninitializedParameter(**factory_kwargs) 236 self.bias = UninitializedParameter(**factory_kwargs) 237 if self.track_running_stats: 238 self.running_mean = UninitializedBuffer(**factory_kwargs) 239 self.running_var = UninitializedBuffer(**factory_kwargs) 240 self.num_batches_tracked = torch.tensor( 241 0, 242 dtype=torch.long, 243 **{k: v for k, v in factory_kwargs.items() if k != "dtype"}, 244 ) 245 246 def reset_parameters(self) -> None: 247 if not self.has_uninitialized_params() and self.num_features != 0: 248 super().reset_parameters() 249 250 def initialize_parameters(self, input) -> None: # type: ignore[override] 251 if self.has_uninitialized_params(): 252 self.num_features = input.shape[1] 253 if self.affine: 254 assert isinstance(self.weight, UninitializedParameter) 255 assert isinstance(self.bias, UninitializedParameter) 256 self.weight.materialize((self.num_features,)) 257 self.bias.materialize((self.num_features,)) 258 if self.track_running_stats: 259 self.running_mean.materialize( # type:ignore[union-attr] 260 (self.num_features,) 261 ) 262 self.running_var.materialize( # type:ignore[union-attr] 263 (self.num_features,) 264 ) 265 self.reset_parameters() 266 267 268class BatchNorm1d(_BatchNorm): 269 r"""Applies Batch Normalization over a 2D or 3D input. 270 271 Method described in the paper 272 `Batch Normalization: Accelerating Deep Network Training by Reducing 273 Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ . 274 275 .. math:: 276 277 y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta 278 279 The mean and standard-deviation are calculated per-dimension over 280 the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors 281 of size `C` (where `C` is the number of features or channels of the input). By default, the 282 elements of :math:`\gamma` are set to 1 and the elements of :math:`\beta` are set to 0. 283 At train time in the forward pass, the standard-deviation is calculated via the biased estimator, 284 equivalent to ``torch.var(input, unbiased=False)``. However, the value stored in the 285 moving average of the standard-deviation is calculated via the unbiased estimator, equivalent to 286 ``torch.var(input, unbiased=True)``. 287 288 Also by default, during training this layer keeps running estimates of its 289 computed mean and variance, which are then used for normalization during 290 evaluation. The running estimates are kept with a default :attr:`momentum` 291 of 0.1. 292 293 If :attr:`track_running_stats` is set to ``False``, this layer then does not 294 keep running estimates, and batch statistics are instead used during 295 evaluation time as well. 296 297 .. note:: 298 This :attr:`momentum` argument is different from one used in optimizer 299 classes and the conventional notion of momentum. Mathematically, the 300 update rule for running statistics here is 301 :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, 302 where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the 303 new observed value. 304 305 Because the Batch Normalization is done over the `C` dimension, computing statistics 306 on `(N, L)` slices, it's common terminology to call this Temporal Batch Normalization. 307 308 Args: 309 num_features: number of features or channels :math:`C` of the input 310 eps: a value added to the denominator for numerical stability. 311 Default: 1e-5 312 momentum: the value used for the running_mean and running_var 313 computation. Can be set to ``None`` for cumulative moving average 314 (i.e. simple average). Default: 0.1 315 affine: a boolean value that when set to ``True``, this module has 316 learnable affine parameters. Default: ``True`` 317 track_running_stats: a boolean value that when set to ``True``, this 318 module tracks the running mean and variance, and when set to ``False``, 319 this module does not track such statistics, and initializes statistics 320 buffers :attr:`running_mean` and :attr:`running_var` as ``None``. 321 When these buffers are ``None``, this module always uses batch statistics. 322 in both training and eval modes. Default: ``True`` 323 324 Shape: 325 - Input: :math:`(N, C)` or :math:`(N, C, L)`, where :math:`N` is the batch size, 326 :math:`C` is the number of features or channels, and :math:`L` is the sequence length 327 - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 328 329 Examples:: 330 331 >>> # With Learnable Parameters 332 >>> m = nn.BatchNorm1d(100) 333 >>> # Without Learnable Parameters 334 >>> m = nn.BatchNorm1d(100, affine=False) 335 >>> input = torch.randn(20, 100) 336 >>> output = m(input) 337 """ 338 339 def _check_input_dim(self, input): 340 if input.dim() != 2 and input.dim() != 3: 341 raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)") 342 343 344class LazyBatchNorm1d(_LazyNormBase, _BatchNorm): 345 r"""A :class:`torch.nn.BatchNorm1d` module with lazy initialization. 346 347 Lazy initialization based on the ``num_features`` argument of the :class:`BatchNorm1d` that is inferred 348 from the ``input.size(1)``. 349 The attributes that will be lazily initialized are `weight`, `bias`, 350 `running_mean` and `running_var`. 351 352 Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation 353 on lazy modules and their limitations. 354 355 Args: 356 eps: a value added to the denominator for numerical stability. 357 Default: 1e-5 358 momentum: the value used for the running_mean and running_var 359 computation. Can be set to ``None`` for cumulative moving average 360 (i.e. simple average). Default: 0.1 361 affine: a boolean value that when set to ``True``, this module has 362 learnable affine parameters. Default: ``True`` 363 track_running_stats: a boolean value that when set to ``True``, this 364 module tracks the running mean and variance, and when set to ``False``, 365 this module does not track such statistics, and initializes statistics 366 buffers :attr:`running_mean` and :attr:`running_var` as ``None``. 367 When these buffers are ``None``, this module always uses batch statistics. 368 in both training and eval modes. Default: ``True`` 369 """ 370 371 cls_to_become = BatchNorm1d # type: ignore[assignment] 372 373 def _check_input_dim(self, input): 374 if input.dim() != 2 and input.dim() != 3: 375 raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)") 376 377 378class BatchNorm2d(_BatchNorm): 379 r"""Applies Batch Normalization over a 4D input. 380 381 4D is a mini-batch of 2D inputs 382 with additional channel dimension. Method described in the paper 383 `Batch Normalization: Accelerating Deep Network Training by Reducing 384 Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ . 385 386 .. math:: 387 388 y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta 389 390 The mean and standard-deviation are calculated per-dimension over 391 the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors 392 of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set 393 to 1 and the elements of :math:`\beta` are set to 0. At train time in the forward pass, the 394 standard-deviation is calculated via the biased estimator, equivalent to 395 ``torch.var(input, unbiased=False)``. However, the value stored in the moving average of the 396 standard-deviation is calculated via the unbiased estimator, equivalent to 397 ``torch.var(input, unbiased=True)``. 398 399 Also by default, during training this layer keeps running estimates of its 400 computed mean and variance, which are then used for normalization during 401 evaluation. The running estimates are kept with a default :attr:`momentum` 402 of 0.1. 403 404 If :attr:`track_running_stats` is set to ``False``, this layer then does not 405 keep running estimates, and batch statistics are instead used during 406 evaluation time as well. 407 408 .. note:: 409 This :attr:`momentum` argument is different from one used in optimizer 410 classes and the conventional notion of momentum. Mathematically, the 411 update rule for running statistics here is 412 :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, 413 where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the 414 new observed value. 415 416 Because the Batch Normalization is done over the `C` dimension, computing statistics 417 on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization. 418 419 Args: 420 num_features: :math:`C` from an expected input of size 421 :math:`(N, C, H, W)` 422 eps: a value added to the denominator for numerical stability. 423 Default: 1e-5 424 momentum: the value used for the running_mean and running_var 425 computation. Can be set to ``None`` for cumulative moving average 426 (i.e. simple average). Default: 0.1 427 affine: a boolean value that when set to ``True``, this module has 428 learnable affine parameters. Default: ``True`` 429 track_running_stats: a boolean value that when set to ``True``, this 430 module tracks the running mean and variance, and when set to ``False``, 431 this module does not track such statistics, and initializes statistics 432 buffers :attr:`running_mean` and :attr:`running_var` as ``None``. 433 When these buffers are ``None``, this module always uses batch statistics. 434 in both training and eval modes. Default: ``True`` 435 436 Shape: 437 - Input: :math:`(N, C, H, W)` 438 - Output: :math:`(N, C, H, W)` (same shape as input) 439 440 Examples:: 441 442 >>> # With Learnable Parameters 443 >>> m = nn.BatchNorm2d(100) 444 >>> # Without Learnable Parameters 445 >>> m = nn.BatchNorm2d(100, affine=False) 446 >>> input = torch.randn(20, 100, 35, 45) 447 >>> output = m(input) 448 """ 449 450 def _check_input_dim(self, input): 451 if input.dim() != 4: 452 raise ValueError(f"expected 4D input (got {input.dim()}D input)") 453 454 455class LazyBatchNorm2d(_LazyNormBase, _BatchNorm): 456 r"""A :class:`torch.nn.BatchNorm2d` module with lazy initialization. 457 458 Lazy initialization is done for the ``num_features`` argument of the :class:`BatchNorm2d` that is inferred 459 from the ``input.size(1)``. 460 The attributes that will be lazily initialized are `weight`, `bias`, 461 `running_mean` and `running_var`. 462 463 Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation 464 on lazy modules and their limitations. 465 466 Args: 467 eps: a value added to the denominator for numerical stability. 468 Default: 1e-5 469 momentum: the value used for the running_mean and running_var 470 computation. Can be set to ``None`` for cumulative moving average 471 (i.e. simple average). Default: 0.1 472 affine: a boolean value that when set to ``True``, this module has 473 learnable affine parameters. Default: ``True`` 474 track_running_stats: a boolean value that when set to ``True``, this 475 module tracks the running mean and variance, and when set to ``False``, 476 this module does not track such statistics, and initializes statistics 477 buffers :attr:`running_mean` and :attr:`running_var` as ``None``. 478 When these buffers are ``None``, this module always uses batch statistics. 479 in both training and eval modes. Default: ``True`` 480 """ 481 482 cls_to_become = BatchNorm2d # type: ignore[assignment] 483 484 def _check_input_dim(self, input): 485 if input.dim() != 4: 486 raise ValueError(f"expected 4D input (got {input.dim()}D input)") 487 488 489class BatchNorm3d(_BatchNorm): 490 r"""Applies Batch Normalization over a 5D input. 491 492 5D is a mini-batch of 3D inputs with additional channel dimension as described in the paper 493 `Batch Normalization: Accelerating Deep Network Training by Reducing 494 Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ . 495 496 .. math:: 497 498 y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta 499 500 The mean and standard-deviation are calculated per-dimension over 501 the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors 502 of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set 503 to 1 and the elements of :math:`\beta` are set to 0. At train time in the forward pass, the 504 standard-deviation is calculated via the biased estimator, equivalent to 505 ``torch.var(input, unbiased=False)``. However, the value stored in the moving average of the 506 standard-deviation is calculated via the unbiased estimator, equivalent to 507 ``torch.var(input, unbiased=True)``. 508 509 Also by default, during training this layer keeps running estimates of its 510 computed mean and variance, which are then used for normalization during 511 evaluation. The running estimates are kept with a default :attr:`momentum` 512 of 0.1. 513 514 If :attr:`track_running_stats` is set to ``False``, this layer then does not 515 keep running estimates, and batch statistics are instead used during 516 evaluation time as well. 517 518 .. note:: 519 This :attr:`momentum` argument is different from one used in optimizer 520 classes and the conventional notion of momentum. Mathematically, the 521 update rule for running statistics here is 522 :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, 523 where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the 524 new observed value. 525 526 Because the Batch Normalization is done over the `C` dimension, computing statistics 527 on `(N, D, H, W)` slices, it's common terminology to call this Volumetric Batch Normalization 528 or Spatio-temporal Batch Normalization. 529 530 Args: 531 num_features: :math:`C` from an expected input of size 532 :math:`(N, C, D, H, W)` 533 eps: a value added to the denominator for numerical stability. 534 Default: 1e-5 535 momentum: the value used for the running_mean and running_var 536 computation. Can be set to ``None`` for cumulative moving average 537 (i.e. simple average). Default: 0.1 538 affine: a boolean value that when set to ``True``, this module has 539 learnable affine parameters. Default: ``True`` 540 track_running_stats: a boolean value that when set to ``True``, this 541 module tracks the running mean and variance, and when set to ``False``, 542 this module does not track such statistics, and initializes statistics 543 buffers :attr:`running_mean` and :attr:`running_var` as ``None``. 544 When these buffers are ``None``, this module always uses batch statistics. 545 in both training and eval modes. Default: ``True`` 546 547 Shape: 548 - Input: :math:`(N, C, D, H, W)` 549 - Output: :math:`(N, C, D, H, W)` (same shape as input) 550 551 Examples:: 552 553 >>> # With Learnable Parameters 554 >>> m = nn.BatchNorm3d(100) 555 >>> # Without Learnable Parameters 556 >>> m = nn.BatchNorm3d(100, affine=False) 557 >>> input = torch.randn(20, 100, 35, 45, 10) 558 >>> output = m(input) 559 """ 560 561 def _check_input_dim(self, input): 562 if input.dim() != 5: 563 raise ValueError(f"expected 5D input (got {input.dim()}D input)") 564 565 566class LazyBatchNorm3d(_LazyNormBase, _BatchNorm): 567 r"""A :class:`torch.nn.BatchNorm3d` module with lazy initialization. 568 569 Lazy initialization is done for the ``num_features`` argument of the :class:`BatchNorm3d` that is inferred 570 from the ``input.size(1)``. 571 The attributes that will be lazily initialized are `weight`, `bias`, 572 `running_mean` and `running_var`. 573 574 Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation 575 on lazy modules and their limitations. 576 577 Args: 578 eps: a value added to the denominator for numerical stability. 579 Default: 1e-5 580 momentum: the value used for the running_mean and running_var 581 computation. Can be set to ``None`` for cumulative moving average 582 (i.e. simple average). Default: 0.1 583 affine: a boolean value that when set to ``True``, this module has 584 learnable affine parameters. Default: ``True`` 585 track_running_stats: a boolean value that when set to ``True``, this 586 module tracks the running mean and variance, and when set to ``False``, 587 this module does not track such statistics, and initializes statistics 588 buffers :attr:`running_mean` and :attr:`running_var` as ``None``. 589 When these buffers are ``None``, this module always uses batch statistics. 590 in both training and eval modes. Default: ``True`` 591 """ 592 593 cls_to_become = BatchNorm3d # type: ignore[assignment] 594 595 def _check_input_dim(self, input): 596 if input.dim() != 5: 597 raise ValueError(f"expected 5D input (got {input.dim()}D input)") 598 599 600class SyncBatchNorm(_BatchNorm): 601 r"""Applies Batch Normalization over a N-Dimensional input. 602 603 The N-D input is a mini-batch of [N-2]D inputs with additional channel dimension) as described in the paper 604 `Batch Normalization: Accelerating Deep Network Training by Reducing 605 Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ . 606 607 .. math:: 608 609 y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta 610 611 The mean and standard-deviation are calculated per-dimension over all 612 mini-batches of the same process groups. :math:`\gamma` and :math:`\beta` 613 are learnable parameter vectors of size `C` (where `C` is the input size). 614 By default, the elements of :math:`\gamma` are sampled from 615 :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0. 616 The standard-deviation is calculated via the biased estimator, equivalent to 617 `torch.var(input, unbiased=False)`. 618 619 Also by default, during training this layer keeps running estimates of its 620 computed mean and variance, which are then used for normalization during 621 evaluation. The running estimates are kept with a default :attr:`momentum` 622 of 0.1. 623 624 If :attr:`track_running_stats` is set to ``False``, this layer then does not 625 keep running estimates, and batch statistics are instead used during 626 evaluation time as well. 627 628 .. note:: 629 This :attr:`momentum` argument is different from one used in optimizer 630 classes and the conventional notion of momentum. Mathematically, the 631 update rule for running statistics here is 632 :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, 633 where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the 634 new observed value. 635 636 Because the Batch Normalization is done for each channel in the ``C`` dimension, computing 637 statistics on ``(N, +)`` slices, it's common terminology to call this Volumetric Batch 638 Normalization or Spatio-temporal Batch Normalization. 639 640 Currently :class:`SyncBatchNorm` only supports 641 :class:`~torch.nn.DistributedDataParallel` (DDP) with single GPU per process. Use 642 :meth:`torch.nn.SyncBatchNorm.convert_sync_batchnorm()` to convert 643 :attr:`BatchNorm*D` layer to :class:`SyncBatchNorm` before wrapping 644 Network with DDP. 645 646 Args: 647 num_features: :math:`C` from an expected input of size 648 :math:`(N, C, +)` 649 eps: a value added to the denominator for numerical stability. 650 Default: ``1e-5`` 651 momentum: the value used for the running_mean and running_var 652 computation. Can be set to ``None`` for cumulative moving average 653 (i.e. simple average). Default: 0.1 654 affine: a boolean value that when set to ``True``, this module has 655 learnable affine parameters. Default: ``True`` 656 track_running_stats: a boolean value that when set to ``True``, this 657 module tracks the running mean and variance, and when set to ``False``, 658 this module does not track such statistics, and initializes statistics 659 buffers :attr:`running_mean` and :attr:`running_var` as ``None``. 660 When these buffers are ``None``, this module always uses batch statistics. 661 in both training and eval modes. Default: ``True`` 662 process_group: synchronization of stats happen within each process group 663 individually. Default behavior is synchronization across the whole 664 world 665 666 Shape: 667 - Input: :math:`(N, C, +)` 668 - Output: :math:`(N, C, +)` (same shape as input) 669 670 .. note:: 671 Synchronization of batchnorm statistics occurs only while training, i.e. 672 synchronization is disabled when ``model.eval()`` is set or if 673 ``self.training`` is otherwise ``False``. 674 675 Examples:: 676 677 >>> # xdoctest: +SKIP 678 >>> # With Learnable Parameters 679 >>> m = nn.SyncBatchNorm(100) 680 >>> # creating process group (optional) 681 >>> # ranks is a list of int identifying rank ids. 682 >>> ranks = list(range(8)) 683 >>> r1, r2 = ranks[:4], ranks[4:] 684 >>> # Note: every rank calls into new_group for every 685 >>> # process group created, even if that rank is not 686 >>> # part of the group. 687 >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]] 688 >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1] 689 >>> # Without Learnable Parameters 690 >>> m = nn.BatchNorm3d(100, affine=False, process_group=process_group) 691 >>> input = torch.randn(20, 100, 35, 45, 10) 692 >>> output = m(input) 693 694 >>> # network is nn.BatchNorm layer 695 >>> sync_bn_network = nn.SyncBatchNorm.convert_sync_batchnorm(network, process_group) 696 >>> # only single gpu per process is currently supported 697 >>> ddp_sync_bn_network = torch.nn.parallel.DistributedDataParallel( 698 >>> sync_bn_network, 699 >>> device_ids=[args.local_rank], 700 >>> output_device=args.local_rank) 701 """ 702 703 def __init__( 704 self, 705 num_features: int, 706 eps: float = 1e-5, 707 momentum: Optional[float] = 0.1, 708 affine: bool = True, 709 track_running_stats: bool = True, 710 process_group: Optional[Any] = None, 711 device=None, 712 dtype=None, 713 ) -> None: 714 factory_kwargs = {"device": device, "dtype": dtype} 715 super().__init__( 716 num_features, eps, momentum, affine, track_running_stats, **factory_kwargs 717 ) 718 self.process_group = process_group 719 720 def _check_input_dim(self, input): 721 if input.dim() < 2: 722 raise ValueError(f"expected at least 2D input (got {input.dim()}D input)") 723 724 def _check_non_zero_input_channels(self, input): 725 if input.size(1) == 0: 726 raise ValueError( 727 "SyncBatchNorm number of input channels should be non-zero" 728 ) 729 730 def forward(self, input: Tensor) -> Tensor: 731 self._check_input_dim(input) 732 self._check_non_zero_input_channels(input) 733 734 # exponential_average_factor is set to self.momentum 735 # (when it is available) only so that it gets updated 736 # in ONNX graph when this node is exported to ONNX. 737 if self.momentum is None: 738 exponential_average_factor = 0.0 739 else: 740 exponential_average_factor = self.momentum 741 742 if self.training and self.track_running_stats: 743 assert self.num_batches_tracked is not None 744 self.num_batches_tracked.add_(1) 745 if self.momentum is None: # use cumulative moving average 746 exponential_average_factor = 1.0 / self.num_batches_tracked.item() 747 else: # use exponential moving average 748 exponential_average_factor = self.momentum 749 750 r""" 751 Decide whether the mini-batch stats should be used for normalization rather than the buffers. 752 Mini-batch stats are used in training mode, and in eval mode when buffers are None. 753 """ 754 if self.training: 755 bn_training = True 756 else: 757 bn_training = (self.running_mean is None) and (self.running_var is None) 758 759 r""" 760 Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be 761 passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are 762 used for normalization (i.e. in eval mode when buffers are not None). 763 """ 764 # If buffers are not to be tracked, ensure that they won't be updated 765 running_mean = ( 766 self.running_mean if not self.training or self.track_running_stats else None 767 ) 768 running_var = ( 769 self.running_var if not self.training or self.track_running_stats else None 770 ) 771 772 # Don't sync batchnorm stats in inference mode (model.eval()). 773 need_sync = ( 774 bn_training 775 and self.training 776 and torch.distributed.is_available() 777 and torch.distributed.is_initialized() 778 ) 779 if need_sync: 780 # currently only GPU/PrivateUse1 input is supported 781 if input.device.type not in [ 782 "cuda", 783 torch._C._get_privateuse1_backend_name(), 784 ]: 785 raise ValueError( 786 "SyncBatchNorm expected input tensor to be on GPU or " 787 f"{torch._C._get_privateuse1_backend_name()}" 788 ) 789 790 process_group = torch.distributed.group.WORLD 791 if self.process_group: 792 process_group = self.process_group 793 world_size = torch.distributed.get_world_size(process_group) 794 need_sync = world_size > 1 795 796 # fallback to framework BN when synchronization is not necessary 797 if not need_sync: 798 return F.batch_norm( 799 input, 800 running_mean, 801 running_var, 802 self.weight, 803 self.bias, 804 bn_training, 805 exponential_average_factor, 806 self.eps, 807 ) 808 else: 809 assert bn_training 810 return sync_batch_norm.apply( 811 input, 812 self.weight, 813 self.bias, 814 running_mean, 815 running_var, 816 self.eps, 817 exponential_average_factor, 818 process_group, # type: ignore[possibly-undefined] 819 world_size, # type: ignore[possibly-undefined] 820 ) 821 822 @classmethod 823 def convert_sync_batchnorm(cls, module, process_group=None): 824 r"""Converts all :attr:`BatchNorm*D` layers in the model to :class:`torch.nn.SyncBatchNorm` layers. 825 826 Args: 827 module (nn.Module): module containing one or more :attr:`BatchNorm*D` layers 828 process_group (optional): process group to scope synchronization, 829 default is the whole world 830 831 Returns: 832 The original :attr:`module` with the converted :class:`torch.nn.SyncBatchNorm` 833 layers. If the original :attr:`module` is a :attr:`BatchNorm*D` layer, 834 a new :class:`torch.nn.SyncBatchNorm` layer object will be returned 835 instead. 836 837 Example:: 838 839 >>> # Network with nn.BatchNorm layer 840 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) 841 >>> module = torch.nn.Sequential( 842 >>> torch.nn.Linear(20, 100), 843 >>> torch.nn.BatchNorm1d(100), 844 >>> ).cuda() 845 >>> # creating process group (optional) 846 >>> # ranks is a list of int identifying rank ids. 847 >>> ranks = list(range(8)) 848 >>> r1, r2 = ranks[:4], ranks[4:] 849 >>> # Note: every rank calls into new_group for every 850 >>> # process group created, even if that rank is not 851 >>> # part of the group. 852 >>> # xdoctest: +SKIP("distributed") 853 >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]] 854 >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1] 855 >>> sync_bn_module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module, process_group) 856 857 """ 858 module_output = module 859 if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): 860 module_output = torch.nn.SyncBatchNorm( 861 module.num_features, 862 module.eps, 863 module.momentum, 864 module.affine, 865 module.track_running_stats, 866 process_group, 867 ) 868 if module.affine: 869 with torch.no_grad(): 870 module_output.weight = module.weight 871 module_output.bias = module.bias 872 module_output.running_mean = module.running_mean 873 module_output.running_var = module.running_var 874 module_output.num_batches_tracked = module.num_batches_tracked 875 module_output.training = module.training 876 if hasattr(module, "qconfig"): 877 module_output.qconfig = module.qconfig 878 for name, child in module.named_children(): 879 module_output.add_module( 880 name, cls.convert_sync_batchnorm(child, process_group) 881 ) 882 del module 883 return module_output 884