xref: /aosp_15_r20/external/pytorch/torch/nn/modules/instancenorm.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2
3import warnings
4
5import torch.nn.functional as F
6from torch import Tensor
7
8from .batchnorm import _LazyNormBase, _NormBase
9
10
11__all__ = [
12    "InstanceNorm1d",
13    "InstanceNorm2d",
14    "InstanceNorm3d",
15    "LazyInstanceNorm1d",
16    "LazyInstanceNorm2d",
17    "LazyInstanceNorm3d",
18]
19
20
21class _InstanceNorm(_NormBase):
22    def __init__(
23        self,
24        num_features: int,
25        eps: float = 1e-5,
26        momentum: float = 0.1,
27        affine: bool = False,
28        track_running_stats: bool = False,
29        device=None,
30        dtype=None,
31    ) -> None:
32        factory_kwargs = {"device": device, "dtype": dtype}
33        super().__init__(
34            num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
35        )
36
37    def _check_input_dim(self, input):
38        raise NotImplementedError
39
40    def _get_no_batch_dim(self):
41        raise NotImplementedError
42
43    def _handle_no_batch_input(self, input):
44        return self._apply_instance_norm(input.unsqueeze(0)).squeeze(0)
45
46    def _apply_instance_norm(self, input):
47        return F.instance_norm(
48            input,
49            self.running_mean,
50            self.running_var,
51            self.weight,
52            self.bias,
53            self.training or not self.track_running_stats,
54            self.momentum if self.momentum is not None else 0.0,
55            self.eps,
56        )
57
58    def _load_from_state_dict(
59        self,
60        state_dict,
61        prefix,
62        local_metadata,
63        strict,
64        missing_keys,
65        unexpected_keys,
66        error_msgs,
67    ):
68        version = local_metadata.get("version", None)
69        # at version 1: removed running_mean and running_var when
70        # track_running_stats=False (default)
71        if version is None and not self.track_running_stats:
72            running_stats_keys = []
73            for name in ("running_mean", "running_var"):
74                key = prefix + name
75                if key in state_dict:
76                    running_stats_keys.append(key)
77            if len(running_stats_keys) > 0:
78                error_msgs.append(
79                    "Unexpected running stats buffer(s) {names} for {klass} "
80                    "with track_running_stats=False. If state_dict is a "
81                    "checkpoint saved before 0.4.0, this may be expected "
82                    "because {klass} does not track running stats by default "
83                    "since 0.4.0. Please remove these keys from state_dict. If "
84                    "the running stats are actually needed, instead set "
85                    "track_running_stats=True in {klass} to enable them. See "
86                    "the documentation of {klass} for details.".format(
87                        names=" and ".join(f'"{k}"' for k in running_stats_keys),
88                        klass=self.__class__.__name__,
89                    )
90                )
91                for key in running_stats_keys:
92                    state_dict.pop(key)
93
94        super()._load_from_state_dict(
95            state_dict,
96            prefix,
97            local_metadata,
98            strict,
99            missing_keys,
100            unexpected_keys,
101            error_msgs,
102        )
103
104    def forward(self, input: Tensor) -> Tensor:
105        self._check_input_dim(input)
106
107        feature_dim = input.dim() - self._get_no_batch_dim()
108        if input.size(feature_dim) != self.num_features:
109            if self.affine:
110                raise ValueError(
111                    f"expected input's size at dim={feature_dim} to match num_features"
112                    f" ({self.num_features}), but got: {input.size(feature_dim)}."
113                )
114            else:
115                warnings.warn(
116                    f"input's size at dim={feature_dim} does not match num_features. "
117                    "You can silence this warning by not passing in num_features, "
118                    "which is not used because affine=False"
119                )
120
121        if input.dim() == self._get_no_batch_dim():
122            return self._handle_no_batch_input(input)
123
124        return self._apply_instance_norm(input)
125
126
127class InstanceNorm1d(_InstanceNorm):
128    r"""Applies Instance Normalization.
129
130    This operation applies Instance Normalization
131    over a 2D (unbatched) or 3D (batched) input as described in the paper
132    `Instance Normalization: The Missing Ingredient for Fast Stylization
133    <https://arxiv.org/abs/1607.08022>`__.
134
135    .. math::
136
137        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
138
139    The mean and standard-deviation are calculated per-dimension separately
140    for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors
141    of size `C` (where `C` is the number of features or channels of the input) if :attr:`affine` is ``True``.
142    The standard-deviation is calculated via the biased estimator, equivalent to
143    `torch.var(input, unbiased=False)`.
144
145    By default, this layer uses instance statistics computed from input data in
146    both training and evaluation modes.
147
148    If :attr:`track_running_stats` is set to ``True``, during training this
149    layer keeps running estimates of its computed mean and variance, which are
150    then used for normalization during evaluation. The running estimates are
151    kept with a default :attr:`momentum` of 0.1.
152
153    .. note::
154        This :attr:`momentum` argument is different from one used in optimizer
155        classes and the conventional notion of momentum. Mathematically, the
156        update rule for running statistics here is
157        :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
158        where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
159        new observed value.
160
161    .. note::
162        :class:`InstanceNorm1d` and :class:`LayerNorm` are very similar, but
163        have some subtle differences. :class:`InstanceNorm1d` is applied
164        on each channel of channeled data like multidimensional time series, but
165        :class:`LayerNorm` is usually applied on entire sample and often in NLP
166        tasks. Additionally, :class:`LayerNorm` applies elementwise affine
167        transform, while :class:`InstanceNorm1d` usually don't apply affine
168        transform.
169
170    Args:
171        num_features: number of features or channels :math:`C` of the input
172        eps: a value added to the denominator for numerical stability. Default: 1e-5
173        momentum: the value used for the running_mean and running_var computation. Default: 0.1
174        affine: a boolean value that when set to ``True``, this module has
175            learnable affine parameters, initialized the same way as done for batch normalization.
176            Default: ``False``.
177        track_running_stats: a boolean value that when set to ``True``, this
178            module tracks the running mean and variance, and when set to ``False``,
179            this module does not track such statistics and always uses batch
180            statistics in both training and eval modes. Default: ``False``
181
182    Shape:
183        - Input: :math:`(N, C, L)` or :math:`(C, L)`
184        - Output: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input)
185
186    Examples::
187
188        >>> # Without Learnable Parameters
189        >>> m = nn.InstanceNorm1d(100)
190        >>> # With Learnable Parameters
191        >>> m = nn.InstanceNorm1d(100, affine=True)
192        >>> input = torch.randn(20, 100, 40)
193        >>> output = m(input)
194    """
195
196    def _get_no_batch_dim(self):
197        return 2
198
199    def _check_input_dim(self, input):
200        if input.dim() not in (2, 3):
201            raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)")
202
203
204class LazyInstanceNorm1d(_LazyNormBase, _InstanceNorm):
205    r"""A :class:`torch.nn.InstanceNorm1d` module with lazy initialization of the ``num_features`` argument.
206
207    The ``num_features`` argument of the :class:`InstanceNorm1d` is inferred from the ``input.size(1)``.
208    The attributes that will be lazily initialized are `weight`, `bias`, `running_mean` and `running_var`.
209
210    Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
211    on lazy modules and their limitations.
212
213    Args:
214        num_features: :math:`C` from an expected input of size
215            :math:`(N, C, L)` or :math:`(C, L)`
216        eps: a value added to the denominator for numerical stability. Default: 1e-5
217        momentum: the value used for the running_mean and running_var computation. Default: 0.1
218        affine: a boolean value that when set to ``True``, this module has
219            learnable affine parameters, initialized the same way as done for batch normalization.
220            Default: ``False``.
221        track_running_stats: a boolean value that when set to ``True``, this
222            module tracks the running mean and variance, and when set to ``False``,
223            this module does not track such statistics and always uses batch
224            statistics in both training and eval modes. Default: ``False``
225
226    Shape:
227        - Input: :math:`(N, C, L)` or :math:`(C, L)`
228        - Output: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input)
229    """
230
231    cls_to_become = InstanceNorm1d  # type: ignore[assignment]
232
233    def _get_no_batch_dim(self):
234        return 2
235
236    def _check_input_dim(self, input):
237        if input.dim() not in (2, 3):
238            raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)")
239
240
241class InstanceNorm2d(_InstanceNorm):
242    r"""Applies Instance Normalization.
243
244    This operation applies Instance Normalization
245    over a 4D input (a mini-batch of 2D inputs
246    with additional channel dimension) as described in the paper
247    `Instance Normalization: The Missing Ingredient for Fast Stylization
248    <https://arxiv.org/abs/1607.08022>`__.
249
250    .. math::
251
252        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
253
254    The mean and standard-deviation are calculated per-dimension separately
255    for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors
256    of size `C` (where `C` is the input size) if :attr:`affine` is ``True``.
257    The standard-deviation is calculated via the biased estimator, equivalent to
258    `torch.var(input, unbiased=False)`.
259
260    By default, this layer uses instance statistics computed from input data in
261    both training and evaluation modes.
262
263    If :attr:`track_running_stats` is set to ``True``, during training this
264    layer keeps running estimates of its computed mean and variance, which are
265    then used for normalization during evaluation. The running estimates are
266    kept with a default :attr:`momentum` of 0.1.
267
268    .. note::
269        This :attr:`momentum` argument is different from one used in optimizer
270        classes and the conventional notion of momentum. Mathematically, the
271        update rule for running statistics here is
272        :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
273        where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
274        new observed value.
275
276    .. note::
277        :class:`InstanceNorm2d` and :class:`LayerNorm` are very similar, but
278        have some subtle differences. :class:`InstanceNorm2d` is applied
279        on each channel of channeled data like RGB images, but
280        :class:`LayerNorm` is usually applied on entire sample and often in NLP
281        tasks. Additionally, :class:`LayerNorm` applies elementwise affine
282        transform, while :class:`InstanceNorm2d` usually don't apply affine
283        transform.
284
285    Args:
286        num_features: :math:`C` from an expected input of size
287            :math:`(N, C, H, W)` or :math:`(C, H, W)`
288        eps: a value added to the denominator for numerical stability. Default: 1e-5
289        momentum: the value used for the running_mean and running_var computation. Default: 0.1
290        affine: a boolean value that when set to ``True``, this module has
291            learnable affine parameters, initialized the same way as done for batch normalization.
292            Default: ``False``.
293        track_running_stats: a boolean value that when set to ``True``, this
294            module tracks the running mean and variance, and when set to ``False``,
295            this module does not track such statistics and always uses batch
296            statistics in both training and eval modes. Default: ``False``
297
298    Shape:
299        - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`
300        - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input)
301
302    Examples::
303
304        >>> # Without Learnable Parameters
305        >>> m = nn.InstanceNorm2d(100)
306        >>> # With Learnable Parameters
307        >>> m = nn.InstanceNorm2d(100, affine=True)
308        >>> input = torch.randn(20, 100, 35, 45)
309        >>> output = m(input)
310    """
311
312    def _get_no_batch_dim(self):
313        return 3
314
315    def _check_input_dim(self, input):
316        if input.dim() not in (3, 4):
317            raise ValueError(f"expected 3D or 4D input (got {input.dim()}D input)")
318
319
320class LazyInstanceNorm2d(_LazyNormBase, _InstanceNorm):
321    r"""A :class:`torch.nn.InstanceNorm2d` module with lazy initialization of the ``num_features`` argument.
322
323    The ``num_features`` argument of the :class:`InstanceNorm2d` is inferred from the ``input.size(1)``.
324    The attributes that will be lazily initialized are `weight`, `bias`,
325    `running_mean` and `running_var`.
326
327    Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
328    on lazy modules and their limitations.
329
330    Args:
331        num_features: :math:`C` from an expected input of size
332            :math:`(N, C, H, W)` or :math:`(C, H, W)`
333        eps: a value added to the denominator for numerical stability. Default: 1e-5
334        momentum: the value used for the running_mean and running_var computation. Default: 0.1
335        affine: a boolean value that when set to ``True``, this module has
336            learnable affine parameters, initialized the same way as done for batch normalization.
337            Default: ``False``.
338        track_running_stats: a boolean value that when set to ``True``, this
339            module tracks the running mean and variance, and when set to ``False``,
340            this module does not track such statistics and always uses batch
341            statistics in both training and eval modes. Default: ``False``
342
343    Shape:
344        - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`
345        - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input)
346    """
347
348    cls_to_become = InstanceNorm2d  # type: ignore[assignment]
349
350    def _get_no_batch_dim(self):
351        return 3
352
353    def _check_input_dim(self, input):
354        if input.dim() not in (3, 4):
355            raise ValueError(f"expected 3D or 4D input (got {input.dim()}D input)")
356
357
358class InstanceNorm3d(_InstanceNorm):
359    r"""Applies Instance Normalization.
360
361    This operation applies Instance Normalization
362    over a 5D input (a mini-batch of 3D inputs with additional channel dimension) as described in the paper
363    `Instance Normalization: The Missing Ingredient for Fast Stylization
364    <https://arxiv.org/abs/1607.08022>`__.
365
366    .. math::
367
368        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
369
370    The mean and standard-deviation are calculated per-dimension separately
371    for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors
372    of size C (where C is the input size) if :attr:`affine` is ``True``.
373    The standard-deviation is calculated via the biased estimator, equivalent to
374    `torch.var(input, unbiased=False)`.
375
376    By default, this layer uses instance statistics computed from input data in
377    both training and evaluation modes.
378
379    If :attr:`track_running_stats` is set to ``True``, during training this
380    layer keeps running estimates of its computed mean and variance, which are
381    then used for normalization during evaluation. The running estimates are
382    kept with a default :attr:`momentum` of 0.1.
383
384    .. note::
385        This :attr:`momentum` argument is different from one used in optimizer
386        classes and the conventional notion of momentum. Mathematically, the
387        update rule for running statistics here is
388        :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
389        where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
390        new observed value.
391
392    .. note::
393        :class:`InstanceNorm3d` and :class:`LayerNorm` are very similar, but
394        have some subtle differences. :class:`InstanceNorm3d` is applied
395        on each channel of channeled data like 3D models with RGB color, but
396        :class:`LayerNorm` is usually applied on entire sample and often in NLP
397        tasks. Additionally, :class:`LayerNorm` applies elementwise affine
398        transform, while :class:`InstanceNorm3d` usually don't apply affine
399        transform.
400
401    Args:
402        num_features: :math:`C` from an expected input of size
403            :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`
404        eps: a value added to the denominator for numerical stability. Default: 1e-5
405        momentum: the value used for the running_mean and running_var computation. Default: 0.1
406        affine: a boolean value that when set to ``True``, this module has
407            learnable affine parameters, initialized the same way as done for batch normalization.
408            Default: ``False``.
409        track_running_stats: a boolean value that when set to ``True``, this
410            module tracks the running mean and variance, and when set to ``False``,
411            this module does not track such statistics and always uses batch
412            statistics in both training and eval modes. Default: ``False``
413
414    Shape:
415        - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`
416        - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input)
417
418    Examples::
419
420        >>> # Without Learnable Parameters
421        >>> m = nn.InstanceNorm3d(100)
422        >>> # With Learnable Parameters
423        >>> m = nn.InstanceNorm3d(100, affine=True)
424        >>> input = torch.randn(20, 100, 35, 45, 10)
425        >>> output = m(input)
426    """
427
428    def _get_no_batch_dim(self):
429        return 4
430
431    def _check_input_dim(self, input):
432        if input.dim() not in (4, 5):
433            raise ValueError(f"expected 4D or 5D input (got {input.dim()}D input)")
434
435
436class LazyInstanceNorm3d(_LazyNormBase, _InstanceNorm):
437    r"""A :class:`torch.nn.InstanceNorm3d` module with lazy initialization of the ``num_features`` argument.
438
439    The ``num_features`` argument of the :class:`InstanceNorm3d` is inferred from the ``input.size(1)``.
440    The attributes that will be lazily initialized are `weight`, `bias`,
441    `running_mean` and `running_var`.
442
443    Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
444    on lazy modules and their limitations.
445
446    Args:
447        num_features: :math:`C` from an expected input of size
448            :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`
449        eps: a value added to the denominator for numerical stability. Default: 1e-5
450        momentum: the value used for the running_mean and running_var computation. Default: 0.1
451        affine: a boolean value that when set to ``True``, this module has
452            learnable affine parameters, initialized the same way as done for batch normalization.
453            Default: ``False``.
454        track_running_stats: a boolean value that when set to ``True``, this
455            module tracks the running mean and variance, and when set to ``False``,
456            this module does not track such statistics and always uses batch
457            statistics in both training and eval modes. Default: ``False``
458
459    Shape:
460        - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`
461        - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input)
462    """
463
464    cls_to_become = InstanceNorm3d  # type: ignore[assignment]
465
466    def _get_no_batch_dim(self):
467        return 4
468
469    def _check_input_dim(self, input):
470        if input.dim() not in (4, 5):
471            raise ValueError(f"expected 4D or 5D input (got {input.dim()}D input)")
472