1# mypy: allow-untyped-defs 2import warnings 3from typing import Optional, Tuple 4 5import torch 6import torch.nn.functional as F 7from torch import Tensor 8from torch.nn.init import constant_, xavier_normal_, xavier_uniform_ 9from torch.nn.parameter import Parameter 10 11from .linear import NonDynamicallyQuantizableLinear 12from .module import Module 13 14 15__all__ = [ 16 "Threshold", 17 "ReLU", 18 "RReLU", 19 "Hardtanh", 20 "ReLU6", 21 "Sigmoid", 22 "Hardsigmoid", 23 "Tanh", 24 "SiLU", 25 "Mish", 26 "Hardswish", 27 "ELU", 28 "CELU", 29 "SELU", 30 "GLU", 31 "GELU", 32 "Hardshrink", 33 "LeakyReLU", 34 "LogSigmoid", 35 "Softplus", 36 "Softshrink", 37 "MultiheadAttention", 38 "PReLU", 39 "Softsign", 40 "Tanhshrink", 41 "Softmin", 42 "Softmax", 43 "Softmax2d", 44 "LogSoftmax", 45] 46 47 48class Threshold(Module): 49 r"""Thresholds each element of the input Tensor. 50 51 Threshold is defined as: 52 53 .. math:: 54 y = 55 \begin{cases} 56 x, &\text{ if } x > \text{threshold} \\ 57 \text{value}, &\text{ otherwise } 58 \end{cases} 59 60 Args: 61 threshold: The value to threshold at 62 value: The value to replace with 63 inplace: can optionally do the operation in-place. Default: ``False`` 64 65 Shape: 66 - Input: :math:`(*)`, where :math:`*` means any number of dimensions. 67 - Output: :math:`(*)`, same shape as the input. 68 69 Examples:: 70 71 >>> m = nn.Threshold(0.1, 20) 72 >>> input = torch.randn(2) 73 >>> output = m(input) 74 """ 75 76 __constants__ = ["threshold", "value", "inplace"] 77 78 threshold: float 79 value: float 80 inplace: bool 81 82 def __init__(self, threshold: float, value: float, inplace: bool = False) -> None: 83 super().__init__() 84 self.threshold = threshold 85 self.value = value 86 self.inplace = inplace 87 # TODO: check in THNN (if inplace == True, then assert value <= threshold) 88 89 def forward(self, input: Tensor) -> Tensor: 90 return F.threshold(input, self.threshold, self.value, self.inplace) 91 92 def extra_repr(self): 93 inplace_str = ", inplace=True" if self.inplace else "" 94 return f"threshold={self.threshold}, value={self.value}{inplace_str}" 95 96 97class ReLU(Module): 98 r"""Applies the rectified linear unit function element-wise. 99 100 :math:`\text{ReLU}(x) = (x)^+ = \max(0, x)` 101 102 Args: 103 inplace: can optionally do the operation in-place. Default: ``False`` 104 105 Shape: 106 - Input: :math:`(*)`, where :math:`*` means any number of dimensions. 107 - Output: :math:`(*)`, same shape as the input. 108 109 .. image:: ../scripts/activation_images/ReLU.png 110 111 Examples:: 112 113 >>> m = nn.ReLU() 114 >>> input = torch.randn(2) 115 >>> output = m(input) 116 117 118 An implementation of CReLU - https://arxiv.org/abs/1603.05201 119 120 >>> m = nn.ReLU() 121 >>> input = torch.randn(2).unsqueeze(0) 122 >>> output = torch.cat((m(input), m(-input))) 123 """ 124 125 __constants__ = ["inplace"] 126 inplace: bool 127 128 def __init__(self, inplace: bool = False): 129 super().__init__() 130 self.inplace = inplace 131 132 def forward(self, input: Tensor) -> Tensor: 133 return F.relu(input, inplace=self.inplace) 134 135 def extra_repr(self) -> str: 136 inplace_str = "inplace=True" if self.inplace else "" 137 return inplace_str 138 139 140class RReLU(Module): 141 r"""Applies the randomized leaky rectified linear unit function, element-wise. 142 143 Method described in the paper: 144 `Empirical Evaluation of Rectified Activations in Convolutional Network <https://arxiv.org/abs/1505.00853>`_. 145 146 The function is defined as: 147 148 .. math:: 149 \text{RReLU}(x) = 150 \begin{cases} 151 x & \text{if } x \geq 0 \\ 152 ax & \text{ otherwise } 153 \end{cases} 154 155 where :math:`a` is randomly sampled from uniform distribution 156 :math:`\mathcal{U}(\text{lower}, \text{upper})` during training while during 157 evaluation :math:`a` is fixed with :math:`a = \frac{\text{lower} + \text{upper}}{2}`. 158 159 Args: 160 lower: lower bound of the uniform distribution. Default: :math:`\frac{1}{8}` 161 upper: upper bound of the uniform distribution. Default: :math:`\frac{1}{3}` 162 inplace: can optionally do the operation in-place. Default: ``False`` 163 164 Shape: 165 - Input: :math:`(*)`, where :math:`*` means any number of dimensions. 166 - Output: :math:`(*)`, same shape as the input. 167 168 .. image:: ../scripts/activation_images/RReLU.png 169 170 Examples:: 171 172 >>> m = nn.RReLU(0.1, 0.3) 173 >>> input = torch.randn(2) 174 >>> output = m(input) 175 176 """ 177 178 __constants__ = ["lower", "upper", "inplace"] 179 180 lower: float 181 upper: float 182 inplace: bool 183 184 def __init__( 185 self, lower: float = 1.0 / 8, upper: float = 1.0 / 3, inplace: bool = False 186 ): 187 super().__init__() 188 self.lower = lower 189 self.upper = upper 190 self.inplace = inplace 191 192 def forward(self, input: Tensor) -> Tensor: 193 return F.rrelu(input, self.lower, self.upper, self.training, self.inplace) 194 195 def extra_repr(self): 196 inplace_str = ", inplace=True" if self.inplace else "" 197 return f"lower={self.lower}, upper={self.upper}{inplace_str}" 198 199 200class Hardtanh(Module): 201 r"""Applies the HardTanh function element-wise. 202 203 HardTanh is defined as: 204 205 .. math:: 206 \text{HardTanh}(x) = \begin{cases} 207 \text{max\_val} & \text{ if } x > \text{ max\_val } \\ 208 \text{min\_val} & \text{ if } x < \text{ min\_val } \\ 209 x & \text{ otherwise } \\ 210 \end{cases} 211 212 Args: 213 min_val: minimum value of the linear region range. Default: -1 214 max_val: maximum value of the linear region range. Default: 1 215 inplace: can optionally do the operation in-place. Default: ``False`` 216 217 Keyword arguments :attr:`min_value` and :attr:`max_value` 218 have been deprecated in favor of :attr:`min_val` and :attr:`max_val`. 219 220 Shape: 221 - Input: :math:`(*)`, where :math:`*` means any number of dimensions. 222 - Output: :math:`(*)`, same shape as the input. 223 224 .. image:: ../scripts/activation_images/Hardtanh.png 225 226 Examples:: 227 228 >>> m = nn.Hardtanh(-2, 2) 229 >>> input = torch.randn(2) 230 >>> output = m(input) 231 """ 232 233 __constants__ = ["min_val", "max_val", "inplace"] 234 235 min_val: float 236 max_val: float 237 inplace: bool 238 239 def __init__( 240 self, 241 min_val: float = -1.0, 242 max_val: float = 1.0, 243 inplace: bool = False, 244 min_value: Optional[float] = None, 245 max_value: Optional[float] = None, 246 ) -> None: 247 super().__init__() 248 if min_value is not None: 249 warnings.warn( 250 "keyword argument `min_value` is deprecated and rename to `min_val`", 251 FutureWarning, 252 stacklevel=2, 253 ) 254 min_val = min_value 255 if max_value is not None: 256 warnings.warn( 257 "keyword argument `max_value` is deprecated and rename to `max_val`", 258 FutureWarning, 259 stacklevel=2, 260 ) 261 max_val = max_value 262 263 self.min_val = min_val 264 self.max_val = max_val 265 self.inplace = inplace 266 assert self.max_val > self.min_val 267 268 def forward(self, input: Tensor) -> Tensor: 269 return F.hardtanh(input, self.min_val, self.max_val, self.inplace) 270 271 def extra_repr(self) -> str: 272 inplace_str = ", inplace=True" if self.inplace else "" 273 return f"min_val={self.min_val}, max_val={self.max_val}{inplace_str}" 274 275 276class ReLU6(Hardtanh): 277 r"""Applies the ReLU6 function element-wise. 278 279 .. math:: 280 \text{ReLU6}(x) = \min(\max(0,x), 6) 281 282 Args: 283 inplace: can optionally do the operation in-place. Default: ``False`` 284 285 Shape: 286 - Input: :math:`(*)`, where :math:`*` means any number of dimensions. 287 - Output: :math:`(*)`, same shape as the input. 288 289 .. image:: ../scripts/activation_images/ReLU6.png 290 291 Examples:: 292 293 >>> m = nn.ReLU6() 294 >>> input = torch.randn(2) 295 >>> output = m(input) 296 """ 297 298 def __init__(self, inplace: bool = False): 299 super().__init__(0.0, 6.0, inplace) 300 301 def extra_repr(self) -> str: 302 inplace_str = "inplace=True" if self.inplace else "" 303 return inplace_str 304 305 306class Sigmoid(Module): 307 r"""Applies the Sigmoid function element-wise. 308 309 .. math:: 310 \text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)} 311 312 313 Shape: 314 - Input: :math:`(*)`, where :math:`*` means any number of dimensions. 315 - Output: :math:`(*)`, same shape as the input. 316 317 .. image:: ../scripts/activation_images/Sigmoid.png 318 319 Examples:: 320 321 >>> m = nn.Sigmoid() 322 >>> input = torch.randn(2) 323 >>> output = m(input) 324 """ 325 326 def forward(self, input: Tensor) -> Tensor: 327 return torch.sigmoid(input) 328 329 330class Hardsigmoid(Module): 331 r"""Applies the Hardsigmoid function element-wise. 332 333 Hardsigmoid is defined as: 334 335 .. math:: 336 \text{Hardsigmoid}(x) = \begin{cases} 337 0 & \text{if~} x \le -3, \\ 338 1 & \text{if~} x \ge +3, \\ 339 x / 6 + 1 / 2 & \text{otherwise} 340 \end{cases} 341 342 Args: 343 inplace: can optionally do the operation in-place. Default: ``False`` 344 345 Shape: 346 - Input: :math:`(*)`, where :math:`*` means any number of dimensions. 347 - Output: :math:`(*)`, same shape as the input. 348 349 .. image:: ../scripts/activation_images/Hardsigmoid.png 350 351 Examples:: 352 353 >>> m = nn.Hardsigmoid() 354 >>> input = torch.randn(2) 355 >>> output = m(input) 356 """ 357 358 __constants__ = ["inplace"] 359 360 inplace: bool 361 362 def __init__(self, inplace: bool = False) -> None: 363 super().__init__() 364 self.inplace = inplace 365 366 def forward(self, input: Tensor) -> Tensor: 367 return F.hardsigmoid(input, self.inplace) 368 369 370class Tanh(Module): 371 r"""Applies the Hyperbolic Tangent (Tanh) function element-wise. 372 373 Tanh is defined as: 374 375 .. math:: 376 \text{Tanh}(x) = \tanh(x) = \frac{\exp(x) - \exp(-x)} {\exp(x) + \exp(-x)} 377 378 Shape: 379 - Input: :math:`(*)`, where :math:`*` means any number of dimensions. 380 - Output: :math:`(*)`, same shape as the input. 381 382 .. image:: ../scripts/activation_images/Tanh.png 383 384 Examples:: 385 386 >>> m = nn.Tanh() 387 >>> input = torch.randn(2) 388 >>> output = m(input) 389 """ 390 391 def forward(self, input: Tensor) -> Tensor: 392 return torch.tanh(input) 393 394 395class SiLU(Module): 396 r"""Applies the Sigmoid Linear Unit (SiLU) function, element-wise. 397 398 The SiLU function is also known as the swish function. 399 400 .. math:: 401 \text{silu}(x) = x * \sigma(x), \text{where } \sigma(x) \text{ is the logistic sigmoid.} 402 403 .. note:: 404 See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_ 405 where the SiLU (Sigmoid Linear Unit) was originally coined, and see 406 `Sigmoid-Weighted Linear Units for Neural Network Function Approximation 407 in Reinforcement Learning <https://arxiv.org/abs/1702.03118>`_ and `Swish: 408 a Self-Gated Activation Function <https://arxiv.org/abs/1710.05941v1>`_ 409 where the SiLU was experimented with later. 410 411 Shape: 412 - Input: :math:`(*)`, where :math:`*` means any number of dimensions. 413 - Output: :math:`(*)`, same shape as the input. 414 415 .. image:: ../scripts/activation_images/SiLU.png 416 417 Examples:: 418 419 >>> m = nn.SiLU() 420 >>> input = torch.randn(2) 421 >>> output = m(input) 422 """ 423 424 __constants__ = ["inplace"] 425 inplace: bool 426 427 def __init__(self, inplace: bool = False): 428 super().__init__() 429 self.inplace = inplace 430 431 def forward(self, input: Tensor) -> Tensor: 432 return F.silu(input, inplace=self.inplace) 433 434 def extra_repr(self) -> str: 435 inplace_str = "inplace=True" if self.inplace else "" 436 return inplace_str 437 438 439class Mish(Module): 440 r"""Applies the Mish function, element-wise. 441 442 Mish: A Self Regularized Non-Monotonic Neural Activation Function. 443 444 .. math:: 445 \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x)) 446 447 .. note:: 448 See `Mish: A Self Regularized Non-Monotonic Neural Activation Function <https://arxiv.org/abs/1908.08681>`_ 449 450 Shape: 451 - Input: :math:`(*)`, where :math:`*` means any number of dimensions. 452 - Output: :math:`(*)`, same shape as the input. 453 454 .. image:: ../scripts/activation_images/Mish.png 455 456 Examples:: 457 458 >>> m = nn.Mish() 459 >>> input = torch.randn(2) 460 >>> output = m(input) 461 """ 462 463 __constants__ = ["inplace"] 464 inplace: bool 465 466 def __init__(self, inplace: bool = False): 467 super().__init__() 468 self.inplace = inplace 469 470 def forward(self, input: Tensor) -> Tensor: 471 return F.mish(input, inplace=self.inplace) 472 473 def extra_repr(self) -> str: 474 inplace_str = "inplace=True" if self.inplace else "" 475 return inplace_str 476 477 478class Hardswish(Module): 479 r"""Applies the Hardswish function, element-wise. 480 481 Method described in the paper: `Searching for MobileNetV3 <https://arxiv.org/abs/1905.02244>`_. 482 483 Hardswish is defined as: 484 485 .. math:: 486 \text{Hardswish}(x) = \begin{cases} 487 0 & \text{if~} x \le -3, \\ 488 x & \text{if~} x \ge +3, \\ 489 x \cdot (x + 3) /6 & \text{otherwise} 490 \end{cases} 491 492 Args: 493 inplace: can optionally do the operation in-place. Default: ``False`` 494 495 Shape: 496 - Input: :math:`(*)`, where :math:`*` means any number of dimensions. 497 - Output: :math:`(*)`, same shape as the input. 498 499 .. image:: ../scripts/activation_images/Hardswish.png 500 501 Examples:: 502 503 >>> m = nn.Hardswish() 504 >>> input = torch.randn(2) 505 >>> output = m(input) 506 """ 507 508 __constants__ = ["inplace"] 509 510 inplace: bool 511 512 def __init__(self, inplace: bool = False) -> None: 513 super().__init__() 514 self.inplace = inplace 515 516 def forward(self, input: Tensor) -> Tensor: 517 return F.hardswish(input, self.inplace) 518 519 520class ELU(Module): 521 r"""Applies the Exponential Linear Unit (ELU) function, element-wise. 522 523 Method described in the paper: `Fast and Accurate Deep Network Learning by Exponential Linear 524 Units (ELUs) <https://arxiv.org/abs/1511.07289>`__. 525 526 ELU is defined as: 527 528 .. math:: 529 \text{ELU}(x) = \begin{cases} 530 x, & \text{ if } x > 0\\ 531 \alpha * (\exp(x) - 1), & \text{ if } x \leq 0 532 \end{cases} 533 534 Args: 535 alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0 536 inplace: can optionally do the operation in-place. Default: ``False`` 537 538 Shape: 539 - Input: :math:`(*)`, where :math:`*` means any number of dimensions. 540 - Output: :math:`(*)`, same shape as the input. 541 542 .. image:: ../scripts/activation_images/ELU.png 543 544 Examples:: 545 546 >>> m = nn.ELU() 547 >>> input = torch.randn(2) 548 >>> output = m(input) 549 """ 550 551 __constants__ = ["alpha", "inplace"] 552 alpha: float 553 inplace: bool 554 555 def __init__(self, alpha: float = 1.0, inplace: bool = False) -> None: 556 super().__init__() 557 self.alpha = alpha 558 self.inplace = inplace 559 560 def forward(self, input: Tensor) -> Tensor: 561 return F.elu(input, self.alpha, self.inplace) 562 563 def extra_repr(self) -> str: 564 inplace_str = ", inplace=True" if self.inplace else "" 565 return f"alpha={self.alpha}{inplace_str}" 566 567 568class CELU(Module): 569 r"""Applies the CELU function element-wise. 570 571 .. math:: 572 \text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1)) 573 574 More details can be found in the paper `Continuously Differentiable Exponential Linear Units`_ . 575 576 Args: 577 alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0 578 inplace: can optionally do the operation in-place. Default: ``False`` 579 580 Shape: 581 - Input: :math:`(*)`, where :math:`*` means any number of dimensions. 582 - Output: :math:`(*)`, same shape as the input. 583 584 .. image:: ../scripts/activation_images/CELU.png 585 586 Examples:: 587 588 >>> m = nn.CELU() 589 >>> input = torch.randn(2) 590 >>> output = m(input) 591 592 .. _`Continuously Differentiable Exponential Linear Units`: 593 https://arxiv.org/abs/1704.07483 594 """ 595 596 __constants__ = ["alpha", "inplace"] 597 alpha: float 598 inplace: bool 599 600 def __init__(self, alpha: float = 1.0, inplace: bool = False) -> None: 601 super().__init__() 602 self.alpha = alpha 603 self.inplace = inplace 604 605 def forward(self, input: Tensor) -> Tensor: 606 return F.celu(input, self.alpha, self.inplace) 607 608 def extra_repr(self) -> str: 609 inplace_str = ", inplace=True" if self.inplace else "" 610 return f"alpha={self.alpha}{inplace_str}" 611 612 613class SELU(Module): 614 r"""Applies the SELU function element-wise. 615 616 .. math:: 617 \text{SELU}(x) = \text{scale} * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1))) 618 619 with :math:`\alpha = 1.6732632423543772848170429916717` and 620 :math:`\text{scale} = 1.0507009873554804934193349852946`. 621 622 .. warning:: 623 When using ``kaiming_normal`` or ``kaiming_normal_`` for initialisation, 624 ``nonlinearity='linear'`` should be used instead of ``nonlinearity='selu'`` 625 in order to get `Self-Normalizing Neural Networks`_. 626 See :func:`torch.nn.init.calculate_gain` for more information. 627 628 More details can be found in the paper `Self-Normalizing Neural Networks`_ . 629 630 Args: 631 inplace (bool, optional): can optionally do the operation in-place. Default: ``False`` 632 633 Shape: 634 - Input: :math:`(*)`, where :math:`*` means any number of dimensions. 635 - Output: :math:`(*)`, same shape as the input. 636 637 .. image:: ../scripts/activation_images/SELU.png 638 639 Examples:: 640 641 >>> m = nn.SELU() 642 >>> input = torch.randn(2) 643 >>> output = m(input) 644 645 .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515 646 """ 647 648 __constants__ = ["inplace"] 649 inplace: bool 650 651 def __init__(self, inplace: bool = False) -> None: 652 super().__init__() 653 self.inplace = inplace 654 655 def forward(self, input: Tensor) -> Tensor: 656 return F.selu(input, self.inplace) 657 658 def extra_repr(self) -> str: 659 inplace_str = "inplace=True" if self.inplace else "" 660 return inplace_str 661 662 663class GLU(Module): 664 r"""Applies the gated linear unit function. 665 666 :math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half 667 of the input matrices and :math:`b` is the second half. 668 669 Args: 670 dim (int): the dimension on which to split the input. Default: -1 671 672 Shape: 673 - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional 674 dimensions 675 - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2` 676 677 Examples:: 678 679 >>> m = nn.GLU() 680 >>> input = torch.randn(4, 2) 681 >>> output = m(input) 682 """ 683 684 __constants__ = ["dim"] 685 dim: int 686 687 def __init__(self, dim: int = -1) -> None: 688 super().__init__() 689 self.dim = dim 690 691 def forward(self, input: Tensor) -> Tensor: 692 return F.glu(input, self.dim) 693 694 def extra_repr(self) -> str: 695 return f"dim={self.dim}" 696 697 698class GELU(Module): 699 r"""Applies the Gaussian Error Linear Units function. 700 701 .. math:: \text{GELU}(x) = x * \Phi(x) 702 703 where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution. 704 705 When the approximate argument is 'tanh', Gelu is estimated with: 706 707 .. math:: \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt{2 / \pi} * (x + 0.044715 * x^3))) 708 709 Args: 710 approximate (str, optional): the gelu approximation algorithm to use: 711 ``'none'`` | ``'tanh'``. Default: ``'none'`` 712 713 Shape: 714 - Input: :math:`(*)`, where :math:`*` means any number of dimensions. 715 - Output: :math:`(*)`, same shape as the input. 716 717 .. image:: ../scripts/activation_images/GELU.png 718 719 Examples:: 720 721 >>> m = nn.GELU() 722 >>> input = torch.randn(2) 723 >>> output = m(input) 724 """ 725 726 __constants__ = ["approximate"] 727 approximate: str 728 729 def __init__(self, approximate: str = "none") -> None: 730 super().__init__() 731 self.approximate = approximate 732 733 def forward(self, input: Tensor) -> Tensor: 734 return F.gelu(input, approximate=self.approximate) 735 736 def extra_repr(self) -> str: 737 return f"approximate={repr(self.approximate)}" 738 739 740class Hardshrink(Module): 741 r"""Applies the Hard Shrinkage (Hardshrink) function element-wise. 742 743 Hardshrink is defined as: 744 745 .. math:: 746 \text{HardShrink}(x) = 747 \begin{cases} 748 x, & \text{ if } x > \lambda \\ 749 x, & \text{ if } x < -\lambda \\ 750 0, & \text{ otherwise } 751 \end{cases} 752 753 Args: 754 lambd: the :math:`\lambda` value for the Hardshrink formulation. Default: 0.5 755 756 Shape: 757 - Input: :math:`(*)`, where :math:`*` means any number of dimensions. 758 - Output: :math:`(*)`, same shape as the input. 759 760 .. image:: ../scripts/activation_images/Hardshrink.png 761 762 Examples:: 763 764 >>> m = nn.Hardshrink() 765 >>> input = torch.randn(2) 766 >>> output = m(input) 767 """ 768 769 __constants__ = ["lambd"] 770 lambd: float 771 772 def __init__(self, lambd: float = 0.5) -> None: 773 super().__init__() 774 self.lambd = lambd 775 776 def forward(self, input: Tensor) -> Tensor: 777 return F.hardshrink(input, self.lambd) 778 779 def extra_repr(self) -> str: 780 return f"{self.lambd}" 781 782 783class LeakyReLU(Module): 784 r"""Applies the LeakyReLU function element-wise. 785 786 .. math:: 787 \text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x) 788 789 790 or 791 792 .. math:: 793 \text{LeakyReLU}(x) = 794 \begin{cases} 795 x, & \text{ if } x \geq 0 \\ 796 \text{negative\_slope} \times x, & \text{ otherwise } 797 \end{cases} 798 799 Args: 800 negative_slope: Controls the angle of the negative slope (which is used for 801 negative input values). Default: 1e-2 802 inplace: can optionally do the operation in-place. Default: ``False`` 803 804 Shape: 805 - Input: :math:`(*)` where `*` means, any number of additional 806 dimensions 807 - Output: :math:`(*)`, same shape as the input 808 809 .. image:: ../scripts/activation_images/LeakyReLU.png 810 811 Examples:: 812 813 >>> m = nn.LeakyReLU(0.1) 814 >>> input = torch.randn(2) 815 >>> output = m(input) 816 """ 817 818 __constants__ = ["inplace", "negative_slope"] 819 inplace: bool 820 negative_slope: float 821 822 def __init__(self, negative_slope: float = 1e-2, inplace: bool = False) -> None: 823 super().__init__() 824 self.negative_slope = negative_slope 825 self.inplace = inplace 826 827 def forward(self, input: Tensor) -> Tensor: 828 return F.leaky_relu(input, self.negative_slope, self.inplace) 829 830 def extra_repr(self) -> str: 831 inplace_str = ", inplace=True" if self.inplace else "" 832 return f"negative_slope={self.negative_slope}{inplace_str}" 833 834 835class LogSigmoid(Module): 836 r"""Applies the Logsigmoid function element-wise. 837 838 .. math:: 839 \text{LogSigmoid}(x) = \log\left(\frac{ 1 }{ 1 + \exp(-x)}\right) 840 841 Shape: 842 - Input: :math:`(*)`, where :math:`*` means any number of dimensions. 843 - Output: :math:`(*)`, same shape as the input. 844 845 .. image:: ../scripts/activation_images/LogSigmoid.png 846 847 Examples:: 848 849 >>> m = nn.LogSigmoid() 850 >>> input = torch.randn(2) 851 >>> output = m(input) 852 """ 853 854 def forward(self, input: Tensor) -> Tensor: 855 return F.logsigmoid(input) 856 857 858class Softplus(Module): 859 r"""Applies the Softplus function element-wise. 860 861 .. math:: 862 \text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x)) 863 864 SoftPlus is a smooth approximation to the ReLU function and can be used 865 to constrain the output of a machine to always be positive. 866 867 For numerical stability the implementation reverts to the linear function 868 when :math:`input \times \beta > threshold`. 869 870 Args: 871 beta: the :math:`\beta` value for the Softplus formulation. Default: 1 872 threshold: values above this revert to a linear function. Default: 20 873 874 Shape: 875 - Input: :math:`(*)`, where :math:`*` means any number of dimensions. 876 - Output: :math:`(*)`, same shape as the input. 877 878 .. image:: ../scripts/activation_images/Softplus.png 879 880 Examples:: 881 882 >>> m = nn.Softplus() 883 >>> input = torch.randn(2) 884 >>> output = m(input) 885 """ 886 887 __constants__ = ["beta", "threshold"] 888 beta: float 889 threshold: float 890 891 def __init__(self, beta: float = 1.0, threshold: float = 20.0) -> None: 892 super().__init__() 893 self.beta = beta 894 self.threshold = threshold 895 896 def forward(self, input: Tensor) -> Tensor: 897 return F.softplus(input, self.beta, self.threshold) 898 899 def extra_repr(self) -> str: 900 return f"beta={self.beta}, threshold={self.threshold}" 901 902 903class Softshrink(Module): 904 r"""Applies the soft shrinkage function element-wise. 905 906 .. math:: 907 \text{SoftShrinkage}(x) = 908 \begin{cases} 909 x - \lambda, & \text{ if } x > \lambda \\ 910 x + \lambda, & \text{ if } x < -\lambda \\ 911 0, & \text{ otherwise } 912 \end{cases} 913 914 Args: 915 lambd: the :math:`\lambda` (must be no less than zero) value for the Softshrink formulation. Default: 0.5 916 917 Shape: 918 - Input: :math:`(*)`, where :math:`*` means any number of dimensions. 919 - Output: :math:`(*)`, same shape as the input. 920 921 .. image:: ../scripts/activation_images/Softshrink.png 922 923 Examples:: 924 925 >>> m = nn.Softshrink() 926 >>> input = torch.randn(2) 927 >>> output = m(input) 928 """ 929 930 __constants__ = ["lambd"] 931 lambd: float 932 933 def __init__(self, lambd: float = 0.5) -> None: 934 super().__init__() 935 self.lambd = lambd 936 937 def forward(self, input: Tensor) -> Tensor: 938 return F.softshrink(input, self.lambd) 939 940 def extra_repr(self) -> str: 941 return str(self.lambd) 942 943 944def _check_arg_device(x: Optional[torch.Tensor]) -> bool: 945 if x is not None: 946 return x.device.type in [ 947 "cpu", 948 "cuda", 949 torch.utils.backend_registration._privateuse1_backend_name, 950 ] 951 return True 952 953 954def _arg_requires_grad(x: Optional[torch.Tensor]) -> bool: 955 if x is not None: 956 return x.requires_grad 957 return False 958 959 960def _is_make_fx_tracing(): 961 if not torch.jit.is_scripting(): 962 torch_dispatch_mode_stack = ( 963 torch.utils._python_dispatch._get_current_dispatch_mode_stack() 964 ) 965 return any( 966 type(x) == torch.fx.experimental.proxy_tensor.ProxyTorchDispatchMode 967 for x in torch_dispatch_mode_stack 968 ) 969 else: 970 return False 971 972 973class MultiheadAttention(Module): 974 r"""Allows the model to jointly attend to information from different representation subspaces. 975 976 Method described in the paper: 977 `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_. 978 979 Multi-Head Attention is defined as: 980 981 .. math:: 982 \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O 983 984 where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. 985 986 ``nn.MultiHeadAttention`` will use the optimized implementations of 987 ``scaled_dot_product_attention()`` when possible. 988 989 In addition to support for the new ``scaled_dot_product_attention()`` 990 function, for speeding up Inference, MHA will use 991 fastpath inference with support for Nested Tensors, iff: 992 993 - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor). 994 - inputs are batched (3D) with ``batch_first==True`` 995 - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad`` 996 - training is disabled (using ``.eval()``) 997 - ``add_bias_kv`` is ``False`` 998 - ``add_zero_attn`` is ``False`` 999 - ``kdim`` and ``vdim`` are equal to ``embed_dim`` 1000 - if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask`` 1001 nor ``attn_mask`` is passed 1002 - autocast is disabled 1003 1004 If the optimized inference fastpath implementation is in use, a 1005 `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for 1006 ``query``/``key``/``value`` to represent padding more efficiently than using a 1007 padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ 1008 will be returned, and an additional speedup proportional to the fraction of the input 1009 that is padding can be expected. 1010 1011 Args: 1012 embed_dim: Total dimension of the model. 1013 num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split 1014 across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``). 1015 dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout). 1016 bias: If specified, adds bias to input / output projection layers. Default: ``True``. 1017 add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``. 1018 add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1. 1019 Default: ``False``. 1020 kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``). 1021 vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``). 1022 batch_first: If ``True``, then the input and output tensors are provided 1023 as (batch, seq, feature). Default: ``False`` (seq, batch, feature). 1024 1025 Examples:: 1026 1027 >>> # xdoctest: +SKIP 1028 >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) 1029 >>> attn_output, attn_output_weights = multihead_attn(query, key, value) 1030 1031 .. _`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`: 1032 https://arxiv.org/abs/2205.14135 1033 1034 """ 1035 1036 __constants__ = ["batch_first"] 1037 bias_k: Optional[torch.Tensor] 1038 bias_v: Optional[torch.Tensor] 1039 1040 def __init__( 1041 self, 1042 embed_dim, 1043 num_heads, 1044 dropout=0.0, 1045 bias=True, 1046 add_bias_kv=False, 1047 add_zero_attn=False, 1048 kdim=None, 1049 vdim=None, 1050 batch_first=False, 1051 device=None, 1052 dtype=None, 1053 ) -> None: 1054 if embed_dim <= 0 or num_heads <= 0: 1055 raise ValueError( 1056 f"embed_dim and num_heads must be greater than 0," 1057 f" got embed_dim={embed_dim} and num_heads={num_heads} instead" 1058 ) 1059 factory_kwargs = {"device": device, "dtype": dtype} 1060 super().__init__() 1061 self.embed_dim = embed_dim 1062 self.kdim = kdim if kdim is not None else embed_dim 1063 self.vdim = vdim if vdim is not None else embed_dim 1064 self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim 1065 1066 self.num_heads = num_heads 1067 self.dropout = dropout 1068 self.batch_first = batch_first 1069 self.head_dim = embed_dim // num_heads 1070 assert ( 1071 self.head_dim * num_heads == self.embed_dim 1072 ), "embed_dim must be divisible by num_heads" 1073 1074 if not self._qkv_same_embed_dim: 1075 self.q_proj_weight = Parameter( 1076 torch.empty((embed_dim, embed_dim), **factory_kwargs) 1077 ) 1078 self.k_proj_weight = Parameter( 1079 torch.empty((embed_dim, self.kdim), **factory_kwargs) 1080 ) 1081 self.v_proj_weight = Parameter( 1082 torch.empty((embed_dim, self.vdim), **factory_kwargs) 1083 ) 1084 self.register_parameter("in_proj_weight", None) 1085 else: 1086 self.in_proj_weight = Parameter( 1087 torch.empty((3 * embed_dim, embed_dim), **factory_kwargs) 1088 ) 1089 self.register_parameter("q_proj_weight", None) 1090 self.register_parameter("k_proj_weight", None) 1091 self.register_parameter("v_proj_weight", None) 1092 1093 if bias: 1094 self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs)) 1095 else: 1096 self.register_parameter("in_proj_bias", None) 1097 self.out_proj = NonDynamicallyQuantizableLinear( 1098 embed_dim, embed_dim, bias=bias, **factory_kwargs 1099 ) 1100 1101 if add_bias_kv: 1102 self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) 1103 self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) 1104 else: 1105 self.bias_k = self.bias_v = None 1106 1107 self.add_zero_attn = add_zero_attn 1108 1109 self._reset_parameters() 1110 1111 def _reset_parameters(self): 1112 if self._qkv_same_embed_dim: 1113 xavier_uniform_(self.in_proj_weight) 1114 else: 1115 xavier_uniform_(self.q_proj_weight) 1116 xavier_uniform_(self.k_proj_weight) 1117 xavier_uniform_(self.v_proj_weight) 1118 1119 if self.in_proj_bias is not None: 1120 constant_(self.in_proj_bias, 0.0) 1121 constant_(self.out_proj.bias, 0.0) 1122 if self.bias_k is not None: 1123 xavier_normal_(self.bias_k) 1124 if self.bias_v is not None: 1125 xavier_normal_(self.bias_v) 1126 1127 def __setstate__(self, state): 1128 # Support loading old MultiheadAttention checkpoints generated by v1.1.0 1129 if "_qkv_same_embed_dim" not in state: 1130 state["_qkv_same_embed_dim"] = True 1131 1132 super().__setstate__(state) 1133 1134 def forward( 1135 self, 1136 query: Tensor, 1137 key: Tensor, 1138 value: Tensor, 1139 key_padding_mask: Optional[Tensor] = None, 1140 need_weights: bool = True, 1141 attn_mask: Optional[Tensor] = None, 1142 average_attn_weights: bool = True, 1143 is_causal: bool = False, 1144 ) -> Tuple[Tensor, Optional[Tensor]]: 1145 r"""Compute attention outputs using query, key, and value embeddings. 1146 1147 Supports optional parameters for padding, masks and attention weights. 1148 1149 Args: 1150 query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False`` 1151 or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length, 1152 :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``. 1153 Queries are compared against key-value pairs to produce the output. 1154 See "Attention Is All You Need" for more details. 1155 key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False`` 1156 or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length, 1157 :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``. 1158 See "Attention Is All You Need" for more details. 1159 value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when 1160 ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source 1161 sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``. 1162 See "Attention Is All You Need" for more details. 1163 key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` 1164 to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. 1165 Binary and float masks are supported. 1166 For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for 1167 the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value. 1168 need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``. 1169 Set ``need_weights=False`` to use the optimized ``scaled_dot_product_attention`` 1170 and achieve the best performance for MHA. 1171 Default: ``True``. 1172 attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape 1173 :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size, 1174 :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be 1175 broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. 1176 Binary and float masks are supported. For a binary mask, a ``True`` value indicates that the 1177 corresponding position is not allowed to attend. For a float mask, the mask values will be added to 1178 the attention weight. 1179 If both attn_mask and key_padding_mask are supplied, their types should match. 1180 average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across 1181 heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an 1182 effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads) 1183 is_causal: If specified, applies a causal mask as attention mask. 1184 Default: ``False``. 1185 Warning: 1186 ``is_causal`` provides a hint that ``attn_mask`` is the 1187 causal mask. Providing incorrect hints can result in 1188 incorrect execution, including forward and backward 1189 compatibility. 1190 1191 Outputs: 1192 - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched, 1193 :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``, 1194 where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the 1195 embedding dimension ``embed_dim``. 1196 - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``, 1197 returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or 1198 :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and 1199 :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per 1200 head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`. 1201 1202 .. note:: 1203 `batch_first` argument is ignored for unbatched inputs. 1204 """ # noqa: B950 1205 why_not_fast_path = "" 1206 if ( 1207 (attn_mask is not None and torch.is_floating_point(attn_mask)) 1208 or (key_padding_mask is not None) 1209 and torch.is_floating_point(key_padding_mask) 1210 ): 1211 why_not_fast_path = "floating-point masks are not supported for fast path." 1212 1213 is_batched = query.dim() == 3 1214 1215 key_padding_mask = F._canonical_mask( 1216 mask=key_padding_mask, 1217 mask_name="key_padding_mask", 1218 other_type=F._none_or_dtype(attn_mask), 1219 other_name="attn_mask", 1220 target_type=query.dtype, 1221 ) 1222 1223 attn_mask = F._canonical_mask( 1224 mask=attn_mask, 1225 mask_name="attn_mask", 1226 other_type=None, 1227 other_name="", 1228 target_type=query.dtype, 1229 check_other=False, 1230 ) 1231 1232 is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled() 1233 1234 if not is_fastpath_enabled: 1235 why_not_fast_path = "torch.backends.mha.get_fastpath_enabled() was not True" 1236 elif not is_batched: 1237 why_not_fast_path = ( 1238 f"input not batched; expected query.dim() of 3 but got {query.dim()}" 1239 ) 1240 elif query is not key or key is not value: 1241 # When lifting this restriction, don't forget to either 1242 # enforce that the dtypes all match or test cases where 1243 # they don't! 1244 why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)" 1245 elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype: 1246 why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match" 1247 elif self.in_proj_weight is None: 1248 why_not_fast_path = "in_proj_weight was None" 1249 elif query.dtype != self.in_proj_weight.dtype: 1250 # this case will fail anyway, but at least they'll get a useful error message. 1251 why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match" 1252 elif self.training: 1253 why_not_fast_path = "training is enabled" 1254 elif (self.num_heads % 2) != 0: 1255 why_not_fast_path = "self.num_heads is not even" 1256 elif not self.batch_first: 1257 why_not_fast_path = "batch_first was not True" 1258 elif self.bias_k is not None: 1259 why_not_fast_path = "self.bias_k was not None" 1260 elif self.bias_v is not None: 1261 why_not_fast_path = "self.bias_v was not None" 1262 elif self.add_zero_attn: 1263 why_not_fast_path = "add_zero_attn was enabled" 1264 elif not self._qkv_same_embed_dim: 1265 why_not_fast_path = "_qkv_same_embed_dim was not True" 1266 elif query.is_nested and ( 1267 key_padding_mask is not None or attn_mask is not None 1268 ): 1269 why_not_fast_path = "supplying both src_key_padding_mask and src_mask at the same time \ 1270 is not supported with NestedTensor input" 1271 elif torch.is_autocast_enabled(): 1272 why_not_fast_path = "autocast is enabled" 1273 1274 if not why_not_fast_path: 1275 tensor_args = ( 1276 query, 1277 key, 1278 value, 1279 self.in_proj_weight, 1280 self.in_proj_bias, 1281 self.out_proj.weight, 1282 self.out_proj.bias, 1283 ) 1284 # We have to use list comprehensions below because TorchScript does not support 1285 # generator expressions. 1286 if torch.overrides.has_torch_function(tensor_args): 1287 why_not_fast_path = "some Tensor argument has_torch_function" 1288 elif _is_make_fx_tracing(): 1289 why_not_fast_path = "we are running make_fx tracing" 1290 elif not all(_check_arg_device(x) for x in tensor_args): 1291 why_not_fast_path = ( 1292 "some Tensor argument's device is neither one of " 1293 f"cpu, cuda or {torch.utils.backend_registration._privateuse1_backend_name}" 1294 ) 1295 elif torch.is_grad_enabled() and any( 1296 _arg_requires_grad(x) for x in tensor_args 1297 ): 1298 why_not_fast_path = ( 1299 "grad is enabled and at least one of query or the " 1300 "input/output projection weights or biases requires_grad" 1301 ) 1302 if not why_not_fast_path: 1303 merged_mask, mask_type = self.merge_masks( 1304 attn_mask, key_padding_mask, query 1305 ) 1306 1307 if self.in_proj_bias is not None and self.in_proj_weight is not None: 1308 return torch._native_multi_head_attention( 1309 query, 1310 key, 1311 value, 1312 self.embed_dim, 1313 self.num_heads, 1314 self.in_proj_weight, 1315 self.in_proj_bias, 1316 self.out_proj.weight, 1317 self.out_proj.bias, 1318 merged_mask, 1319 need_weights, 1320 average_attn_weights, 1321 mask_type, 1322 ) 1323 1324 any_nested = query.is_nested or key.is_nested or value.is_nested 1325 assert not any_nested, ( 1326 "MultiheadAttention does not support NestedTensor outside of its fast path. " 1327 + f"The fast path was not hit because {why_not_fast_path}" 1328 ) 1329 1330 if self.batch_first and is_batched: 1331 # make sure that the transpose op does not affect the "is" property 1332 if key is value: 1333 if query is key: 1334 query = key = value = query.transpose(1, 0) 1335 else: 1336 query, key = (x.transpose(1, 0) for x in (query, key)) 1337 value = key 1338 else: 1339 query, key, value = (x.transpose(1, 0) for x in (query, key, value)) 1340 1341 if not self._qkv_same_embed_dim: 1342 attn_output, attn_output_weights = F.multi_head_attention_forward( 1343 query, 1344 key, 1345 value, 1346 self.embed_dim, 1347 self.num_heads, 1348 self.in_proj_weight, 1349 self.in_proj_bias, 1350 self.bias_k, 1351 self.bias_v, 1352 self.add_zero_attn, 1353 self.dropout, 1354 self.out_proj.weight, 1355 self.out_proj.bias, 1356 training=self.training, 1357 key_padding_mask=key_padding_mask, 1358 need_weights=need_weights, 1359 attn_mask=attn_mask, 1360 use_separate_proj_weight=True, 1361 q_proj_weight=self.q_proj_weight, 1362 k_proj_weight=self.k_proj_weight, 1363 v_proj_weight=self.v_proj_weight, 1364 average_attn_weights=average_attn_weights, 1365 is_causal=is_causal, 1366 ) 1367 else: 1368 attn_output, attn_output_weights = F.multi_head_attention_forward( 1369 query, 1370 key, 1371 value, 1372 self.embed_dim, 1373 self.num_heads, 1374 self.in_proj_weight, 1375 self.in_proj_bias, 1376 self.bias_k, 1377 self.bias_v, 1378 self.add_zero_attn, 1379 self.dropout, 1380 self.out_proj.weight, 1381 self.out_proj.bias, 1382 training=self.training, 1383 key_padding_mask=key_padding_mask, 1384 need_weights=need_weights, 1385 attn_mask=attn_mask, 1386 average_attn_weights=average_attn_weights, 1387 is_causal=is_causal, 1388 ) 1389 if self.batch_first and is_batched: 1390 return attn_output.transpose(1, 0), attn_output_weights 1391 else: 1392 return attn_output, attn_output_weights 1393 1394 def merge_masks( 1395 self, 1396 attn_mask: Optional[Tensor], 1397 key_padding_mask: Optional[Tensor], 1398 query: Tensor, 1399 ) -> Tuple[Optional[Tensor], Optional[int]]: 1400 r"""Determine mask type and combine masks if necessary. 1401 1402 If only one mask is provided, that mask 1403 and the corresponding mask type will be returned. If both masks are provided, they will be both 1404 expanded to shape ``(batch_size, num_heads, seq_len, seq_len)``, combined with logical ``or`` 1405 and mask type 2 will be returned 1406 Args: 1407 attn_mask: attention mask of shape ``(seq_len, seq_len)``, mask type 0 1408 key_padding_mask: padding mask of shape ``(batch_size, seq_len)``, mask type 1 1409 query: query embeddings of shape ``(batch_size, seq_len, embed_dim)`` 1410 Returns: 1411 merged_mask: merged mask 1412 mask_type: merged mask type (0, 1, or 2) 1413 """ 1414 mask_type: Optional[int] = None 1415 merged_mask: Optional[Tensor] = None 1416 1417 if key_padding_mask is not None: 1418 mask_type = 1 1419 merged_mask = key_padding_mask 1420 1421 if attn_mask is not None: 1422 # In this branch query can't be a nested tensor, so it has a shape 1423 batch_size, seq_len, _ = query.shape 1424 mask_type = 2 1425 1426 # Always expands attn_mask to 4D 1427 if attn_mask.dim() == 3: 1428 attn_mask_expanded = attn_mask.view(batch_size, -1, seq_len, seq_len) 1429 else: # attn_mask.dim() == 2: 1430 attn_mask_expanded = attn_mask.view(1, 1, seq_len, seq_len).expand( 1431 batch_size, self.num_heads, -1, -1 1432 ) 1433 merged_mask = attn_mask_expanded 1434 1435 if key_padding_mask is not None: 1436 key_padding_mask_expanded = key_padding_mask.view( 1437 batch_size, 1, 1, seq_len 1438 ).expand(-1, self.num_heads, -1, -1) 1439 merged_mask = attn_mask_expanded + key_padding_mask_expanded 1440 1441 # no attn_mask and no key_padding_mask, returns None, None 1442 return merged_mask, mask_type 1443 1444 1445class PReLU(Module): 1446 r"""Applies the element-wise PReLU function. 1447 1448 .. math:: 1449 \text{PReLU}(x) = \max(0,x) + a * \min(0,x) 1450 1451 or 1452 1453 .. math:: 1454 \text{PReLU}(x) = 1455 \begin{cases} 1456 x, & \text{ if } x \ge 0 \\ 1457 ax, & \text{ otherwise } 1458 \end{cases} 1459 1460 Here :math:`a` is a learnable parameter. When called without arguments, `nn.PReLU()` uses a single 1461 parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`, 1462 a separate :math:`a` is used for each input channel. 1463 1464 1465 .. note:: 1466 weight decay should not be used when learning :math:`a` for good performance. 1467 1468 .. note:: 1469 Channel dim is the 2nd dim of input. When input has dims < 2, then there is 1470 no channel dim and the number of channels = 1. 1471 1472 Args: 1473 num_parameters (int): number of :math:`a` to learn. 1474 Although it takes an int as input, there is only two values are legitimate: 1475 1, or the number of channels at input. Default: 1 1476 init (float): the initial value of :math:`a`. Default: 0.25 1477 1478 Shape: 1479 - Input: :math:`( *)` where `*` means, any number of additional 1480 dimensions. 1481 - Output: :math:`(*)`, same shape as the input. 1482 1483 Attributes: 1484 weight (Tensor): the learnable weights of shape (:attr:`num_parameters`). 1485 1486 .. image:: ../scripts/activation_images/PReLU.png 1487 1488 Examples:: 1489 1490 >>> m = nn.PReLU() 1491 >>> input = torch.randn(2) 1492 >>> output = m(input) 1493 """ 1494 1495 __constants__ = ["num_parameters"] 1496 num_parameters: int 1497 1498 def __init__( 1499 self, num_parameters: int = 1, init: float = 0.25, device=None, dtype=None 1500 ) -> None: 1501 factory_kwargs = {"device": device, "dtype": dtype} 1502 self.num_parameters = num_parameters 1503 super().__init__() 1504 self.init = init 1505 self.weight = Parameter(torch.empty(num_parameters, **factory_kwargs)) 1506 self.reset_parameters() 1507 1508 def reset_parameters(self): 1509 torch.nn.init.constant_(self.weight, self.init) 1510 1511 def forward(self, input: Tensor) -> Tensor: 1512 return F.prelu(input, self.weight) 1513 1514 def extra_repr(self) -> str: 1515 return f"num_parameters={self.num_parameters}" 1516 1517 1518class Softsign(Module): 1519 r"""Applies the element-wise Softsign function. 1520 1521 .. math:: 1522 \text{SoftSign}(x) = \frac{x}{ 1 + |x|} 1523 1524 Shape: 1525 - Input: :math:`(*)`, where :math:`*` means any number of dimensions. 1526 - Output: :math:`(*)`, same shape as the input. 1527 1528 .. image:: ../scripts/activation_images/Softsign.png 1529 1530 Examples:: 1531 1532 >>> m = nn.Softsign() 1533 >>> input = torch.randn(2) 1534 >>> output = m(input) 1535 """ 1536 1537 def forward(self, input: Tensor) -> Tensor: 1538 return F.softsign(input) 1539 1540 1541class Tanhshrink(Module): 1542 r"""Applies the element-wise Tanhshrink function. 1543 1544 .. math:: 1545 \text{Tanhshrink}(x) = x - \tanh(x) 1546 1547 Shape: 1548 - Input: :math:`(*)`, where :math:`*` means any number of dimensions. 1549 - Output: :math:`(*)`, same shape as the input. 1550 1551 .. image:: ../scripts/activation_images/Tanhshrink.png 1552 1553 Examples:: 1554 1555 >>> m = nn.Tanhshrink() 1556 >>> input = torch.randn(2) 1557 >>> output = m(input) 1558 """ 1559 1560 def forward(self, input: Tensor) -> Tensor: 1561 return F.tanhshrink(input) 1562 1563 1564class Softmin(Module): 1565 r"""Applies the Softmin function to an n-dimensional input Tensor. 1566 1567 Rescales them so that the elements of the n-dimensional output Tensor 1568 lie in the range `[0, 1]` and sum to 1. 1569 1570 Softmin is defined as: 1571 1572 .. math:: 1573 \text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)} 1574 1575 Shape: 1576 - Input: :math:`(*)` where `*` means, any number of additional 1577 dimensions 1578 - Output: :math:`(*)`, same shape as the input 1579 1580 Args: 1581 dim (int): A dimension along which Softmin will be computed (so every slice 1582 along dim will sum to 1). 1583 1584 Returns: 1585 a Tensor of the same dimension and shape as the input, with 1586 values in the range [0, 1] 1587 1588 Examples:: 1589 1590 >>> m = nn.Softmin(dim=1) 1591 >>> input = torch.randn(2, 3) 1592 >>> output = m(input) 1593 """ 1594 1595 __constants__ = ["dim"] 1596 dim: Optional[int] 1597 1598 def __init__(self, dim: Optional[int] = None) -> None: 1599 super().__init__() 1600 self.dim = dim 1601 1602 def __setstate__(self, state): 1603 super().__setstate__(state) 1604 if not hasattr(self, "dim"): 1605 self.dim = None 1606 1607 def forward(self, input: Tensor) -> Tensor: 1608 return F.softmin(input, self.dim, _stacklevel=5) 1609 1610 def extra_repr(self): 1611 return f"dim={self.dim}" 1612 1613 1614class Softmax(Module): 1615 r"""Applies the Softmax function to an n-dimensional input Tensor. 1616 1617 Rescales them so that the elements of the n-dimensional output Tensor 1618 lie in the range [0,1] and sum to 1. 1619 1620 Softmax is defined as: 1621 1622 .. math:: 1623 \text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)} 1624 1625 When the input Tensor is a sparse tensor then the unspecified 1626 values are treated as ``-inf``. 1627 1628 Shape: 1629 - Input: :math:`(*)` where `*` means, any number of additional 1630 dimensions 1631 - Output: :math:`(*)`, same shape as the input 1632 1633 Returns: 1634 a Tensor of the same dimension and shape as the input with 1635 values in the range [0, 1] 1636 1637 Args: 1638 dim (int): A dimension along which Softmax will be computed (so every slice 1639 along dim will sum to 1). 1640 1641 .. note:: 1642 This module doesn't work directly with NLLLoss, 1643 which expects the Log to be computed between the Softmax and itself. 1644 Use `LogSoftmax` instead (it's faster and has better numerical properties). 1645 1646 Examples:: 1647 1648 >>> m = nn.Softmax(dim=1) 1649 >>> input = torch.randn(2, 3) 1650 >>> output = m(input) 1651 1652 """ 1653 1654 __constants__ = ["dim"] 1655 dim: Optional[int] 1656 1657 def __init__(self, dim: Optional[int] = None) -> None: 1658 super().__init__() 1659 self.dim = dim 1660 1661 def __setstate__(self, state): 1662 super().__setstate__(state) 1663 if not hasattr(self, "dim"): 1664 self.dim = None 1665 1666 def forward(self, input: Tensor) -> Tensor: 1667 return F.softmax(input, self.dim, _stacklevel=5) 1668 1669 def extra_repr(self) -> str: 1670 return f"dim={self.dim}" 1671 1672 1673class Softmax2d(Module): 1674 r"""Applies SoftMax over features to each spatial location. 1675 1676 When given an image of ``Channels x Height x Width``, it will 1677 apply `Softmax` to each location :math:`(Channels, h_i, w_j)` 1678 1679 Shape: 1680 - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`. 1681 - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input) 1682 1683 Returns: 1684 a Tensor of the same dimension and shape as the input with 1685 values in the range [0, 1] 1686 1687 Examples:: 1688 1689 >>> m = nn.Softmax2d() 1690 >>> # you softmax over the 2nd dimension 1691 >>> input = torch.randn(2, 3, 12, 13) 1692 >>> output = m(input) 1693 """ 1694 1695 def forward(self, input: Tensor) -> Tensor: 1696 if input.dim() not in (3, 4): 1697 raise ValueError( 1698 f"Softmax2d: expected input to be 3D or 4D, got {input.dim()}D instead" 1699 ) 1700 return F.softmax(input, -3, _stacklevel=5) 1701 1702 1703class LogSoftmax(Module): 1704 r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional input Tensor. 1705 1706 The LogSoftmax formulation can be simplified as: 1707 1708 .. math:: 1709 \text{LogSoftmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right) 1710 1711 Shape: 1712 - Input: :math:`(*)` where `*` means, any number of additional 1713 dimensions 1714 - Output: :math:`(*)`, same shape as the input 1715 1716 Args: 1717 dim (int): A dimension along which LogSoftmax will be computed. 1718 1719 Returns: 1720 a Tensor of the same dimension and shape as the input with 1721 values in the range [-inf, 0) 1722 1723 Examples:: 1724 1725 >>> m = nn.LogSoftmax(dim=1) 1726 >>> input = torch.randn(2, 3) 1727 >>> output = m(input) 1728 """ 1729 1730 __constants__ = ["dim"] 1731 dim: Optional[int] 1732 1733 def __init__(self, dim: Optional[int] = None) -> None: 1734 super().__init__() 1735 self.dim = dim 1736 1737 def __setstate__(self, state): 1738 super().__setstate__(state) 1739 if not hasattr(self, "dim"): 1740 self.dim = None 1741 1742 def forward(self, input: Tensor) -> Tensor: 1743 return F.log_softmax(input, self.dim, _stacklevel=5) 1744 1745 def extra_repr(self): 1746 return f"dim={self.dim}" 1747