xref: /aosp_15_r20/external/pytorch/torch/nn/modules/normalization.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import numbers
3from typing import List, Optional, Tuple, Union
4
5import torch
6from torch import Size, Tensor
7from torch.nn import functional as F, init
8from torch.nn.parameter import Parameter
9
10from ._functions import CrossMapLRN2d as _cross_map_lrn2d
11from .module import Module
12
13
14__all__ = ["LocalResponseNorm", "CrossMapLRN2d", "LayerNorm", "GroupNorm", "RMSNorm"]
15
16
17class LocalResponseNorm(Module):
18    r"""Applies local response normalization over an input signal.
19
20    The input signal is composed of several input planes, where channels occupy the second dimension.
21    Applies normalization across channels.
22
23    .. math::
24        b_{c} = a_{c}\left(k + \frac{\alpha}{n}
25        \sum_{c'=\max(0, c-n/2)}^{\min(N-1,c+n/2)}a_{c'}^2\right)^{-\beta}
26
27    Args:
28        size: amount of neighbouring channels used for normalization
29        alpha: multiplicative factor. Default: 0.0001
30        beta: exponent. Default: 0.75
31        k: additive factor. Default: 1
32
33    Shape:
34        - Input: :math:`(N, C, *)`
35        - Output: :math:`(N, C, *)` (same shape as input)
36
37    Examples::
38
39        >>> lrn = nn.LocalResponseNorm(2)
40        >>> signal_2d = torch.randn(32, 5, 24, 24)
41        >>> signal_4d = torch.randn(16, 5, 7, 7, 7, 7)
42        >>> output_2d = lrn(signal_2d)
43        >>> output_4d = lrn(signal_4d)
44
45    """
46
47    __constants__ = ["size", "alpha", "beta", "k"]
48    size: int
49    alpha: float
50    beta: float
51    k: float
52
53    def __init__(
54        self, size: int, alpha: float = 1e-4, beta: float = 0.75, k: float = 1.0
55    ) -> None:
56        super().__init__()
57        self.size = size
58        self.alpha = alpha
59        self.beta = beta
60        self.k = k
61
62    def forward(self, input: Tensor) -> Tensor:
63        return F.local_response_norm(input, self.size, self.alpha, self.beta, self.k)
64
65    def extra_repr(self):
66        return "{size}, alpha={alpha}, beta={beta}, k={k}".format(**self.__dict__)
67
68
69class CrossMapLRN2d(Module):
70    size: int
71    alpha: float
72    beta: float
73    k: float
74
75    def __init__(
76        self, size: int, alpha: float = 1e-4, beta: float = 0.75, k: float = 1
77    ) -> None:
78        super().__init__()
79        self.size = size
80        self.alpha = alpha
81        self.beta = beta
82        self.k = k
83
84    def forward(self, input: Tensor) -> Tensor:
85        return _cross_map_lrn2d.apply(input, self.size, self.alpha, self.beta, self.k)
86
87    def extra_repr(self) -> str:
88        return "{size}, alpha={alpha}, beta={beta}, k={k}".format(**self.__dict__)
89
90
91_shape_t = Union[int, List[int], Size]
92
93
94class LayerNorm(Module):
95    r"""Applies Layer Normalization over a mini-batch of inputs.
96
97    This layer implements the operation as described in
98    the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__
99
100    .. math::
101        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
102
103    The mean and standard-deviation are calculated over the last `D` dimensions, where `D`
104    is the dimension of :attr:`normalized_shape`. For example, if :attr:`normalized_shape`
105    is ``(3, 5)`` (a 2-dimensional shape), the mean and standard-deviation are computed over
106    the last 2 dimensions of the input (i.e. ``input.mean((-2, -1))``).
107    :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
108    :attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
109    The standard-deviation is calculated via the biased estimator, equivalent to
110    `torch.var(input, unbiased=False)`.
111
112    .. note::
113        Unlike Batch Normalization and Instance Normalization, which applies
114        scalar scale and bias for each entire channel/plane with the
115        :attr:`affine` option, Layer Normalization applies per-element scale and
116        bias with :attr:`elementwise_affine`.
117
118    This layer uses statistics computed from input data in both training and
119    evaluation modes.
120
121    Args:
122        normalized_shape (int or list or torch.Size): input shape from an expected input
123            of size
124
125            .. math::
126                [* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1]
127                    \times \ldots \times \text{normalized\_shape}[-1]]
128
129            If a single integer is used, it is treated as a singleton list, and this module will
130            normalize over the last dimension which is expected to be of that specific size.
131        eps: a value added to the denominator for numerical stability. Default: 1e-5
132        elementwise_affine: a boolean value that when set to ``True``, this module
133            has learnable per-element affine parameters initialized to ones (for weights)
134            and zeros (for biases). Default: ``True``.
135        bias: If set to ``False``, the layer will not learn an additive bias (only relevant if
136            :attr:`elementwise_affine` is ``True``). Default: ``True``.
137
138    Attributes:
139        weight: the learnable weights of the module of shape
140            :math:`\text{normalized\_shape}` when :attr:`elementwise_affine` is set to ``True``.
141            The values are initialized to 1.
142        bias:   the learnable bias of the module of shape
143                :math:`\text{normalized\_shape}` when :attr:`elementwise_affine` is set to ``True``.
144                The values are initialized to 0.
145
146    Shape:
147        - Input: :math:`(N, *)`
148        - Output: :math:`(N, *)` (same shape as input)
149
150    Examples::
151
152        >>> # NLP Example
153        >>> batch, sentence_length, embedding_dim = 20, 5, 10
154        >>> embedding = torch.randn(batch, sentence_length, embedding_dim)
155        >>> layer_norm = nn.LayerNorm(embedding_dim)
156        >>> # Activate module
157        >>> layer_norm(embedding)
158        >>>
159        >>> # Image Example
160        >>> N, C, H, W = 20, 5, 10, 10
161        >>> input = torch.randn(N, C, H, W)
162        >>> # Normalize over the last three dimensions (i.e. the channel and spatial dimensions)
163        >>> # as shown in the image below
164        >>> layer_norm = nn.LayerNorm([C, H, W])
165        >>> output = layer_norm(input)
166
167    .. image:: ../_static/img/nn/layer_norm.jpg
168        :scale: 50 %
169
170    """
171
172    __constants__ = ["normalized_shape", "eps", "elementwise_affine"]
173    normalized_shape: Tuple[int, ...]
174    eps: float
175    elementwise_affine: bool
176
177    def __init__(
178        self,
179        normalized_shape: _shape_t,
180        eps: float = 1e-5,
181        elementwise_affine: bool = True,
182        bias: bool = True,
183        device=None,
184        dtype=None,
185    ) -> None:
186        factory_kwargs = {"device": device, "dtype": dtype}
187        super().__init__()
188        if isinstance(normalized_shape, numbers.Integral):
189            # mypy error: incompatible types in assignment
190            normalized_shape = (normalized_shape,)  # type: ignore[assignment]
191        self.normalized_shape = tuple(normalized_shape)  # type: ignore[arg-type]
192        self.eps = eps
193        self.elementwise_affine = elementwise_affine
194        if self.elementwise_affine:
195            self.weight = Parameter(
196                torch.empty(self.normalized_shape, **factory_kwargs)
197            )
198            if bias:
199                self.bias = Parameter(
200                    torch.empty(self.normalized_shape, **factory_kwargs)
201                )
202            else:
203                self.register_parameter("bias", None)
204        else:
205            self.register_parameter("weight", None)
206            self.register_parameter("bias", None)
207
208        self.reset_parameters()
209
210    def reset_parameters(self) -> None:
211        if self.elementwise_affine:
212            init.ones_(self.weight)
213            if self.bias is not None:
214                init.zeros_(self.bias)
215
216    def forward(self, input: Tensor) -> Tensor:
217        return F.layer_norm(
218            input, self.normalized_shape, self.weight, self.bias, self.eps
219        )
220
221    def extra_repr(self) -> str:
222        return (
223            "{normalized_shape}, eps={eps}, "
224            "elementwise_affine={elementwise_affine}".format(**self.__dict__)
225        )
226
227
228class GroupNorm(Module):
229    r"""Applies Group Normalization over a mini-batch of inputs.
230
231    This layer implements the operation as described in
232    the paper `Group Normalization <https://arxiv.org/abs/1803.08494>`__
233
234    .. math::
235        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
236
237    The input channels are separated into :attr:`num_groups` groups, each containing
238    ``num_channels / num_groups`` channels. :attr:`num_channels` must be divisible by
239    :attr:`num_groups`. The mean and standard-deviation are calculated
240    separately over the each group. :math:`\gamma` and :math:`\beta` are learnable
241    per-channel affine transform parameter vectors of size :attr:`num_channels` if
242    :attr:`affine` is ``True``.
243    The standard-deviation is calculated via the biased estimator, equivalent to
244    `torch.var(input, unbiased=False)`.
245
246    This layer uses statistics computed from input data in both training and
247    evaluation modes.
248
249    Args:
250        num_groups (int): number of groups to separate the channels into
251        num_channels (int): number of channels expected in input
252        eps: a value added to the denominator for numerical stability. Default: 1e-5
253        affine: a boolean value that when set to ``True``, this module
254            has learnable per-channel affine parameters initialized to ones (for weights)
255            and zeros (for biases). Default: ``True``.
256
257    Shape:
258        - Input: :math:`(N, C, *)` where :math:`C=\text{num\_channels}`
259        - Output: :math:`(N, C, *)` (same shape as input)
260
261    Examples::
262
263        >>> input = torch.randn(20, 6, 10, 10)
264        >>> # Separate 6 channels into 3 groups
265        >>> m = nn.GroupNorm(3, 6)
266        >>> # Separate 6 channels into 6 groups (equivalent with InstanceNorm)
267        >>> m = nn.GroupNorm(6, 6)
268        >>> # Put all 6 channels into a single group (equivalent with LayerNorm)
269        >>> m = nn.GroupNorm(1, 6)
270        >>> # Activating the module
271        >>> output = m(input)
272    """
273
274    __constants__ = ["num_groups", "num_channels", "eps", "affine"]
275    num_groups: int
276    num_channels: int
277    eps: float
278    affine: bool
279
280    def __init__(
281        self,
282        num_groups: int,
283        num_channels: int,
284        eps: float = 1e-5,
285        affine: bool = True,
286        device=None,
287        dtype=None,
288    ) -> None:
289        factory_kwargs = {"device": device, "dtype": dtype}
290        super().__init__()
291        if num_channels % num_groups != 0:
292            raise ValueError("num_channels must be divisible by num_groups")
293
294        self.num_groups = num_groups
295        self.num_channels = num_channels
296        self.eps = eps
297        self.affine = affine
298        if self.affine:
299            self.weight = Parameter(torch.empty(num_channels, **factory_kwargs))
300            self.bias = Parameter(torch.empty(num_channels, **factory_kwargs))
301        else:
302            self.register_parameter("weight", None)
303            self.register_parameter("bias", None)
304
305        self.reset_parameters()
306
307    def reset_parameters(self) -> None:
308        if self.affine:
309            init.ones_(self.weight)
310            init.zeros_(self.bias)
311
312    def forward(self, input: Tensor) -> Tensor:
313        return F.group_norm(input, self.num_groups, self.weight, self.bias, self.eps)
314
315    def extra_repr(self) -> str:
316        return "{num_groups}, {num_channels}, eps={eps}, " "affine={affine}".format(
317            **self.__dict__
318        )
319
320
321class RMSNorm(Module):
322    r"""Applies Root Mean Square Layer Normalization over a mini-batch of inputs.
323
324    This layer implements the operation as described in
325    the paper `Root Mean Square Layer Normalization <https://arxiv.org/pdf/1910.07467.pdf>`__
326
327    .. math::
328        y = \frac{x}{\sqrt{\mathrm{RMS}[x] + \epsilon}} * \gamma
329
330    The root mean squared norm is taken over the last ``D`` dimensions, where ``D``
331    is the dimension of :attr:`normalized_shape`. For example, if :attr:`normalized_shape`
332    is ``(3, 5)`` (a 2-dimensional shape), the rms norm is computed over
333    the last 2 dimensions of the input.
334
335    Args:
336        normalized_shape (int or list or torch.Size): input shape from an expected input
337            of size
338
339            .. math::
340                [* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1]
341                    \times \ldots \times \text{normalized\_shape}[-1]]
342
343            If a single integer is used, it is treated as a singleton list, and this module will
344            normalize over the last dimension which is expected to be of that specific size.
345        eps: a value added to the denominator for numerical stability. Default: :func:`torch.finfo(x.dtype).eps`
346        elementwise_affine: a boolean value that when set to ``True``, this module
347            has learnable per-element affine parameters initialized to ones (for weights)
348            and zeros (for biases). Default: ``True``.
349
350    Shape:
351        - Input: :math:`(N, *)`
352        - Output: :math:`(N, *)` (same shape as input)
353
354    Examples::
355
356        >>> rms_norm = nn.RMSNorm([2, 3])
357        >>> input = torch.randn(2, 2, 3)
358        >>> rms_norm(input)
359
360    """
361    __constants__ = ["normalized_shape", "eps", "elementwise_affine"]
362    normalized_shape: Tuple[int, ...]
363    eps: Optional[float]
364    elementwise_affine: bool
365
366    def __init__(
367        self,
368        normalized_shape: _shape_t,
369        eps: Optional[float] = None,
370        elementwise_affine: bool = True,
371        device=None,
372        dtype=None,
373    ) -> None:
374        factory_kwargs = {"device": device, "dtype": dtype}
375        super().__init__()
376        if isinstance(normalized_shape, numbers.Integral):
377            # mypy error: incompatible types in assignment
378            normalized_shape = (normalized_shape,)  # type: ignore[assignment]
379        self.normalized_shape = tuple(normalized_shape)  # type: ignore[arg-type]
380        self.eps = eps
381        self.elementwise_affine = elementwise_affine
382        if self.elementwise_affine:
383            self.weight = Parameter(
384                torch.empty(self.normalized_shape, **factory_kwargs)
385            )
386        else:
387            self.register_parameter("weight", None)
388        self.reset_parameters()
389
390    def reset_parameters(self) -> None:
391        """
392        Resets parameters based on their initialization used in __init__.
393        """
394        if self.elementwise_affine:
395            init.ones_(self.weight)
396
397    def forward(self, x: torch.Tensor) -> torch.Tensor:
398        """
399        Runs forward pass.
400        """
401        return F.rms_norm(x, self.normalized_shape, self.weight, self.eps)
402
403    def extra_repr(self) -> str:
404        """
405        Extra information about the module.
406        """
407        return (
408            "{normalized_shape}, eps={eps}, "
409            "elementwise_affine={elementwise_affine}".format(**self.__dict__)
410        )
411
412
413# TODO: ContrastiveNorm2d
414# TODO: DivisiveNorm2d
415# TODO: SubtractiveNorm2d
416