1# mypy: allow-untyped-defs 2import numbers 3from typing import List, Optional, Tuple, Union 4 5import torch 6from torch import Size, Tensor 7from torch.nn import functional as F, init 8from torch.nn.parameter import Parameter 9 10from ._functions import CrossMapLRN2d as _cross_map_lrn2d 11from .module import Module 12 13 14__all__ = ["LocalResponseNorm", "CrossMapLRN2d", "LayerNorm", "GroupNorm", "RMSNorm"] 15 16 17class LocalResponseNorm(Module): 18 r"""Applies local response normalization over an input signal. 19 20 The input signal is composed of several input planes, where channels occupy the second dimension. 21 Applies normalization across channels. 22 23 .. math:: 24 b_{c} = a_{c}\left(k + \frac{\alpha}{n} 25 \sum_{c'=\max(0, c-n/2)}^{\min(N-1,c+n/2)}a_{c'}^2\right)^{-\beta} 26 27 Args: 28 size: amount of neighbouring channels used for normalization 29 alpha: multiplicative factor. Default: 0.0001 30 beta: exponent. Default: 0.75 31 k: additive factor. Default: 1 32 33 Shape: 34 - Input: :math:`(N, C, *)` 35 - Output: :math:`(N, C, *)` (same shape as input) 36 37 Examples:: 38 39 >>> lrn = nn.LocalResponseNorm(2) 40 >>> signal_2d = torch.randn(32, 5, 24, 24) 41 >>> signal_4d = torch.randn(16, 5, 7, 7, 7, 7) 42 >>> output_2d = lrn(signal_2d) 43 >>> output_4d = lrn(signal_4d) 44 45 """ 46 47 __constants__ = ["size", "alpha", "beta", "k"] 48 size: int 49 alpha: float 50 beta: float 51 k: float 52 53 def __init__( 54 self, size: int, alpha: float = 1e-4, beta: float = 0.75, k: float = 1.0 55 ) -> None: 56 super().__init__() 57 self.size = size 58 self.alpha = alpha 59 self.beta = beta 60 self.k = k 61 62 def forward(self, input: Tensor) -> Tensor: 63 return F.local_response_norm(input, self.size, self.alpha, self.beta, self.k) 64 65 def extra_repr(self): 66 return "{size}, alpha={alpha}, beta={beta}, k={k}".format(**self.__dict__) 67 68 69class CrossMapLRN2d(Module): 70 size: int 71 alpha: float 72 beta: float 73 k: float 74 75 def __init__( 76 self, size: int, alpha: float = 1e-4, beta: float = 0.75, k: float = 1 77 ) -> None: 78 super().__init__() 79 self.size = size 80 self.alpha = alpha 81 self.beta = beta 82 self.k = k 83 84 def forward(self, input: Tensor) -> Tensor: 85 return _cross_map_lrn2d.apply(input, self.size, self.alpha, self.beta, self.k) 86 87 def extra_repr(self) -> str: 88 return "{size}, alpha={alpha}, beta={beta}, k={k}".format(**self.__dict__) 89 90 91_shape_t = Union[int, List[int], Size] 92 93 94class LayerNorm(Module): 95 r"""Applies Layer Normalization over a mini-batch of inputs. 96 97 This layer implements the operation as described in 98 the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__ 99 100 .. math:: 101 y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta 102 103 The mean and standard-deviation are calculated over the last `D` dimensions, where `D` 104 is the dimension of :attr:`normalized_shape`. For example, if :attr:`normalized_shape` 105 is ``(3, 5)`` (a 2-dimensional shape), the mean and standard-deviation are computed over 106 the last 2 dimensions of the input (i.e. ``input.mean((-2, -1))``). 107 :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of 108 :attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``. 109 The standard-deviation is calculated via the biased estimator, equivalent to 110 `torch.var(input, unbiased=False)`. 111 112 .. note:: 113 Unlike Batch Normalization and Instance Normalization, which applies 114 scalar scale and bias for each entire channel/plane with the 115 :attr:`affine` option, Layer Normalization applies per-element scale and 116 bias with :attr:`elementwise_affine`. 117 118 This layer uses statistics computed from input data in both training and 119 evaluation modes. 120 121 Args: 122 normalized_shape (int or list or torch.Size): input shape from an expected input 123 of size 124 125 .. math:: 126 [* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1] 127 \times \ldots \times \text{normalized\_shape}[-1]] 128 129 If a single integer is used, it is treated as a singleton list, and this module will 130 normalize over the last dimension which is expected to be of that specific size. 131 eps: a value added to the denominator for numerical stability. Default: 1e-5 132 elementwise_affine: a boolean value that when set to ``True``, this module 133 has learnable per-element affine parameters initialized to ones (for weights) 134 and zeros (for biases). Default: ``True``. 135 bias: If set to ``False``, the layer will not learn an additive bias (only relevant if 136 :attr:`elementwise_affine` is ``True``). Default: ``True``. 137 138 Attributes: 139 weight: the learnable weights of the module of shape 140 :math:`\text{normalized\_shape}` when :attr:`elementwise_affine` is set to ``True``. 141 The values are initialized to 1. 142 bias: the learnable bias of the module of shape 143 :math:`\text{normalized\_shape}` when :attr:`elementwise_affine` is set to ``True``. 144 The values are initialized to 0. 145 146 Shape: 147 - Input: :math:`(N, *)` 148 - Output: :math:`(N, *)` (same shape as input) 149 150 Examples:: 151 152 >>> # NLP Example 153 >>> batch, sentence_length, embedding_dim = 20, 5, 10 154 >>> embedding = torch.randn(batch, sentence_length, embedding_dim) 155 >>> layer_norm = nn.LayerNorm(embedding_dim) 156 >>> # Activate module 157 >>> layer_norm(embedding) 158 >>> 159 >>> # Image Example 160 >>> N, C, H, W = 20, 5, 10, 10 161 >>> input = torch.randn(N, C, H, W) 162 >>> # Normalize over the last three dimensions (i.e. the channel and spatial dimensions) 163 >>> # as shown in the image below 164 >>> layer_norm = nn.LayerNorm([C, H, W]) 165 >>> output = layer_norm(input) 166 167 .. image:: ../_static/img/nn/layer_norm.jpg 168 :scale: 50 % 169 170 """ 171 172 __constants__ = ["normalized_shape", "eps", "elementwise_affine"] 173 normalized_shape: Tuple[int, ...] 174 eps: float 175 elementwise_affine: bool 176 177 def __init__( 178 self, 179 normalized_shape: _shape_t, 180 eps: float = 1e-5, 181 elementwise_affine: bool = True, 182 bias: bool = True, 183 device=None, 184 dtype=None, 185 ) -> None: 186 factory_kwargs = {"device": device, "dtype": dtype} 187 super().__init__() 188 if isinstance(normalized_shape, numbers.Integral): 189 # mypy error: incompatible types in assignment 190 normalized_shape = (normalized_shape,) # type: ignore[assignment] 191 self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] 192 self.eps = eps 193 self.elementwise_affine = elementwise_affine 194 if self.elementwise_affine: 195 self.weight = Parameter( 196 torch.empty(self.normalized_shape, **factory_kwargs) 197 ) 198 if bias: 199 self.bias = Parameter( 200 torch.empty(self.normalized_shape, **factory_kwargs) 201 ) 202 else: 203 self.register_parameter("bias", None) 204 else: 205 self.register_parameter("weight", None) 206 self.register_parameter("bias", None) 207 208 self.reset_parameters() 209 210 def reset_parameters(self) -> None: 211 if self.elementwise_affine: 212 init.ones_(self.weight) 213 if self.bias is not None: 214 init.zeros_(self.bias) 215 216 def forward(self, input: Tensor) -> Tensor: 217 return F.layer_norm( 218 input, self.normalized_shape, self.weight, self.bias, self.eps 219 ) 220 221 def extra_repr(self) -> str: 222 return ( 223 "{normalized_shape}, eps={eps}, " 224 "elementwise_affine={elementwise_affine}".format(**self.__dict__) 225 ) 226 227 228class GroupNorm(Module): 229 r"""Applies Group Normalization over a mini-batch of inputs. 230 231 This layer implements the operation as described in 232 the paper `Group Normalization <https://arxiv.org/abs/1803.08494>`__ 233 234 .. math:: 235 y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta 236 237 The input channels are separated into :attr:`num_groups` groups, each containing 238 ``num_channels / num_groups`` channels. :attr:`num_channels` must be divisible by 239 :attr:`num_groups`. The mean and standard-deviation are calculated 240 separately over the each group. :math:`\gamma` and :math:`\beta` are learnable 241 per-channel affine transform parameter vectors of size :attr:`num_channels` if 242 :attr:`affine` is ``True``. 243 The standard-deviation is calculated via the biased estimator, equivalent to 244 `torch.var(input, unbiased=False)`. 245 246 This layer uses statistics computed from input data in both training and 247 evaluation modes. 248 249 Args: 250 num_groups (int): number of groups to separate the channels into 251 num_channels (int): number of channels expected in input 252 eps: a value added to the denominator for numerical stability. Default: 1e-5 253 affine: a boolean value that when set to ``True``, this module 254 has learnable per-channel affine parameters initialized to ones (for weights) 255 and zeros (for biases). Default: ``True``. 256 257 Shape: 258 - Input: :math:`(N, C, *)` where :math:`C=\text{num\_channels}` 259 - Output: :math:`(N, C, *)` (same shape as input) 260 261 Examples:: 262 263 >>> input = torch.randn(20, 6, 10, 10) 264 >>> # Separate 6 channels into 3 groups 265 >>> m = nn.GroupNorm(3, 6) 266 >>> # Separate 6 channels into 6 groups (equivalent with InstanceNorm) 267 >>> m = nn.GroupNorm(6, 6) 268 >>> # Put all 6 channels into a single group (equivalent with LayerNorm) 269 >>> m = nn.GroupNorm(1, 6) 270 >>> # Activating the module 271 >>> output = m(input) 272 """ 273 274 __constants__ = ["num_groups", "num_channels", "eps", "affine"] 275 num_groups: int 276 num_channels: int 277 eps: float 278 affine: bool 279 280 def __init__( 281 self, 282 num_groups: int, 283 num_channels: int, 284 eps: float = 1e-5, 285 affine: bool = True, 286 device=None, 287 dtype=None, 288 ) -> None: 289 factory_kwargs = {"device": device, "dtype": dtype} 290 super().__init__() 291 if num_channels % num_groups != 0: 292 raise ValueError("num_channels must be divisible by num_groups") 293 294 self.num_groups = num_groups 295 self.num_channels = num_channels 296 self.eps = eps 297 self.affine = affine 298 if self.affine: 299 self.weight = Parameter(torch.empty(num_channels, **factory_kwargs)) 300 self.bias = Parameter(torch.empty(num_channels, **factory_kwargs)) 301 else: 302 self.register_parameter("weight", None) 303 self.register_parameter("bias", None) 304 305 self.reset_parameters() 306 307 def reset_parameters(self) -> None: 308 if self.affine: 309 init.ones_(self.weight) 310 init.zeros_(self.bias) 311 312 def forward(self, input: Tensor) -> Tensor: 313 return F.group_norm(input, self.num_groups, self.weight, self.bias, self.eps) 314 315 def extra_repr(self) -> str: 316 return "{num_groups}, {num_channels}, eps={eps}, " "affine={affine}".format( 317 **self.__dict__ 318 ) 319 320 321class RMSNorm(Module): 322 r"""Applies Root Mean Square Layer Normalization over a mini-batch of inputs. 323 324 This layer implements the operation as described in 325 the paper `Root Mean Square Layer Normalization <https://arxiv.org/pdf/1910.07467.pdf>`__ 326 327 .. math:: 328 y = \frac{x}{\sqrt{\mathrm{RMS}[x] + \epsilon}} * \gamma 329 330 The root mean squared norm is taken over the last ``D`` dimensions, where ``D`` 331 is the dimension of :attr:`normalized_shape`. For example, if :attr:`normalized_shape` 332 is ``(3, 5)`` (a 2-dimensional shape), the rms norm is computed over 333 the last 2 dimensions of the input. 334 335 Args: 336 normalized_shape (int or list or torch.Size): input shape from an expected input 337 of size 338 339 .. math:: 340 [* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1] 341 \times \ldots \times \text{normalized\_shape}[-1]] 342 343 If a single integer is used, it is treated as a singleton list, and this module will 344 normalize over the last dimension which is expected to be of that specific size. 345 eps: a value added to the denominator for numerical stability. Default: :func:`torch.finfo(x.dtype).eps` 346 elementwise_affine: a boolean value that when set to ``True``, this module 347 has learnable per-element affine parameters initialized to ones (for weights) 348 and zeros (for biases). Default: ``True``. 349 350 Shape: 351 - Input: :math:`(N, *)` 352 - Output: :math:`(N, *)` (same shape as input) 353 354 Examples:: 355 356 >>> rms_norm = nn.RMSNorm([2, 3]) 357 >>> input = torch.randn(2, 2, 3) 358 >>> rms_norm(input) 359 360 """ 361 __constants__ = ["normalized_shape", "eps", "elementwise_affine"] 362 normalized_shape: Tuple[int, ...] 363 eps: Optional[float] 364 elementwise_affine: bool 365 366 def __init__( 367 self, 368 normalized_shape: _shape_t, 369 eps: Optional[float] = None, 370 elementwise_affine: bool = True, 371 device=None, 372 dtype=None, 373 ) -> None: 374 factory_kwargs = {"device": device, "dtype": dtype} 375 super().__init__() 376 if isinstance(normalized_shape, numbers.Integral): 377 # mypy error: incompatible types in assignment 378 normalized_shape = (normalized_shape,) # type: ignore[assignment] 379 self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] 380 self.eps = eps 381 self.elementwise_affine = elementwise_affine 382 if self.elementwise_affine: 383 self.weight = Parameter( 384 torch.empty(self.normalized_shape, **factory_kwargs) 385 ) 386 else: 387 self.register_parameter("weight", None) 388 self.reset_parameters() 389 390 def reset_parameters(self) -> None: 391 """ 392 Resets parameters based on their initialization used in __init__. 393 """ 394 if self.elementwise_affine: 395 init.ones_(self.weight) 396 397 def forward(self, x: torch.Tensor) -> torch.Tensor: 398 """ 399 Runs forward pass. 400 """ 401 return F.rms_norm(x, self.normalized_shape, self.weight, self.eps) 402 403 def extra_repr(self) -> str: 404 """ 405 Extra information about the module. 406 """ 407 return ( 408 "{normalized_shape}, eps={eps}, " 409 "elementwise_affine={elementwise_affine}".format(**self.__dict__) 410 ) 411 412 413# TODO: ContrastiveNorm2d 414# TODO: DivisiveNorm2d 415# TODO: SubtractiveNorm2d 416