1# mypy: allow-untyped-defs 2 3import warnings 4 5import torch.nn.functional as F 6from torch import Tensor 7 8from .batchnorm import _LazyNormBase, _NormBase 9 10 11__all__ = [ 12 "InstanceNorm1d", 13 "InstanceNorm2d", 14 "InstanceNorm3d", 15 "LazyInstanceNorm1d", 16 "LazyInstanceNorm2d", 17 "LazyInstanceNorm3d", 18] 19 20 21class _InstanceNorm(_NormBase): 22 def __init__( 23 self, 24 num_features: int, 25 eps: float = 1e-5, 26 momentum: float = 0.1, 27 affine: bool = False, 28 track_running_stats: bool = False, 29 device=None, 30 dtype=None, 31 ) -> None: 32 factory_kwargs = {"device": device, "dtype": dtype} 33 super().__init__( 34 num_features, eps, momentum, affine, track_running_stats, **factory_kwargs 35 ) 36 37 def _check_input_dim(self, input): 38 raise NotImplementedError 39 40 def _get_no_batch_dim(self): 41 raise NotImplementedError 42 43 def _handle_no_batch_input(self, input): 44 return self._apply_instance_norm(input.unsqueeze(0)).squeeze(0) 45 46 def _apply_instance_norm(self, input): 47 return F.instance_norm( 48 input, 49 self.running_mean, 50 self.running_var, 51 self.weight, 52 self.bias, 53 self.training or not self.track_running_stats, 54 self.momentum if self.momentum is not None else 0.0, 55 self.eps, 56 ) 57 58 def _load_from_state_dict( 59 self, 60 state_dict, 61 prefix, 62 local_metadata, 63 strict, 64 missing_keys, 65 unexpected_keys, 66 error_msgs, 67 ): 68 version = local_metadata.get("version", None) 69 # at version 1: removed running_mean and running_var when 70 # track_running_stats=False (default) 71 if version is None and not self.track_running_stats: 72 running_stats_keys = [] 73 for name in ("running_mean", "running_var"): 74 key = prefix + name 75 if key in state_dict: 76 running_stats_keys.append(key) 77 if len(running_stats_keys) > 0: 78 error_msgs.append( 79 "Unexpected running stats buffer(s) {names} for {klass} " 80 "with track_running_stats=False. If state_dict is a " 81 "checkpoint saved before 0.4.0, this may be expected " 82 "because {klass} does not track running stats by default " 83 "since 0.4.0. Please remove these keys from state_dict. If " 84 "the running stats are actually needed, instead set " 85 "track_running_stats=True in {klass} to enable them. See " 86 "the documentation of {klass} for details.".format( 87 names=" and ".join(f'"{k}"' for k in running_stats_keys), 88 klass=self.__class__.__name__, 89 ) 90 ) 91 for key in running_stats_keys: 92 state_dict.pop(key) 93 94 super()._load_from_state_dict( 95 state_dict, 96 prefix, 97 local_metadata, 98 strict, 99 missing_keys, 100 unexpected_keys, 101 error_msgs, 102 ) 103 104 def forward(self, input: Tensor) -> Tensor: 105 self._check_input_dim(input) 106 107 feature_dim = input.dim() - self._get_no_batch_dim() 108 if input.size(feature_dim) != self.num_features: 109 if self.affine: 110 raise ValueError( 111 f"expected input's size at dim={feature_dim} to match num_features" 112 f" ({self.num_features}), but got: {input.size(feature_dim)}." 113 ) 114 else: 115 warnings.warn( 116 f"input's size at dim={feature_dim} does not match num_features. " 117 "You can silence this warning by not passing in num_features, " 118 "which is not used because affine=False" 119 ) 120 121 if input.dim() == self._get_no_batch_dim(): 122 return self._handle_no_batch_input(input) 123 124 return self._apply_instance_norm(input) 125 126 127class InstanceNorm1d(_InstanceNorm): 128 r"""Applies Instance Normalization. 129 130 This operation applies Instance Normalization 131 over a 2D (unbatched) or 3D (batched) input as described in the paper 132 `Instance Normalization: The Missing Ingredient for Fast Stylization 133 <https://arxiv.org/abs/1607.08022>`__. 134 135 .. math:: 136 137 y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta 138 139 The mean and standard-deviation are calculated per-dimension separately 140 for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors 141 of size `C` (where `C` is the number of features or channels of the input) if :attr:`affine` is ``True``. 142 The standard-deviation is calculated via the biased estimator, equivalent to 143 `torch.var(input, unbiased=False)`. 144 145 By default, this layer uses instance statistics computed from input data in 146 both training and evaluation modes. 147 148 If :attr:`track_running_stats` is set to ``True``, during training this 149 layer keeps running estimates of its computed mean and variance, which are 150 then used for normalization during evaluation. The running estimates are 151 kept with a default :attr:`momentum` of 0.1. 152 153 .. note:: 154 This :attr:`momentum` argument is different from one used in optimizer 155 classes and the conventional notion of momentum. Mathematically, the 156 update rule for running statistics here is 157 :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, 158 where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the 159 new observed value. 160 161 .. note:: 162 :class:`InstanceNorm1d` and :class:`LayerNorm` are very similar, but 163 have some subtle differences. :class:`InstanceNorm1d` is applied 164 on each channel of channeled data like multidimensional time series, but 165 :class:`LayerNorm` is usually applied on entire sample and often in NLP 166 tasks. Additionally, :class:`LayerNorm` applies elementwise affine 167 transform, while :class:`InstanceNorm1d` usually don't apply affine 168 transform. 169 170 Args: 171 num_features: number of features or channels :math:`C` of the input 172 eps: a value added to the denominator for numerical stability. Default: 1e-5 173 momentum: the value used for the running_mean and running_var computation. Default: 0.1 174 affine: a boolean value that when set to ``True``, this module has 175 learnable affine parameters, initialized the same way as done for batch normalization. 176 Default: ``False``. 177 track_running_stats: a boolean value that when set to ``True``, this 178 module tracks the running mean and variance, and when set to ``False``, 179 this module does not track such statistics and always uses batch 180 statistics in both training and eval modes. Default: ``False`` 181 182 Shape: 183 - Input: :math:`(N, C, L)` or :math:`(C, L)` 184 - Output: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input) 185 186 Examples:: 187 188 >>> # Without Learnable Parameters 189 >>> m = nn.InstanceNorm1d(100) 190 >>> # With Learnable Parameters 191 >>> m = nn.InstanceNorm1d(100, affine=True) 192 >>> input = torch.randn(20, 100, 40) 193 >>> output = m(input) 194 """ 195 196 def _get_no_batch_dim(self): 197 return 2 198 199 def _check_input_dim(self, input): 200 if input.dim() not in (2, 3): 201 raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)") 202 203 204class LazyInstanceNorm1d(_LazyNormBase, _InstanceNorm): 205 r"""A :class:`torch.nn.InstanceNorm1d` module with lazy initialization of the ``num_features`` argument. 206 207 The ``num_features`` argument of the :class:`InstanceNorm1d` is inferred from the ``input.size(1)``. 208 The attributes that will be lazily initialized are `weight`, `bias`, `running_mean` and `running_var`. 209 210 Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation 211 on lazy modules and their limitations. 212 213 Args: 214 num_features: :math:`C` from an expected input of size 215 :math:`(N, C, L)` or :math:`(C, L)` 216 eps: a value added to the denominator for numerical stability. Default: 1e-5 217 momentum: the value used for the running_mean and running_var computation. Default: 0.1 218 affine: a boolean value that when set to ``True``, this module has 219 learnable affine parameters, initialized the same way as done for batch normalization. 220 Default: ``False``. 221 track_running_stats: a boolean value that when set to ``True``, this 222 module tracks the running mean and variance, and when set to ``False``, 223 this module does not track such statistics and always uses batch 224 statistics in both training and eval modes. Default: ``False`` 225 226 Shape: 227 - Input: :math:`(N, C, L)` or :math:`(C, L)` 228 - Output: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input) 229 """ 230 231 cls_to_become = InstanceNorm1d # type: ignore[assignment] 232 233 def _get_no_batch_dim(self): 234 return 2 235 236 def _check_input_dim(self, input): 237 if input.dim() not in (2, 3): 238 raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)") 239 240 241class InstanceNorm2d(_InstanceNorm): 242 r"""Applies Instance Normalization. 243 244 This operation applies Instance Normalization 245 over a 4D input (a mini-batch of 2D inputs 246 with additional channel dimension) as described in the paper 247 `Instance Normalization: The Missing Ingredient for Fast Stylization 248 <https://arxiv.org/abs/1607.08022>`__. 249 250 .. math:: 251 252 y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta 253 254 The mean and standard-deviation are calculated per-dimension separately 255 for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors 256 of size `C` (where `C` is the input size) if :attr:`affine` is ``True``. 257 The standard-deviation is calculated via the biased estimator, equivalent to 258 `torch.var(input, unbiased=False)`. 259 260 By default, this layer uses instance statistics computed from input data in 261 both training and evaluation modes. 262 263 If :attr:`track_running_stats` is set to ``True``, during training this 264 layer keeps running estimates of its computed mean and variance, which are 265 then used for normalization during evaluation. The running estimates are 266 kept with a default :attr:`momentum` of 0.1. 267 268 .. note:: 269 This :attr:`momentum` argument is different from one used in optimizer 270 classes and the conventional notion of momentum. Mathematically, the 271 update rule for running statistics here is 272 :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, 273 where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the 274 new observed value. 275 276 .. note:: 277 :class:`InstanceNorm2d` and :class:`LayerNorm` are very similar, but 278 have some subtle differences. :class:`InstanceNorm2d` is applied 279 on each channel of channeled data like RGB images, but 280 :class:`LayerNorm` is usually applied on entire sample and often in NLP 281 tasks. Additionally, :class:`LayerNorm` applies elementwise affine 282 transform, while :class:`InstanceNorm2d` usually don't apply affine 283 transform. 284 285 Args: 286 num_features: :math:`C` from an expected input of size 287 :math:`(N, C, H, W)` or :math:`(C, H, W)` 288 eps: a value added to the denominator for numerical stability. Default: 1e-5 289 momentum: the value used for the running_mean and running_var computation. Default: 0.1 290 affine: a boolean value that when set to ``True``, this module has 291 learnable affine parameters, initialized the same way as done for batch normalization. 292 Default: ``False``. 293 track_running_stats: a boolean value that when set to ``True``, this 294 module tracks the running mean and variance, and when set to ``False``, 295 this module does not track such statistics and always uses batch 296 statistics in both training and eval modes. Default: ``False`` 297 298 Shape: 299 - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)` 300 - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input) 301 302 Examples:: 303 304 >>> # Without Learnable Parameters 305 >>> m = nn.InstanceNorm2d(100) 306 >>> # With Learnable Parameters 307 >>> m = nn.InstanceNorm2d(100, affine=True) 308 >>> input = torch.randn(20, 100, 35, 45) 309 >>> output = m(input) 310 """ 311 312 def _get_no_batch_dim(self): 313 return 3 314 315 def _check_input_dim(self, input): 316 if input.dim() not in (3, 4): 317 raise ValueError(f"expected 3D or 4D input (got {input.dim()}D input)") 318 319 320class LazyInstanceNorm2d(_LazyNormBase, _InstanceNorm): 321 r"""A :class:`torch.nn.InstanceNorm2d` module with lazy initialization of the ``num_features`` argument. 322 323 The ``num_features`` argument of the :class:`InstanceNorm2d` is inferred from the ``input.size(1)``. 324 The attributes that will be lazily initialized are `weight`, `bias`, 325 `running_mean` and `running_var`. 326 327 Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation 328 on lazy modules and their limitations. 329 330 Args: 331 num_features: :math:`C` from an expected input of size 332 :math:`(N, C, H, W)` or :math:`(C, H, W)` 333 eps: a value added to the denominator for numerical stability. Default: 1e-5 334 momentum: the value used for the running_mean and running_var computation. Default: 0.1 335 affine: a boolean value that when set to ``True``, this module has 336 learnable affine parameters, initialized the same way as done for batch normalization. 337 Default: ``False``. 338 track_running_stats: a boolean value that when set to ``True``, this 339 module tracks the running mean and variance, and when set to ``False``, 340 this module does not track such statistics and always uses batch 341 statistics in both training and eval modes. Default: ``False`` 342 343 Shape: 344 - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)` 345 - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input) 346 """ 347 348 cls_to_become = InstanceNorm2d # type: ignore[assignment] 349 350 def _get_no_batch_dim(self): 351 return 3 352 353 def _check_input_dim(self, input): 354 if input.dim() not in (3, 4): 355 raise ValueError(f"expected 3D or 4D input (got {input.dim()}D input)") 356 357 358class InstanceNorm3d(_InstanceNorm): 359 r"""Applies Instance Normalization. 360 361 This operation applies Instance Normalization 362 over a 5D input (a mini-batch of 3D inputs with additional channel dimension) as described in the paper 363 `Instance Normalization: The Missing Ingredient for Fast Stylization 364 <https://arxiv.org/abs/1607.08022>`__. 365 366 .. math:: 367 368 y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta 369 370 The mean and standard-deviation are calculated per-dimension separately 371 for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors 372 of size C (where C is the input size) if :attr:`affine` is ``True``. 373 The standard-deviation is calculated via the biased estimator, equivalent to 374 `torch.var(input, unbiased=False)`. 375 376 By default, this layer uses instance statistics computed from input data in 377 both training and evaluation modes. 378 379 If :attr:`track_running_stats` is set to ``True``, during training this 380 layer keeps running estimates of its computed mean and variance, which are 381 then used for normalization during evaluation. The running estimates are 382 kept with a default :attr:`momentum` of 0.1. 383 384 .. note:: 385 This :attr:`momentum` argument is different from one used in optimizer 386 classes and the conventional notion of momentum. Mathematically, the 387 update rule for running statistics here is 388 :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, 389 where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the 390 new observed value. 391 392 .. note:: 393 :class:`InstanceNorm3d` and :class:`LayerNorm` are very similar, but 394 have some subtle differences. :class:`InstanceNorm3d` is applied 395 on each channel of channeled data like 3D models with RGB color, but 396 :class:`LayerNorm` is usually applied on entire sample and often in NLP 397 tasks. Additionally, :class:`LayerNorm` applies elementwise affine 398 transform, while :class:`InstanceNorm3d` usually don't apply affine 399 transform. 400 401 Args: 402 num_features: :math:`C` from an expected input of size 403 :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` 404 eps: a value added to the denominator for numerical stability. Default: 1e-5 405 momentum: the value used for the running_mean and running_var computation. Default: 0.1 406 affine: a boolean value that when set to ``True``, this module has 407 learnable affine parameters, initialized the same way as done for batch normalization. 408 Default: ``False``. 409 track_running_stats: a boolean value that when set to ``True``, this 410 module tracks the running mean and variance, and when set to ``False``, 411 this module does not track such statistics and always uses batch 412 statistics in both training and eval modes. Default: ``False`` 413 414 Shape: 415 - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` 416 - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input) 417 418 Examples:: 419 420 >>> # Without Learnable Parameters 421 >>> m = nn.InstanceNorm3d(100) 422 >>> # With Learnable Parameters 423 >>> m = nn.InstanceNorm3d(100, affine=True) 424 >>> input = torch.randn(20, 100, 35, 45, 10) 425 >>> output = m(input) 426 """ 427 428 def _get_no_batch_dim(self): 429 return 4 430 431 def _check_input_dim(self, input): 432 if input.dim() not in (4, 5): 433 raise ValueError(f"expected 4D or 5D input (got {input.dim()}D input)") 434 435 436class LazyInstanceNorm3d(_LazyNormBase, _InstanceNorm): 437 r"""A :class:`torch.nn.InstanceNorm3d` module with lazy initialization of the ``num_features`` argument. 438 439 The ``num_features`` argument of the :class:`InstanceNorm3d` is inferred from the ``input.size(1)``. 440 The attributes that will be lazily initialized are `weight`, `bias`, 441 `running_mean` and `running_var`. 442 443 Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation 444 on lazy modules and their limitations. 445 446 Args: 447 num_features: :math:`C` from an expected input of size 448 :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` 449 eps: a value added to the denominator for numerical stability. Default: 1e-5 450 momentum: the value used for the running_mean and running_var computation. Default: 0.1 451 affine: a boolean value that when set to ``True``, this module has 452 learnable affine parameters, initialized the same way as done for batch normalization. 453 Default: ``False``. 454 track_running_stats: a boolean value that when set to ``True``, this 455 module tracks the running mean and variance, and when set to ``False``, 456 this module does not track such statistics and always uses batch 457 statistics in both training and eval modes. Default: ``False`` 458 459 Shape: 460 - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` 461 - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input) 462 """ 463 464 cls_to_become = InstanceNorm3d # type: ignore[assignment] 465 466 def _get_no_batch_dim(self): 467 return 4 468 469 def _check_input_dim(self, input): 470 if input.dim() not in (4, 5): 471 raise ValueError(f"expected 4D or 5D input (got {input.dim()}D input)") 472