xref: /aosp_15_r20/external/pytorch/torch/nn/modules/batchnorm.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Any, Optional
3
4import torch
5from torch import Tensor
6from torch.nn import functional as F, init
7from torch.nn.parameter import Parameter, UninitializedBuffer, UninitializedParameter
8
9from ._functions import SyncBatchNorm as sync_batch_norm
10from .lazy import LazyModuleMixin
11from .module import Module
12
13
14__all__ = [
15    "BatchNorm1d",
16    "LazyBatchNorm1d",
17    "BatchNorm2d",
18    "LazyBatchNorm2d",
19    "BatchNorm3d",
20    "LazyBatchNorm3d",
21    "SyncBatchNorm",
22]
23
24
25class _NormBase(Module):
26    """Common base of _InstanceNorm and _BatchNorm."""
27
28    _version = 2
29    __constants__ = ["track_running_stats", "momentum", "eps", "num_features", "affine"]
30    num_features: int
31    eps: float
32    momentum: Optional[float]
33    affine: bool
34    track_running_stats: bool
35    # WARNING: weight and bias purposely not defined here.
36    # See https://github.com/pytorch/pytorch/issues/39670
37
38    def __init__(
39        self,
40        num_features: int,
41        eps: float = 1e-5,
42        momentum: Optional[float] = 0.1,
43        affine: bool = True,
44        track_running_stats: bool = True,
45        device=None,
46        dtype=None,
47    ) -> None:
48        factory_kwargs = {"device": device, "dtype": dtype}
49        super().__init__()
50        self.num_features = num_features
51        self.eps = eps
52        self.momentum = momentum
53        self.affine = affine
54        self.track_running_stats = track_running_stats
55        if self.affine:
56            self.weight = Parameter(torch.empty(num_features, **factory_kwargs))
57            self.bias = Parameter(torch.empty(num_features, **factory_kwargs))
58        else:
59            self.register_parameter("weight", None)
60            self.register_parameter("bias", None)
61        if self.track_running_stats:
62            self.register_buffer(
63                "running_mean", torch.zeros(num_features, **factory_kwargs)
64            )
65            self.register_buffer(
66                "running_var", torch.ones(num_features, **factory_kwargs)
67            )
68            self.running_mean: Optional[Tensor]
69            self.running_var: Optional[Tensor]
70            self.register_buffer(
71                "num_batches_tracked",
72                torch.tensor(
73                    0,
74                    dtype=torch.long,
75                    **{k: v for k, v in factory_kwargs.items() if k != "dtype"},
76                ),
77            )
78            self.num_batches_tracked: Optional[Tensor]
79        else:
80            self.register_buffer("running_mean", None)
81            self.register_buffer("running_var", None)
82            self.register_buffer("num_batches_tracked", None)
83        self.reset_parameters()
84
85    def reset_running_stats(self) -> None:
86        if self.track_running_stats:
87            # running_mean/running_var/num_batches... are registered at runtime depending
88            # if self.track_running_stats is on
89            self.running_mean.zero_()  # type: ignore[union-attr]
90            self.running_var.fill_(1)  # type: ignore[union-attr]
91            self.num_batches_tracked.zero_()  # type: ignore[union-attr,operator]
92
93    def reset_parameters(self) -> None:
94        self.reset_running_stats()
95        if self.affine:
96            init.ones_(self.weight)
97            init.zeros_(self.bias)
98
99    def _check_input_dim(self, input):
100        raise NotImplementedError
101
102    def extra_repr(self):
103        return (
104            "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, "
105            "track_running_stats={track_running_stats}".format(**self.__dict__)
106        )
107
108    def _load_from_state_dict(
109        self,
110        state_dict,
111        prefix,
112        local_metadata,
113        strict,
114        missing_keys,
115        unexpected_keys,
116        error_msgs,
117    ):
118        version = local_metadata.get("version", None)
119
120        if (version is None or version < 2) and self.track_running_stats:
121            # at version 2: added num_batches_tracked buffer
122            #               this should have a default value of 0
123            num_batches_tracked_key = prefix + "num_batches_tracked"
124            if num_batches_tracked_key not in state_dict:
125                state_dict[num_batches_tracked_key] = (
126                    self.num_batches_tracked
127                    if self.num_batches_tracked is not None
128                    and self.num_batches_tracked.device != torch.device("meta")
129                    else torch.tensor(0, dtype=torch.long)
130                )
131
132        super()._load_from_state_dict(
133            state_dict,
134            prefix,
135            local_metadata,
136            strict,
137            missing_keys,
138            unexpected_keys,
139            error_msgs,
140        )
141
142
143class _BatchNorm(_NormBase):
144    def __init__(
145        self,
146        num_features: int,
147        eps: float = 1e-5,
148        momentum: Optional[float] = 0.1,
149        affine: bool = True,
150        track_running_stats: bool = True,
151        device=None,
152        dtype=None,
153    ) -> None:
154        factory_kwargs = {"device": device, "dtype": dtype}
155        super().__init__(
156            num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
157        )
158
159    def forward(self, input: Tensor) -> Tensor:
160        self._check_input_dim(input)
161
162        # exponential_average_factor is set to self.momentum
163        # (when it is available) only so that it gets updated
164        # in ONNX graph when this node is exported to ONNX.
165        if self.momentum is None:
166            exponential_average_factor = 0.0
167        else:
168            exponential_average_factor = self.momentum
169
170        if self.training and self.track_running_stats:
171            # TODO: if statement only here to tell the jit to skip emitting this when it is None
172            if self.num_batches_tracked is not None:  # type: ignore[has-type]
173                self.num_batches_tracked.add_(1)  # type: ignore[has-type]
174                if self.momentum is None:  # use cumulative moving average
175                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
176                else:  # use exponential moving average
177                    exponential_average_factor = self.momentum
178
179        r"""
180        Decide whether the mini-batch stats should be used for normalization rather than the buffers.
181        Mini-batch stats are used in training mode, and in eval mode when buffers are None.
182        """
183        if self.training:
184            bn_training = True
185        else:
186            bn_training = (self.running_mean is None) and (self.running_var is None)
187
188        r"""
189        Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
190        passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
191        used for normalization (i.e. in eval mode when buffers are not None).
192        """
193        return F.batch_norm(
194            input,
195            # If buffers are not to be tracked, ensure that they won't be updated
196            self.running_mean
197            if not self.training or self.track_running_stats
198            else None,
199            self.running_var if not self.training or self.track_running_stats else None,
200            self.weight,
201            self.bias,
202            bn_training,
203            exponential_average_factor,
204            self.eps,
205        )
206
207
208class _LazyNormBase(LazyModuleMixin, _NormBase):
209    weight: UninitializedParameter  # type: ignore[assignment]
210    bias: UninitializedParameter  # type: ignore[assignment]
211
212    def __init__(
213        self,
214        eps=1e-5,
215        momentum=0.1,
216        affine=True,
217        track_running_stats=True,
218        device=None,
219        dtype=None,
220    ) -> None:
221        factory_kwargs = {"device": device, "dtype": dtype}
222        super().__init__(
223            # affine and track_running_stats are hardcoded to False to
224            # avoid creating tensors that will soon be overwritten.
225            0,
226            eps,
227            momentum,
228            False,
229            False,
230            **factory_kwargs,
231        )
232        self.affine = affine
233        self.track_running_stats = track_running_stats
234        if self.affine:
235            self.weight = UninitializedParameter(**factory_kwargs)
236            self.bias = UninitializedParameter(**factory_kwargs)
237        if self.track_running_stats:
238            self.running_mean = UninitializedBuffer(**factory_kwargs)
239            self.running_var = UninitializedBuffer(**factory_kwargs)
240            self.num_batches_tracked = torch.tensor(
241                0,
242                dtype=torch.long,
243                **{k: v for k, v in factory_kwargs.items() if k != "dtype"},
244            )
245
246    def reset_parameters(self) -> None:
247        if not self.has_uninitialized_params() and self.num_features != 0:
248            super().reset_parameters()
249
250    def initialize_parameters(self, input) -> None:  # type: ignore[override]
251        if self.has_uninitialized_params():
252            self.num_features = input.shape[1]
253            if self.affine:
254                assert isinstance(self.weight, UninitializedParameter)
255                assert isinstance(self.bias, UninitializedParameter)
256                self.weight.materialize((self.num_features,))
257                self.bias.materialize((self.num_features,))
258            if self.track_running_stats:
259                self.running_mean.materialize(  # type:ignore[union-attr]
260                    (self.num_features,)
261                )
262                self.running_var.materialize(  # type:ignore[union-attr]
263                    (self.num_features,)
264                )
265            self.reset_parameters()
266
267
268class BatchNorm1d(_BatchNorm):
269    r"""Applies Batch Normalization over a 2D or 3D input.
270
271    Method described in the paper
272    `Batch Normalization: Accelerating Deep Network Training by Reducing
273    Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
274
275    .. math::
276
277        y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
278
279    The mean and standard-deviation are calculated per-dimension over
280    the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
281    of size `C` (where `C` is the number of features or channels of the input). By default, the
282    elements of :math:`\gamma` are set to 1 and the elements of :math:`\beta` are set to 0.
283    At train time in the forward pass, the standard-deviation is calculated via the biased estimator,
284    equivalent to ``torch.var(input, unbiased=False)``. However, the value stored in the
285    moving average of the standard-deviation is calculated via the unbiased  estimator, equivalent to
286    ``torch.var(input, unbiased=True)``.
287
288    Also by default, during training this layer keeps running estimates of its
289    computed mean and variance, which are then used for normalization during
290    evaluation. The running estimates are kept with a default :attr:`momentum`
291    of 0.1.
292
293    If :attr:`track_running_stats` is set to ``False``, this layer then does not
294    keep running estimates, and batch statistics are instead used during
295    evaluation time as well.
296
297    .. note::
298        This :attr:`momentum` argument is different from one used in optimizer
299        classes and the conventional notion of momentum. Mathematically, the
300        update rule for running statistics here is
301        :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
302        where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
303        new observed value.
304
305    Because the Batch Normalization is done over the `C` dimension, computing statistics
306    on `(N, L)` slices, it's common terminology to call this Temporal Batch Normalization.
307
308    Args:
309        num_features: number of features or channels :math:`C` of the input
310        eps: a value added to the denominator for numerical stability.
311            Default: 1e-5
312        momentum: the value used for the running_mean and running_var
313            computation. Can be set to ``None`` for cumulative moving average
314            (i.e. simple average). Default: 0.1
315        affine: a boolean value that when set to ``True``, this module has
316            learnable affine parameters. Default: ``True``
317        track_running_stats: a boolean value that when set to ``True``, this
318            module tracks the running mean and variance, and when set to ``False``,
319            this module does not track such statistics, and initializes statistics
320            buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
321            When these buffers are ``None``, this module always uses batch statistics.
322            in both training and eval modes. Default: ``True``
323
324    Shape:
325        - Input: :math:`(N, C)` or :math:`(N, C, L)`, where :math:`N` is the batch size,
326          :math:`C` is the number of features or channels, and :math:`L` is the sequence length
327        - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
328
329    Examples::
330
331        >>> # With Learnable Parameters
332        >>> m = nn.BatchNorm1d(100)
333        >>> # Without Learnable Parameters
334        >>> m = nn.BatchNorm1d(100, affine=False)
335        >>> input = torch.randn(20, 100)
336        >>> output = m(input)
337    """
338
339    def _check_input_dim(self, input):
340        if input.dim() != 2 and input.dim() != 3:
341            raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)")
342
343
344class LazyBatchNorm1d(_LazyNormBase, _BatchNorm):
345    r"""A :class:`torch.nn.BatchNorm1d` module with lazy initialization.
346
347    Lazy initialization based on the ``num_features`` argument of the :class:`BatchNorm1d` that is inferred
348    from the ``input.size(1)``.
349    The attributes that will be lazily initialized are `weight`, `bias`,
350    `running_mean` and `running_var`.
351
352    Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
353    on lazy modules and their limitations.
354
355    Args:
356        eps: a value added to the denominator for numerical stability.
357            Default: 1e-5
358        momentum: the value used for the running_mean and running_var
359            computation. Can be set to ``None`` for cumulative moving average
360            (i.e. simple average). Default: 0.1
361        affine: a boolean value that when set to ``True``, this module has
362            learnable affine parameters. Default: ``True``
363        track_running_stats: a boolean value that when set to ``True``, this
364            module tracks the running mean and variance, and when set to ``False``,
365            this module does not track such statistics, and initializes statistics
366            buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
367            When these buffers are ``None``, this module always uses batch statistics.
368            in both training and eval modes. Default: ``True``
369    """
370
371    cls_to_become = BatchNorm1d  # type: ignore[assignment]
372
373    def _check_input_dim(self, input):
374        if input.dim() != 2 and input.dim() != 3:
375            raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)")
376
377
378class BatchNorm2d(_BatchNorm):
379    r"""Applies Batch Normalization over a 4D input.
380
381    4D is a mini-batch of 2D inputs
382    with additional channel dimension. Method described in the paper
383    `Batch Normalization: Accelerating Deep Network Training by Reducing
384    Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
385
386    .. math::
387
388        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
389
390    The mean and standard-deviation are calculated per-dimension over
391    the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
392    of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set
393    to 1 and the elements of :math:`\beta` are set to 0. At train time in the forward pass, the
394    standard-deviation is calculated via the biased estimator, equivalent to
395    ``torch.var(input, unbiased=False)``. However, the value stored in the moving average of the
396    standard-deviation is calculated via the unbiased  estimator, equivalent to
397    ``torch.var(input, unbiased=True)``.
398
399    Also by default, during training this layer keeps running estimates of its
400    computed mean and variance, which are then used for normalization during
401    evaluation. The running estimates are kept with a default :attr:`momentum`
402    of 0.1.
403
404    If :attr:`track_running_stats` is set to ``False``, this layer then does not
405    keep running estimates, and batch statistics are instead used during
406    evaluation time as well.
407
408    .. note::
409        This :attr:`momentum` argument is different from one used in optimizer
410        classes and the conventional notion of momentum. Mathematically, the
411        update rule for running statistics here is
412        :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
413        where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
414        new observed value.
415
416    Because the Batch Normalization is done over the `C` dimension, computing statistics
417    on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization.
418
419    Args:
420        num_features: :math:`C` from an expected input of size
421            :math:`(N, C, H, W)`
422        eps: a value added to the denominator for numerical stability.
423            Default: 1e-5
424        momentum: the value used for the running_mean and running_var
425            computation. Can be set to ``None`` for cumulative moving average
426            (i.e. simple average). Default: 0.1
427        affine: a boolean value that when set to ``True``, this module has
428            learnable affine parameters. Default: ``True``
429        track_running_stats: a boolean value that when set to ``True``, this
430            module tracks the running mean and variance, and when set to ``False``,
431            this module does not track such statistics, and initializes statistics
432            buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
433            When these buffers are ``None``, this module always uses batch statistics.
434            in both training and eval modes. Default: ``True``
435
436    Shape:
437        - Input: :math:`(N, C, H, W)`
438        - Output: :math:`(N, C, H, W)` (same shape as input)
439
440    Examples::
441
442        >>> # With Learnable Parameters
443        >>> m = nn.BatchNorm2d(100)
444        >>> # Without Learnable Parameters
445        >>> m = nn.BatchNorm2d(100, affine=False)
446        >>> input = torch.randn(20, 100, 35, 45)
447        >>> output = m(input)
448    """
449
450    def _check_input_dim(self, input):
451        if input.dim() != 4:
452            raise ValueError(f"expected 4D input (got {input.dim()}D input)")
453
454
455class LazyBatchNorm2d(_LazyNormBase, _BatchNorm):
456    r"""A :class:`torch.nn.BatchNorm2d` module with lazy initialization.
457
458    Lazy initialization is done for the ``num_features`` argument of the :class:`BatchNorm2d` that is inferred
459    from the ``input.size(1)``.
460    The attributes that will be lazily initialized are `weight`, `bias`,
461    `running_mean` and `running_var`.
462
463    Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
464    on lazy modules and their limitations.
465
466    Args:
467        eps: a value added to the denominator for numerical stability.
468            Default: 1e-5
469        momentum: the value used for the running_mean and running_var
470            computation. Can be set to ``None`` for cumulative moving average
471            (i.e. simple average). Default: 0.1
472        affine: a boolean value that when set to ``True``, this module has
473            learnable affine parameters. Default: ``True``
474        track_running_stats: a boolean value that when set to ``True``, this
475            module tracks the running mean and variance, and when set to ``False``,
476            this module does not track such statistics, and initializes statistics
477            buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
478            When these buffers are ``None``, this module always uses batch statistics.
479            in both training and eval modes. Default: ``True``
480    """
481
482    cls_to_become = BatchNorm2d  # type: ignore[assignment]
483
484    def _check_input_dim(self, input):
485        if input.dim() != 4:
486            raise ValueError(f"expected 4D input (got {input.dim()}D input)")
487
488
489class BatchNorm3d(_BatchNorm):
490    r"""Applies Batch Normalization over a 5D input.
491
492    5D is a mini-batch of 3D inputs with additional channel dimension as described in the paper
493    `Batch Normalization: Accelerating Deep Network Training by Reducing
494    Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
495
496    .. math::
497
498        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
499
500    The mean and standard-deviation are calculated per-dimension over
501    the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
502    of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set
503    to 1 and the elements of :math:`\beta` are set to 0. At train time in the forward pass, the
504    standard-deviation is calculated via the biased estimator, equivalent to
505    ``torch.var(input, unbiased=False)``. However, the value stored in the moving average of the
506    standard-deviation is calculated via the unbiased  estimator, equivalent to
507    ``torch.var(input, unbiased=True)``.
508
509    Also by default, during training this layer keeps running estimates of its
510    computed mean and variance, which are then used for normalization during
511    evaluation. The running estimates are kept with a default :attr:`momentum`
512    of 0.1.
513
514    If :attr:`track_running_stats` is set to ``False``, this layer then does not
515    keep running estimates, and batch statistics are instead used during
516    evaluation time as well.
517
518    .. note::
519        This :attr:`momentum` argument is different from one used in optimizer
520        classes and the conventional notion of momentum. Mathematically, the
521        update rule for running statistics here is
522        :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
523        where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
524        new observed value.
525
526    Because the Batch Normalization is done over the `C` dimension, computing statistics
527    on `(N, D, H, W)` slices, it's common terminology to call this Volumetric Batch Normalization
528    or Spatio-temporal Batch Normalization.
529
530    Args:
531        num_features: :math:`C` from an expected input of size
532            :math:`(N, C, D, H, W)`
533        eps: a value added to the denominator for numerical stability.
534            Default: 1e-5
535        momentum: the value used for the running_mean and running_var
536            computation. Can be set to ``None`` for cumulative moving average
537            (i.e. simple average). Default: 0.1
538        affine: a boolean value that when set to ``True``, this module has
539            learnable affine parameters. Default: ``True``
540        track_running_stats: a boolean value that when set to ``True``, this
541            module tracks the running mean and variance, and when set to ``False``,
542            this module does not track such statistics, and initializes statistics
543            buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
544            When these buffers are ``None``, this module always uses batch statistics.
545            in both training and eval modes. Default: ``True``
546
547    Shape:
548        - Input: :math:`(N, C, D, H, W)`
549        - Output: :math:`(N, C, D, H, W)` (same shape as input)
550
551    Examples::
552
553        >>> # With Learnable Parameters
554        >>> m = nn.BatchNorm3d(100)
555        >>> # Without Learnable Parameters
556        >>> m = nn.BatchNorm3d(100, affine=False)
557        >>> input = torch.randn(20, 100, 35, 45, 10)
558        >>> output = m(input)
559    """
560
561    def _check_input_dim(self, input):
562        if input.dim() != 5:
563            raise ValueError(f"expected 5D input (got {input.dim()}D input)")
564
565
566class LazyBatchNorm3d(_LazyNormBase, _BatchNorm):
567    r"""A :class:`torch.nn.BatchNorm3d` module with lazy initialization.
568
569    Lazy initialization is done for the ``num_features`` argument of the :class:`BatchNorm3d` that is inferred
570    from the ``input.size(1)``.
571    The attributes that will be lazily initialized are `weight`, `bias`,
572    `running_mean` and `running_var`.
573
574    Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
575    on lazy modules and their limitations.
576
577    Args:
578        eps: a value added to the denominator for numerical stability.
579            Default: 1e-5
580        momentum: the value used for the running_mean and running_var
581            computation. Can be set to ``None`` for cumulative moving average
582            (i.e. simple average). Default: 0.1
583        affine: a boolean value that when set to ``True``, this module has
584            learnable affine parameters. Default: ``True``
585        track_running_stats: a boolean value that when set to ``True``, this
586            module tracks the running mean and variance, and when set to ``False``,
587            this module does not track such statistics, and initializes statistics
588            buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
589            When these buffers are ``None``, this module always uses batch statistics.
590            in both training and eval modes. Default: ``True``
591    """
592
593    cls_to_become = BatchNorm3d  # type: ignore[assignment]
594
595    def _check_input_dim(self, input):
596        if input.dim() != 5:
597            raise ValueError(f"expected 5D input (got {input.dim()}D input)")
598
599
600class SyncBatchNorm(_BatchNorm):
601    r"""Applies Batch Normalization over a N-Dimensional input.
602
603    The N-D input is a mini-batch of [N-2]D inputs with additional channel dimension) as described in the paper
604    `Batch Normalization: Accelerating Deep Network Training by Reducing
605    Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
606
607    .. math::
608
609        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
610
611    The mean and standard-deviation are calculated per-dimension over all
612    mini-batches of the same process groups. :math:`\gamma` and :math:`\beta`
613    are learnable parameter vectors of size `C` (where `C` is the input size).
614    By default, the elements of :math:`\gamma` are sampled from
615    :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0.
616    The standard-deviation is calculated via the biased estimator, equivalent to
617    `torch.var(input, unbiased=False)`.
618
619    Also by default, during training this layer keeps running estimates of its
620    computed mean and variance, which are then used for normalization during
621    evaluation. The running estimates are kept with a default :attr:`momentum`
622    of 0.1.
623
624    If :attr:`track_running_stats` is set to ``False``, this layer then does not
625    keep running estimates, and batch statistics are instead used during
626    evaluation time as well.
627
628    .. note::
629        This :attr:`momentum` argument is different from one used in optimizer
630        classes and the conventional notion of momentum. Mathematically, the
631        update rule for running statistics here is
632        :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
633        where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
634        new observed value.
635
636    Because the Batch Normalization is done for each channel in the ``C`` dimension, computing
637    statistics on ``(N, +)`` slices, it's common terminology to call this Volumetric Batch
638    Normalization or Spatio-temporal Batch Normalization.
639
640    Currently :class:`SyncBatchNorm` only supports
641    :class:`~torch.nn.DistributedDataParallel` (DDP) with single GPU per process. Use
642    :meth:`torch.nn.SyncBatchNorm.convert_sync_batchnorm()` to convert
643    :attr:`BatchNorm*D` layer to :class:`SyncBatchNorm` before wrapping
644    Network with DDP.
645
646    Args:
647        num_features: :math:`C` from an expected input of size
648            :math:`(N, C, +)`
649        eps: a value added to the denominator for numerical stability.
650            Default: ``1e-5``
651        momentum: the value used for the running_mean and running_var
652            computation. Can be set to ``None`` for cumulative moving average
653            (i.e. simple average). Default: 0.1
654        affine: a boolean value that when set to ``True``, this module has
655            learnable affine parameters. Default: ``True``
656        track_running_stats: a boolean value that when set to ``True``, this
657            module tracks the running mean and variance, and when set to ``False``,
658            this module does not track such statistics, and initializes statistics
659            buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
660            When these buffers are ``None``, this module always uses batch statistics.
661            in both training and eval modes. Default: ``True``
662        process_group: synchronization of stats happen within each process group
663            individually. Default behavior is synchronization across the whole
664            world
665
666    Shape:
667        - Input: :math:`(N, C, +)`
668        - Output: :math:`(N, C, +)` (same shape as input)
669
670    .. note::
671        Synchronization of batchnorm statistics occurs only while training, i.e.
672        synchronization is disabled when ``model.eval()`` is set or if
673        ``self.training`` is otherwise ``False``.
674
675    Examples::
676
677        >>> # xdoctest: +SKIP
678        >>> # With Learnable Parameters
679        >>> m = nn.SyncBatchNorm(100)
680        >>> # creating process group (optional)
681        >>> # ranks is a list of int identifying rank ids.
682        >>> ranks = list(range(8))
683        >>> r1, r2 = ranks[:4], ranks[4:]
684        >>> # Note: every rank calls into new_group for every
685        >>> # process group created, even if that rank is not
686        >>> # part of the group.
687        >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]]
688        >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1]
689        >>> # Without Learnable Parameters
690        >>> m = nn.BatchNorm3d(100, affine=False, process_group=process_group)
691        >>> input = torch.randn(20, 100, 35, 45, 10)
692        >>> output = m(input)
693
694        >>> # network is nn.BatchNorm layer
695        >>> sync_bn_network = nn.SyncBatchNorm.convert_sync_batchnorm(network, process_group)
696        >>> # only single gpu per process is currently supported
697        >>> ddp_sync_bn_network = torch.nn.parallel.DistributedDataParallel(
698        >>>                         sync_bn_network,
699        >>>                         device_ids=[args.local_rank],
700        >>>                         output_device=args.local_rank)
701    """
702
703    def __init__(
704        self,
705        num_features: int,
706        eps: float = 1e-5,
707        momentum: Optional[float] = 0.1,
708        affine: bool = True,
709        track_running_stats: bool = True,
710        process_group: Optional[Any] = None,
711        device=None,
712        dtype=None,
713    ) -> None:
714        factory_kwargs = {"device": device, "dtype": dtype}
715        super().__init__(
716            num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
717        )
718        self.process_group = process_group
719
720    def _check_input_dim(self, input):
721        if input.dim() < 2:
722            raise ValueError(f"expected at least 2D input (got {input.dim()}D input)")
723
724    def _check_non_zero_input_channels(self, input):
725        if input.size(1) == 0:
726            raise ValueError(
727                "SyncBatchNorm number of input channels should be non-zero"
728            )
729
730    def forward(self, input: Tensor) -> Tensor:
731        self._check_input_dim(input)
732        self._check_non_zero_input_channels(input)
733
734        # exponential_average_factor is set to self.momentum
735        # (when it is available) only so that it gets updated
736        # in ONNX graph when this node is exported to ONNX.
737        if self.momentum is None:
738            exponential_average_factor = 0.0
739        else:
740            exponential_average_factor = self.momentum
741
742        if self.training and self.track_running_stats:
743            assert self.num_batches_tracked is not None
744            self.num_batches_tracked.add_(1)
745            if self.momentum is None:  # use cumulative moving average
746                exponential_average_factor = 1.0 / self.num_batches_tracked.item()
747            else:  # use exponential moving average
748                exponential_average_factor = self.momentum
749
750        r"""
751        Decide whether the mini-batch stats should be used for normalization rather than the buffers.
752        Mini-batch stats are used in training mode, and in eval mode when buffers are None.
753        """
754        if self.training:
755            bn_training = True
756        else:
757            bn_training = (self.running_mean is None) and (self.running_var is None)
758
759        r"""
760        Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
761        passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
762        used for normalization (i.e. in eval mode when buffers are not None).
763        """
764        # If buffers are not to be tracked, ensure that they won't be updated
765        running_mean = (
766            self.running_mean if not self.training or self.track_running_stats else None
767        )
768        running_var = (
769            self.running_var if not self.training or self.track_running_stats else None
770        )
771
772        # Don't sync batchnorm stats in inference mode (model.eval()).
773        need_sync = (
774            bn_training
775            and self.training
776            and torch.distributed.is_available()
777            and torch.distributed.is_initialized()
778        )
779        if need_sync:
780            # currently only GPU/PrivateUse1 input is supported
781            if input.device.type not in [
782                "cuda",
783                torch._C._get_privateuse1_backend_name(),
784            ]:
785                raise ValueError(
786                    "SyncBatchNorm expected input tensor to be on GPU or "
787                    f"{torch._C._get_privateuse1_backend_name()}"
788                )
789
790            process_group = torch.distributed.group.WORLD
791            if self.process_group:
792                process_group = self.process_group
793            world_size = torch.distributed.get_world_size(process_group)
794            need_sync = world_size > 1
795
796        # fallback to framework BN when synchronization is not necessary
797        if not need_sync:
798            return F.batch_norm(
799                input,
800                running_mean,
801                running_var,
802                self.weight,
803                self.bias,
804                bn_training,
805                exponential_average_factor,
806                self.eps,
807            )
808        else:
809            assert bn_training
810            return sync_batch_norm.apply(
811                input,
812                self.weight,
813                self.bias,
814                running_mean,
815                running_var,
816                self.eps,
817                exponential_average_factor,
818                process_group,  # type: ignore[possibly-undefined]
819                world_size,  # type: ignore[possibly-undefined]
820            )
821
822    @classmethod
823    def convert_sync_batchnorm(cls, module, process_group=None):
824        r"""Converts all :attr:`BatchNorm*D` layers in the model to :class:`torch.nn.SyncBatchNorm` layers.
825
826        Args:
827            module (nn.Module): module containing one or more :attr:`BatchNorm*D` layers
828            process_group (optional): process group to scope synchronization,
829                default is the whole world
830
831        Returns:
832            The original :attr:`module` with the converted :class:`torch.nn.SyncBatchNorm`
833            layers. If the original :attr:`module` is a :attr:`BatchNorm*D` layer,
834            a new :class:`torch.nn.SyncBatchNorm` layer object will be returned
835            instead.
836
837        Example::
838
839            >>> # Network with nn.BatchNorm layer
840            >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
841            >>> module = torch.nn.Sequential(
842            >>>            torch.nn.Linear(20, 100),
843            >>>            torch.nn.BatchNorm1d(100),
844            >>>          ).cuda()
845            >>> # creating process group (optional)
846            >>> # ranks is a list of int identifying rank ids.
847            >>> ranks = list(range(8))
848            >>> r1, r2 = ranks[:4], ranks[4:]
849            >>> # Note: every rank calls into new_group for every
850            >>> # process group created, even if that rank is not
851            >>> # part of the group.
852            >>> # xdoctest: +SKIP("distributed")
853            >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]]
854            >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1]
855            >>> sync_bn_module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module, process_group)
856
857        """
858        module_output = module
859        if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
860            module_output = torch.nn.SyncBatchNorm(
861                module.num_features,
862                module.eps,
863                module.momentum,
864                module.affine,
865                module.track_running_stats,
866                process_group,
867            )
868            if module.affine:
869                with torch.no_grad():
870                    module_output.weight = module.weight
871                    module_output.bias = module.bias
872            module_output.running_mean = module.running_mean
873            module_output.running_var = module.running_var
874            module_output.num_batches_tracked = module.num_batches_tracked
875            module_output.training = module.training
876            if hasattr(module, "qconfig"):
877                module_output.qconfig = module.qconfig
878        for name, child in module.named_children():
879            module_output.add_module(
880                name, cls.convert_sync_batchnorm(child, process_group)
881            )
882        del module
883        return module_output
884