xref: /aosp_15_r20/external/pytorch/torch/nn/modules/activation.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import warnings
3from typing import Optional, Tuple
4
5import torch
6import torch.nn.functional as F
7from torch import Tensor
8from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
9from torch.nn.parameter import Parameter
10
11from .linear import NonDynamicallyQuantizableLinear
12from .module import Module
13
14
15__all__ = [
16    "Threshold",
17    "ReLU",
18    "RReLU",
19    "Hardtanh",
20    "ReLU6",
21    "Sigmoid",
22    "Hardsigmoid",
23    "Tanh",
24    "SiLU",
25    "Mish",
26    "Hardswish",
27    "ELU",
28    "CELU",
29    "SELU",
30    "GLU",
31    "GELU",
32    "Hardshrink",
33    "LeakyReLU",
34    "LogSigmoid",
35    "Softplus",
36    "Softshrink",
37    "MultiheadAttention",
38    "PReLU",
39    "Softsign",
40    "Tanhshrink",
41    "Softmin",
42    "Softmax",
43    "Softmax2d",
44    "LogSoftmax",
45]
46
47
48class Threshold(Module):
49    r"""Thresholds each element of the input Tensor.
50
51    Threshold is defined as:
52
53    .. math::
54        y =
55        \begin{cases}
56        x, &\text{ if } x > \text{threshold} \\
57        \text{value}, &\text{ otherwise }
58        \end{cases}
59
60    Args:
61        threshold: The value to threshold at
62        value: The value to replace with
63        inplace: can optionally do the operation in-place. Default: ``False``
64
65    Shape:
66        - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
67        - Output: :math:`(*)`, same shape as the input.
68
69    Examples::
70
71        >>> m = nn.Threshold(0.1, 20)
72        >>> input = torch.randn(2)
73        >>> output = m(input)
74    """
75
76    __constants__ = ["threshold", "value", "inplace"]
77
78    threshold: float
79    value: float
80    inplace: bool
81
82    def __init__(self, threshold: float, value: float, inplace: bool = False) -> None:
83        super().__init__()
84        self.threshold = threshold
85        self.value = value
86        self.inplace = inplace
87        # TODO: check in THNN (if inplace == True, then assert value <= threshold)
88
89    def forward(self, input: Tensor) -> Tensor:
90        return F.threshold(input, self.threshold, self.value, self.inplace)
91
92    def extra_repr(self):
93        inplace_str = ", inplace=True" if self.inplace else ""
94        return f"threshold={self.threshold}, value={self.value}{inplace_str}"
95
96
97class ReLU(Module):
98    r"""Applies the rectified linear unit function element-wise.
99
100    :math:`\text{ReLU}(x) = (x)^+ = \max(0, x)`
101
102    Args:
103        inplace: can optionally do the operation in-place. Default: ``False``
104
105    Shape:
106        - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
107        - Output: :math:`(*)`, same shape as the input.
108
109    .. image:: ../scripts/activation_images/ReLU.png
110
111    Examples::
112
113        >>> m = nn.ReLU()
114        >>> input = torch.randn(2)
115        >>> output = m(input)
116
117
118      An implementation of CReLU - https://arxiv.org/abs/1603.05201
119
120        >>> m = nn.ReLU()
121        >>> input = torch.randn(2).unsqueeze(0)
122        >>> output = torch.cat((m(input), m(-input)))
123    """
124
125    __constants__ = ["inplace"]
126    inplace: bool
127
128    def __init__(self, inplace: bool = False):
129        super().__init__()
130        self.inplace = inplace
131
132    def forward(self, input: Tensor) -> Tensor:
133        return F.relu(input, inplace=self.inplace)
134
135    def extra_repr(self) -> str:
136        inplace_str = "inplace=True" if self.inplace else ""
137        return inplace_str
138
139
140class RReLU(Module):
141    r"""Applies the randomized leaky rectified linear unit function, element-wise.
142
143    Method described in the paper:
144    `Empirical Evaluation of Rectified Activations in Convolutional Network <https://arxiv.org/abs/1505.00853>`_.
145
146    The function is defined as:
147
148    .. math::
149        \text{RReLU}(x) =
150        \begin{cases}
151            x & \text{if } x \geq 0 \\
152            ax & \text{ otherwise }
153        \end{cases}
154
155    where :math:`a` is randomly sampled from uniform distribution
156    :math:`\mathcal{U}(\text{lower}, \text{upper})` during training while during
157    evaluation :math:`a` is fixed with :math:`a = \frac{\text{lower} + \text{upper}}{2}`.
158
159    Args:
160        lower: lower bound of the uniform distribution. Default: :math:`\frac{1}{8}`
161        upper: upper bound of the uniform distribution. Default: :math:`\frac{1}{3}`
162        inplace: can optionally do the operation in-place. Default: ``False``
163
164    Shape:
165        - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
166        - Output: :math:`(*)`, same shape as the input.
167
168    .. image:: ../scripts/activation_images/RReLU.png
169
170    Examples::
171
172        >>> m = nn.RReLU(0.1, 0.3)
173        >>> input = torch.randn(2)
174        >>> output = m(input)
175
176    """
177
178    __constants__ = ["lower", "upper", "inplace"]
179
180    lower: float
181    upper: float
182    inplace: bool
183
184    def __init__(
185        self, lower: float = 1.0 / 8, upper: float = 1.0 / 3, inplace: bool = False
186    ):
187        super().__init__()
188        self.lower = lower
189        self.upper = upper
190        self.inplace = inplace
191
192    def forward(self, input: Tensor) -> Tensor:
193        return F.rrelu(input, self.lower, self.upper, self.training, self.inplace)
194
195    def extra_repr(self):
196        inplace_str = ", inplace=True" if self.inplace else ""
197        return f"lower={self.lower}, upper={self.upper}{inplace_str}"
198
199
200class Hardtanh(Module):
201    r"""Applies the HardTanh function element-wise.
202
203    HardTanh is defined as:
204
205    .. math::
206        \text{HardTanh}(x) = \begin{cases}
207            \text{max\_val} & \text{ if } x > \text{ max\_val } \\
208            \text{min\_val} & \text{ if } x < \text{ min\_val } \\
209            x & \text{ otherwise } \\
210        \end{cases}
211
212    Args:
213        min_val: minimum value of the linear region range. Default: -1
214        max_val: maximum value of the linear region range. Default: 1
215        inplace: can optionally do the operation in-place. Default: ``False``
216
217    Keyword arguments :attr:`min_value` and :attr:`max_value`
218    have been deprecated in favor of :attr:`min_val` and :attr:`max_val`.
219
220    Shape:
221        - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
222        - Output: :math:`(*)`, same shape as the input.
223
224    .. image:: ../scripts/activation_images/Hardtanh.png
225
226    Examples::
227
228        >>> m = nn.Hardtanh(-2, 2)
229        >>> input = torch.randn(2)
230        >>> output = m(input)
231    """
232
233    __constants__ = ["min_val", "max_val", "inplace"]
234
235    min_val: float
236    max_val: float
237    inplace: bool
238
239    def __init__(
240        self,
241        min_val: float = -1.0,
242        max_val: float = 1.0,
243        inplace: bool = False,
244        min_value: Optional[float] = None,
245        max_value: Optional[float] = None,
246    ) -> None:
247        super().__init__()
248        if min_value is not None:
249            warnings.warn(
250                "keyword argument `min_value` is deprecated and rename to `min_val`",
251                FutureWarning,
252                stacklevel=2,
253            )
254            min_val = min_value
255        if max_value is not None:
256            warnings.warn(
257                "keyword argument `max_value` is deprecated and rename to `max_val`",
258                FutureWarning,
259                stacklevel=2,
260            )
261            max_val = max_value
262
263        self.min_val = min_val
264        self.max_val = max_val
265        self.inplace = inplace
266        assert self.max_val > self.min_val
267
268    def forward(self, input: Tensor) -> Tensor:
269        return F.hardtanh(input, self.min_val, self.max_val, self.inplace)
270
271    def extra_repr(self) -> str:
272        inplace_str = ", inplace=True" if self.inplace else ""
273        return f"min_val={self.min_val}, max_val={self.max_val}{inplace_str}"
274
275
276class ReLU6(Hardtanh):
277    r"""Applies the ReLU6 function element-wise.
278
279    .. math::
280        \text{ReLU6}(x) = \min(\max(0,x), 6)
281
282    Args:
283        inplace: can optionally do the operation in-place. Default: ``False``
284
285    Shape:
286        - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
287        - Output: :math:`(*)`, same shape as the input.
288
289    .. image:: ../scripts/activation_images/ReLU6.png
290
291    Examples::
292
293        >>> m = nn.ReLU6()
294        >>> input = torch.randn(2)
295        >>> output = m(input)
296    """
297
298    def __init__(self, inplace: bool = False):
299        super().__init__(0.0, 6.0, inplace)
300
301    def extra_repr(self) -> str:
302        inplace_str = "inplace=True" if self.inplace else ""
303        return inplace_str
304
305
306class Sigmoid(Module):
307    r"""Applies the Sigmoid function element-wise.
308
309    .. math::
310        \text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
311
312
313    Shape:
314        - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
315        - Output: :math:`(*)`, same shape as the input.
316
317    .. image:: ../scripts/activation_images/Sigmoid.png
318
319    Examples::
320
321        >>> m = nn.Sigmoid()
322        >>> input = torch.randn(2)
323        >>> output = m(input)
324    """
325
326    def forward(self, input: Tensor) -> Tensor:
327        return torch.sigmoid(input)
328
329
330class Hardsigmoid(Module):
331    r"""Applies the Hardsigmoid function element-wise.
332
333    Hardsigmoid is defined as:
334
335    .. math::
336        \text{Hardsigmoid}(x) = \begin{cases}
337            0 & \text{if~} x \le -3, \\
338            1 & \text{if~} x \ge +3, \\
339            x / 6 + 1 / 2 & \text{otherwise}
340        \end{cases}
341
342    Args:
343        inplace: can optionally do the operation in-place. Default: ``False``
344
345    Shape:
346        - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
347        - Output: :math:`(*)`, same shape as the input.
348
349    .. image:: ../scripts/activation_images/Hardsigmoid.png
350
351    Examples::
352
353        >>> m = nn.Hardsigmoid()
354        >>> input = torch.randn(2)
355        >>> output = m(input)
356    """
357
358    __constants__ = ["inplace"]
359
360    inplace: bool
361
362    def __init__(self, inplace: bool = False) -> None:
363        super().__init__()
364        self.inplace = inplace
365
366    def forward(self, input: Tensor) -> Tensor:
367        return F.hardsigmoid(input, self.inplace)
368
369
370class Tanh(Module):
371    r"""Applies the Hyperbolic Tangent (Tanh) function element-wise.
372
373    Tanh is defined as:
374
375    .. math::
376        \text{Tanh}(x) = \tanh(x) = \frac{\exp(x) - \exp(-x)} {\exp(x) + \exp(-x)}
377
378    Shape:
379        - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
380        - Output: :math:`(*)`, same shape as the input.
381
382    .. image:: ../scripts/activation_images/Tanh.png
383
384    Examples::
385
386        >>> m = nn.Tanh()
387        >>> input = torch.randn(2)
388        >>> output = m(input)
389    """
390
391    def forward(self, input: Tensor) -> Tensor:
392        return torch.tanh(input)
393
394
395class SiLU(Module):
396    r"""Applies the Sigmoid Linear Unit (SiLU) function, element-wise.
397
398    The SiLU function is also known as the swish function.
399
400    .. math::
401        \text{silu}(x) = x * \sigma(x), \text{where } \sigma(x) \text{ is the logistic sigmoid.}
402
403    .. note::
404        See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_
405        where the SiLU (Sigmoid Linear Unit) was originally coined, and see
406        `Sigmoid-Weighted Linear Units for Neural Network Function Approximation
407        in Reinforcement Learning <https://arxiv.org/abs/1702.03118>`_ and `Swish:
408        a Self-Gated Activation Function <https://arxiv.org/abs/1710.05941v1>`_
409        where the SiLU was experimented with later.
410
411    Shape:
412        - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
413        - Output: :math:`(*)`, same shape as the input.
414
415    .. image:: ../scripts/activation_images/SiLU.png
416
417    Examples::
418
419        >>> m = nn.SiLU()
420        >>> input = torch.randn(2)
421        >>> output = m(input)
422    """
423
424    __constants__ = ["inplace"]
425    inplace: bool
426
427    def __init__(self, inplace: bool = False):
428        super().__init__()
429        self.inplace = inplace
430
431    def forward(self, input: Tensor) -> Tensor:
432        return F.silu(input, inplace=self.inplace)
433
434    def extra_repr(self) -> str:
435        inplace_str = "inplace=True" if self.inplace else ""
436        return inplace_str
437
438
439class Mish(Module):
440    r"""Applies the Mish function, element-wise.
441
442    Mish: A Self Regularized Non-Monotonic Neural Activation Function.
443
444    .. math::
445        \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
446
447    .. note::
448        See `Mish: A Self Regularized Non-Monotonic Neural Activation Function <https://arxiv.org/abs/1908.08681>`_
449
450    Shape:
451        - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
452        - Output: :math:`(*)`, same shape as the input.
453
454    .. image:: ../scripts/activation_images/Mish.png
455
456    Examples::
457
458        >>> m = nn.Mish()
459        >>> input = torch.randn(2)
460        >>> output = m(input)
461    """
462
463    __constants__ = ["inplace"]
464    inplace: bool
465
466    def __init__(self, inplace: bool = False):
467        super().__init__()
468        self.inplace = inplace
469
470    def forward(self, input: Tensor) -> Tensor:
471        return F.mish(input, inplace=self.inplace)
472
473    def extra_repr(self) -> str:
474        inplace_str = "inplace=True" if self.inplace else ""
475        return inplace_str
476
477
478class Hardswish(Module):
479    r"""Applies the Hardswish function, element-wise.
480
481    Method described in the paper: `Searching for MobileNetV3 <https://arxiv.org/abs/1905.02244>`_.
482
483    Hardswish is defined as:
484
485    .. math::
486        \text{Hardswish}(x) = \begin{cases}
487            0 & \text{if~} x \le -3, \\
488            x & \text{if~} x \ge +3, \\
489            x \cdot (x + 3) /6 & \text{otherwise}
490        \end{cases}
491
492    Args:
493        inplace: can optionally do the operation in-place. Default: ``False``
494
495    Shape:
496        - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
497        - Output: :math:`(*)`, same shape as the input.
498
499    .. image:: ../scripts/activation_images/Hardswish.png
500
501    Examples::
502
503        >>> m = nn.Hardswish()
504        >>> input = torch.randn(2)
505        >>> output = m(input)
506    """
507
508    __constants__ = ["inplace"]
509
510    inplace: bool
511
512    def __init__(self, inplace: bool = False) -> None:
513        super().__init__()
514        self.inplace = inplace
515
516    def forward(self, input: Tensor) -> Tensor:
517        return F.hardswish(input, self.inplace)
518
519
520class ELU(Module):
521    r"""Applies the Exponential Linear Unit (ELU) function, element-wise.
522
523    Method described in the paper: `Fast and Accurate Deep Network Learning by Exponential Linear
524    Units (ELUs) <https://arxiv.org/abs/1511.07289>`__.
525
526    ELU is defined as:
527
528    .. math::
529        \text{ELU}(x) = \begin{cases}
530        x, & \text{ if } x > 0\\
531        \alpha * (\exp(x) - 1), & \text{ if } x \leq 0
532        \end{cases}
533
534    Args:
535        alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0
536        inplace: can optionally do the operation in-place. Default: ``False``
537
538    Shape:
539        - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
540        - Output: :math:`(*)`, same shape as the input.
541
542    .. image:: ../scripts/activation_images/ELU.png
543
544    Examples::
545
546        >>> m = nn.ELU()
547        >>> input = torch.randn(2)
548        >>> output = m(input)
549    """
550
551    __constants__ = ["alpha", "inplace"]
552    alpha: float
553    inplace: bool
554
555    def __init__(self, alpha: float = 1.0, inplace: bool = False) -> None:
556        super().__init__()
557        self.alpha = alpha
558        self.inplace = inplace
559
560    def forward(self, input: Tensor) -> Tensor:
561        return F.elu(input, self.alpha, self.inplace)
562
563    def extra_repr(self) -> str:
564        inplace_str = ", inplace=True" if self.inplace else ""
565        return f"alpha={self.alpha}{inplace_str}"
566
567
568class CELU(Module):
569    r"""Applies the CELU function element-wise.
570
571    .. math::
572        \text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))
573
574    More details can be found in the paper `Continuously Differentiable Exponential Linear Units`_ .
575
576    Args:
577        alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0
578        inplace: can optionally do the operation in-place. Default: ``False``
579
580    Shape:
581        - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
582        - Output: :math:`(*)`, same shape as the input.
583
584    .. image:: ../scripts/activation_images/CELU.png
585
586    Examples::
587
588        >>> m = nn.CELU()
589        >>> input = torch.randn(2)
590        >>> output = m(input)
591
592    .. _`Continuously Differentiable Exponential Linear Units`:
593        https://arxiv.org/abs/1704.07483
594    """
595
596    __constants__ = ["alpha", "inplace"]
597    alpha: float
598    inplace: bool
599
600    def __init__(self, alpha: float = 1.0, inplace: bool = False) -> None:
601        super().__init__()
602        self.alpha = alpha
603        self.inplace = inplace
604
605    def forward(self, input: Tensor) -> Tensor:
606        return F.celu(input, self.alpha, self.inplace)
607
608    def extra_repr(self) -> str:
609        inplace_str = ", inplace=True" if self.inplace else ""
610        return f"alpha={self.alpha}{inplace_str}"
611
612
613class SELU(Module):
614    r"""Applies the SELU function element-wise.
615
616    .. math::
617        \text{SELU}(x) = \text{scale} * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1)))
618
619    with :math:`\alpha = 1.6732632423543772848170429916717` and
620    :math:`\text{scale} = 1.0507009873554804934193349852946`.
621
622    .. warning::
623        When using ``kaiming_normal`` or ``kaiming_normal_`` for initialisation,
624        ``nonlinearity='linear'`` should be used instead of ``nonlinearity='selu'``
625        in order to get `Self-Normalizing Neural Networks`_.
626        See :func:`torch.nn.init.calculate_gain` for more information.
627
628    More details can be found in the paper `Self-Normalizing Neural Networks`_ .
629
630    Args:
631        inplace (bool, optional): can optionally do the operation in-place. Default: ``False``
632
633    Shape:
634        - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
635        - Output: :math:`(*)`, same shape as the input.
636
637    .. image:: ../scripts/activation_images/SELU.png
638
639    Examples::
640
641        >>> m = nn.SELU()
642        >>> input = torch.randn(2)
643        >>> output = m(input)
644
645    .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
646    """
647
648    __constants__ = ["inplace"]
649    inplace: bool
650
651    def __init__(self, inplace: bool = False) -> None:
652        super().__init__()
653        self.inplace = inplace
654
655    def forward(self, input: Tensor) -> Tensor:
656        return F.selu(input, self.inplace)
657
658    def extra_repr(self) -> str:
659        inplace_str = "inplace=True" if self.inplace else ""
660        return inplace_str
661
662
663class GLU(Module):
664    r"""Applies the gated linear unit function.
665
666    :math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half
667    of the input matrices and :math:`b` is the second half.
668
669    Args:
670        dim (int): the dimension on which to split the input. Default: -1
671
672    Shape:
673        - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
674          dimensions
675        - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
676
677    Examples::
678
679        >>> m = nn.GLU()
680        >>> input = torch.randn(4, 2)
681        >>> output = m(input)
682    """
683
684    __constants__ = ["dim"]
685    dim: int
686
687    def __init__(self, dim: int = -1) -> None:
688        super().__init__()
689        self.dim = dim
690
691    def forward(self, input: Tensor) -> Tensor:
692        return F.glu(input, self.dim)
693
694    def extra_repr(self) -> str:
695        return f"dim={self.dim}"
696
697
698class GELU(Module):
699    r"""Applies the Gaussian Error Linear Units function.
700
701    .. math:: \text{GELU}(x) = x * \Phi(x)
702
703    where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution.
704
705    When the approximate argument is 'tanh', Gelu is estimated with:
706
707    .. math:: \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt{2 / \pi} * (x + 0.044715 * x^3)))
708
709    Args:
710        approximate (str, optional): the gelu approximation algorithm to use:
711            ``'none'`` | ``'tanh'``. Default: ``'none'``
712
713    Shape:
714        - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
715        - Output: :math:`(*)`, same shape as the input.
716
717    .. image:: ../scripts/activation_images/GELU.png
718
719    Examples::
720
721        >>> m = nn.GELU()
722        >>> input = torch.randn(2)
723        >>> output = m(input)
724    """
725
726    __constants__ = ["approximate"]
727    approximate: str
728
729    def __init__(self, approximate: str = "none") -> None:
730        super().__init__()
731        self.approximate = approximate
732
733    def forward(self, input: Tensor) -> Tensor:
734        return F.gelu(input, approximate=self.approximate)
735
736    def extra_repr(self) -> str:
737        return f"approximate={repr(self.approximate)}"
738
739
740class Hardshrink(Module):
741    r"""Applies the Hard Shrinkage (Hardshrink) function element-wise.
742
743    Hardshrink is defined as:
744
745    .. math::
746        \text{HardShrink}(x) =
747        \begin{cases}
748        x, & \text{ if } x > \lambda \\
749        x, & \text{ if } x < -\lambda \\
750        0, & \text{ otherwise }
751        \end{cases}
752
753    Args:
754        lambd: the :math:`\lambda` value for the Hardshrink formulation. Default: 0.5
755
756    Shape:
757        - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
758        - Output: :math:`(*)`, same shape as the input.
759
760    .. image:: ../scripts/activation_images/Hardshrink.png
761
762    Examples::
763
764        >>> m = nn.Hardshrink()
765        >>> input = torch.randn(2)
766        >>> output = m(input)
767    """
768
769    __constants__ = ["lambd"]
770    lambd: float
771
772    def __init__(self, lambd: float = 0.5) -> None:
773        super().__init__()
774        self.lambd = lambd
775
776    def forward(self, input: Tensor) -> Tensor:
777        return F.hardshrink(input, self.lambd)
778
779    def extra_repr(self) -> str:
780        return f"{self.lambd}"
781
782
783class LeakyReLU(Module):
784    r"""Applies the LeakyReLU function element-wise.
785
786    .. math::
787        \text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)
788
789
790    or
791
792    .. math::
793        \text{LeakyReLU}(x) =
794        \begin{cases}
795        x, & \text{ if } x \geq 0 \\
796        \text{negative\_slope} \times x, & \text{ otherwise }
797        \end{cases}
798
799    Args:
800        negative_slope: Controls the angle of the negative slope (which is used for
801          negative input values). Default: 1e-2
802        inplace: can optionally do the operation in-place. Default: ``False``
803
804    Shape:
805        - Input: :math:`(*)` where `*` means, any number of additional
806          dimensions
807        - Output: :math:`(*)`, same shape as the input
808
809    .. image:: ../scripts/activation_images/LeakyReLU.png
810
811    Examples::
812
813        >>> m = nn.LeakyReLU(0.1)
814        >>> input = torch.randn(2)
815        >>> output = m(input)
816    """
817
818    __constants__ = ["inplace", "negative_slope"]
819    inplace: bool
820    negative_slope: float
821
822    def __init__(self, negative_slope: float = 1e-2, inplace: bool = False) -> None:
823        super().__init__()
824        self.negative_slope = negative_slope
825        self.inplace = inplace
826
827    def forward(self, input: Tensor) -> Tensor:
828        return F.leaky_relu(input, self.negative_slope, self.inplace)
829
830    def extra_repr(self) -> str:
831        inplace_str = ", inplace=True" if self.inplace else ""
832        return f"negative_slope={self.negative_slope}{inplace_str}"
833
834
835class LogSigmoid(Module):
836    r"""Applies the Logsigmoid function element-wise.
837
838    .. math::
839        \text{LogSigmoid}(x) = \log\left(\frac{ 1 }{ 1 + \exp(-x)}\right)
840
841    Shape:
842        - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
843        - Output: :math:`(*)`, same shape as the input.
844
845    .. image:: ../scripts/activation_images/LogSigmoid.png
846
847    Examples::
848
849        >>> m = nn.LogSigmoid()
850        >>> input = torch.randn(2)
851        >>> output = m(input)
852    """
853
854    def forward(self, input: Tensor) -> Tensor:
855        return F.logsigmoid(input)
856
857
858class Softplus(Module):
859    r"""Applies the Softplus function element-wise.
860
861    .. math::
862        \text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x))
863
864    SoftPlus is a smooth approximation to the ReLU function and can be used
865    to constrain the output of a machine to always be positive.
866
867    For numerical stability the implementation reverts to the linear function
868    when :math:`input \times \beta > threshold`.
869
870    Args:
871        beta: the :math:`\beta` value for the Softplus formulation. Default: 1
872        threshold: values above this revert to a linear function. Default: 20
873
874    Shape:
875        - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
876        - Output: :math:`(*)`, same shape as the input.
877
878    .. image:: ../scripts/activation_images/Softplus.png
879
880    Examples::
881
882        >>> m = nn.Softplus()
883        >>> input = torch.randn(2)
884        >>> output = m(input)
885    """
886
887    __constants__ = ["beta", "threshold"]
888    beta: float
889    threshold: float
890
891    def __init__(self, beta: float = 1.0, threshold: float = 20.0) -> None:
892        super().__init__()
893        self.beta = beta
894        self.threshold = threshold
895
896    def forward(self, input: Tensor) -> Tensor:
897        return F.softplus(input, self.beta, self.threshold)
898
899    def extra_repr(self) -> str:
900        return f"beta={self.beta}, threshold={self.threshold}"
901
902
903class Softshrink(Module):
904    r"""Applies the soft shrinkage function element-wise.
905
906    .. math::
907        \text{SoftShrinkage}(x) =
908        \begin{cases}
909        x - \lambda, & \text{ if } x > \lambda \\
910        x + \lambda, & \text{ if } x < -\lambda \\
911        0, & \text{ otherwise }
912        \end{cases}
913
914    Args:
915        lambd: the :math:`\lambda` (must be no less than zero) value for the Softshrink formulation. Default: 0.5
916
917    Shape:
918        - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
919        - Output: :math:`(*)`, same shape as the input.
920
921    .. image:: ../scripts/activation_images/Softshrink.png
922
923    Examples::
924
925        >>> m = nn.Softshrink()
926        >>> input = torch.randn(2)
927        >>> output = m(input)
928    """
929
930    __constants__ = ["lambd"]
931    lambd: float
932
933    def __init__(self, lambd: float = 0.5) -> None:
934        super().__init__()
935        self.lambd = lambd
936
937    def forward(self, input: Tensor) -> Tensor:
938        return F.softshrink(input, self.lambd)
939
940    def extra_repr(self) -> str:
941        return str(self.lambd)
942
943
944def _check_arg_device(x: Optional[torch.Tensor]) -> bool:
945    if x is not None:
946        return x.device.type in [
947            "cpu",
948            "cuda",
949            torch.utils.backend_registration._privateuse1_backend_name,
950        ]
951    return True
952
953
954def _arg_requires_grad(x: Optional[torch.Tensor]) -> bool:
955    if x is not None:
956        return x.requires_grad
957    return False
958
959
960def _is_make_fx_tracing():
961    if not torch.jit.is_scripting():
962        torch_dispatch_mode_stack = (
963            torch.utils._python_dispatch._get_current_dispatch_mode_stack()
964        )
965        return any(
966            type(x) == torch.fx.experimental.proxy_tensor.ProxyTorchDispatchMode
967            for x in torch_dispatch_mode_stack
968        )
969    else:
970        return False
971
972
973class MultiheadAttention(Module):
974    r"""Allows the model to jointly attend to information from different representation subspaces.
975
976    Method described in the paper:
977    `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
978
979    Multi-Head Attention is defined as:
980
981    .. math::
982        \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
983
984    where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
985
986    ``nn.MultiHeadAttention`` will use the optimized implementations of
987    ``scaled_dot_product_attention()`` when possible.
988
989    In addition to support for the new ``scaled_dot_product_attention()``
990    function, for speeding up Inference, MHA will use
991    fastpath inference with support for Nested Tensors, iff:
992
993    - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor).
994    - inputs are batched (3D) with ``batch_first==True``
995    - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
996    - training is disabled (using ``.eval()``)
997    - ``add_bias_kv`` is ``False``
998    - ``add_zero_attn`` is ``False``
999    - ``kdim`` and ``vdim`` are equal to ``embed_dim``
1000    - if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
1001      nor ``attn_mask`` is passed
1002    - autocast is disabled
1003
1004    If the optimized inference fastpath implementation is in use, a
1005    `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
1006    ``query``/``key``/``value`` to represent padding more efficiently than using a
1007    padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
1008    will be returned, and an additional speedup proportional to the fraction of the input
1009    that is padding can be expected.
1010
1011    Args:
1012        embed_dim: Total dimension of the model.
1013        num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
1014            across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
1015        dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
1016        bias: If specified, adds bias to input / output projection layers. Default: ``True``.
1017        add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
1018        add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
1019            Default: ``False``.
1020        kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
1021        vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
1022        batch_first: If ``True``, then the input and output tensors are provided
1023            as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
1024
1025    Examples::
1026
1027        >>> # xdoctest: +SKIP
1028        >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
1029        >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
1030
1031    .. _`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`:
1032         https://arxiv.org/abs/2205.14135
1033
1034    """
1035
1036    __constants__ = ["batch_first"]
1037    bias_k: Optional[torch.Tensor]
1038    bias_v: Optional[torch.Tensor]
1039
1040    def __init__(
1041        self,
1042        embed_dim,
1043        num_heads,
1044        dropout=0.0,
1045        bias=True,
1046        add_bias_kv=False,
1047        add_zero_attn=False,
1048        kdim=None,
1049        vdim=None,
1050        batch_first=False,
1051        device=None,
1052        dtype=None,
1053    ) -> None:
1054        if embed_dim <= 0 or num_heads <= 0:
1055            raise ValueError(
1056                f"embed_dim and num_heads must be greater than 0,"
1057                f" got embed_dim={embed_dim} and num_heads={num_heads} instead"
1058            )
1059        factory_kwargs = {"device": device, "dtype": dtype}
1060        super().__init__()
1061        self.embed_dim = embed_dim
1062        self.kdim = kdim if kdim is not None else embed_dim
1063        self.vdim = vdim if vdim is not None else embed_dim
1064        self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
1065
1066        self.num_heads = num_heads
1067        self.dropout = dropout
1068        self.batch_first = batch_first
1069        self.head_dim = embed_dim // num_heads
1070        assert (
1071            self.head_dim * num_heads == self.embed_dim
1072        ), "embed_dim must be divisible by num_heads"
1073
1074        if not self._qkv_same_embed_dim:
1075            self.q_proj_weight = Parameter(
1076                torch.empty((embed_dim, embed_dim), **factory_kwargs)
1077            )
1078            self.k_proj_weight = Parameter(
1079                torch.empty((embed_dim, self.kdim), **factory_kwargs)
1080            )
1081            self.v_proj_weight = Parameter(
1082                torch.empty((embed_dim, self.vdim), **factory_kwargs)
1083            )
1084            self.register_parameter("in_proj_weight", None)
1085        else:
1086            self.in_proj_weight = Parameter(
1087                torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
1088            )
1089            self.register_parameter("q_proj_weight", None)
1090            self.register_parameter("k_proj_weight", None)
1091            self.register_parameter("v_proj_weight", None)
1092
1093        if bias:
1094            self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
1095        else:
1096            self.register_parameter("in_proj_bias", None)
1097        self.out_proj = NonDynamicallyQuantizableLinear(
1098            embed_dim, embed_dim, bias=bias, **factory_kwargs
1099        )
1100
1101        if add_bias_kv:
1102            self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
1103            self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
1104        else:
1105            self.bias_k = self.bias_v = None
1106
1107        self.add_zero_attn = add_zero_attn
1108
1109        self._reset_parameters()
1110
1111    def _reset_parameters(self):
1112        if self._qkv_same_embed_dim:
1113            xavier_uniform_(self.in_proj_weight)
1114        else:
1115            xavier_uniform_(self.q_proj_weight)
1116            xavier_uniform_(self.k_proj_weight)
1117            xavier_uniform_(self.v_proj_weight)
1118
1119        if self.in_proj_bias is not None:
1120            constant_(self.in_proj_bias, 0.0)
1121            constant_(self.out_proj.bias, 0.0)
1122        if self.bias_k is not None:
1123            xavier_normal_(self.bias_k)
1124        if self.bias_v is not None:
1125            xavier_normal_(self.bias_v)
1126
1127    def __setstate__(self, state):
1128        # Support loading old MultiheadAttention checkpoints generated by v1.1.0
1129        if "_qkv_same_embed_dim" not in state:
1130            state["_qkv_same_embed_dim"] = True
1131
1132        super().__setstate__(state)
1133
1134    def forward(
1135        self,
1136        query: Tensor,
1137        key: Tensor,
1138        value: Tensor,
1139        key_padding_mask: Optional[Tensor] = None,
1140        need_weights: bool = True,
1141        attn_mask: Optional[Tensor] = None,
1142        average_attn_weights: bool = True,
1143        is_causal: bool = False,
1144    ) -> Tuple[Tensor, Optional[Tensor]]:
1145        r"""Compute attention outputs using query, key, and value embeddings.
1146
1147            Supports optional parameters for padding, masks and attention weights.
1148
1149        Args:
1150            query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
1151                or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
1152                :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
1153                Queries are compared against key-value pairs to produce the output.
1154                See "Attention Is All You Need" for more details.
1155            key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
1156                or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
1157                :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
1158                See "Attention Is All You Need" for more details.
1159            value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
1160                ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
1161                sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
1162                See "Attention Is All You Need" for more details.
1163            key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
1164                to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
1165                Binary and float masks are supported.
1166                For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
1167                the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
1168            need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
1169                Set ``need_weights=False`` to use the optimized ``scaled_dot_product_attention``
1170                and achieve the best performance for MHA.
1171                Default: ``True``.
1172            attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
1173                :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
1174                :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
1175                broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
1176                Binary and float masks are supported. For a binary mask, a ``True`` value indicates that the
1177                corresponding position is not allowed to attend. For a float mask, the mask values will be added to
1178                the attention weight.
1179                If both attn_mask and key_padding_mask are supplied, their types should match.
1180            average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
1181                heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
1182                effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
1183            is_causal: If specified, applies a causal mask as attention mask.
1184                Default: ``False``.
1185                Warning:
1186                ``is_causal`` provides a hint that ``attn_mask`` is the
1187                causal mask. Providing incorrect hints can result in
1188                incorrect execution, including forward and backward
1189                compatibility.
1190
1191        Outputs:
1192            - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
1193              :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
1194              where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
1195              embedding dimension ``embed_dim``.
1196            - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
1197              returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
1198              :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
1199              :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
1200              head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
1201
1202            .. note::
1203                `batch_first` argument is ignored for unbatched inputs.
1204        """  # noqa: B950
1205        why_not_fast_path = ""
1206        if (
1207            (attn_mask is not None and torch.is_floating_point(attn_mask))
1208            or (key_padding_mask is not None)
1209            and torch.is_floating_point(key_padding_mask)
1210        ):
1211            why_not_fast_path = "floating-point masks are not supported for fast path."
1212
1213        is_batched = query.dim() == 3
1214
1215        key_padding_mask = F._canonical_mask(
1216            mask=key_padding_mask,
1217            mask_name="key_padding_mask",
1218            other_type=F._none_or_dtype(attn_mask),
1219            other_name="attn_mask",
1220            target_type=query.dtype,
1221        )
1222
1223        attn_mask = F._canonical_mask(
1224            mask=attn_mask,
1225            mask_name="attn_mask",
1226            other_type=None,
1227            other_name="",
1228            target_type=query.dtype,
1229            check_other=False,
1230        )
1231
1232        is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled()
1233
1234        if not is_fastpath_enabled:
1235            why_not_fast_path = "torch.backends.mha.get_fastpath_enabled() was not True"
1236        elif not is_batched:
1237            why_not_fast_path = (
1238                f"input not batched; expected query.dim() of 3 but got {query.dim()}"
1239            )
1240        elif query is not key or key is not value:
1241            # When lifting this restriction, don't forget to either
1242            # enforce that the dtypes all match or test cases where
1243            # they don't!
1244            why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
1245        elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
1246            why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
1247        elif self.in_proj_weight is None:
1248            why_not_fast_path = "in_proj_weight was None"
1249        elif query.dtype != self.in_proj_weight.dtype:
1250            # this case will fail anyway, but at least they'll get a useful error message.
1251            why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
1252        elif self.training:
1253            why_not_fast_path = "training is enabled"
1254        elif (self.num_heads % 2) != 0:
1255            why_not_fast_path = "self.num_heads is not even"
1256        elif not self.batch_first:
1257            why_not_fast_path = "batch_first was not True"
1258        elif self.bias_k is not None:
1259            why_not_fast_path = "self.bias_k was not None"
1260        elif self.bias_v is not None:
1261            why_not_fast_path = "self.bias_v was not None"
1262        elif self.add_zero_attn:
1263            why_not_fast_path = "add_zero_attn was enabled"
1264        elif not self._qkv_same_embed_dim:
1265            why_not_fast_path = "_qkv_same_embed_dim was not True"
1266        elif query.is_nested and (
1267            key_padding_mask is not None or attn_mask is not None
1268        ):
1269            why_not_fast_path = "supplying both src_key_padding_mask and src_mask at the same time \
1270                                 is not supported with NestedTensor input"
1271        elif torch.is_autocast_enabled():
1272            why_not_fast_path = "autocast is enabled"
1273
1274        if not why_not_fast_path:
1275            tensor_args = (
1276                query,
1277                key,
1278                value,
1279                self.in_proj_weight,
1280                self.in_proj_bias,
1281                self.out_proj.weight,
1282                self.out_proj.bias,
1283            )
1284            # We have to use list comprehensions below because TorchScript does not support
1285            # generator expressions.
1286            if torch.overrides.has_torch_function(tensor_args):
1287                why_not_fast_path = "some Tensor argument has_torch_function"
1288            elif _is_make_fx_tracing():
1289                why_not_fast_path = "we are running make_fx tracing"
1290            elif not all(_check_arg_device(x) for x in tensor_args):
1291                why_not_fast_path = (
1292                    "some Tensor argument's device is neither one of "
1293                    f"cpu, cuda or {torch.utils.backend_registration._privateuse1_backend_name}"
1294                )
1295            elif torch.is_grad_enabled() and any(
1296                _arg_requires_grad(x) for x in tensor_args
1297            ):
1298                why_not_fast_path = (
1299                    "grad is enabled and at least one of query or the "
1300                    "input/output projection weights or biases requires_grad"
1301                )
1302            if not why_not_fast_path:
1303                merged_mask, mask_type = self.merge_masks(
1304                    attn_mask, key_padding_mask, query
1305                )
1306
1307                if self.in_proj_bias is not None and self.in_proj_weight is not None:
1308                    return torch._native_multi_head_attention(
1309                        query,
1310                        key,
1311                        value,
1312                        self.embed_dim,
1313                        self.num_heads,
1314                        self.in_proj_weight,
1315                        self.in_proj_bias,
1316                        self.out_proj.weight,
1317                        self.out_proj.bias,
1318                        merged_mask,
1319                        need_weights,
1320                        average_attn_weights,
1321                        mask_type,
1322                    )
1323
1324        any_nested = query.is_nested or key.is_nested or value.is_nested
1325        assert not any_nested, (
1326            "MultiheadAttention does not support NestedTensor outside of its fast path. "
1327            + f"The fast path was not hit because {why_not_fast_path}"
1328        )
1329
1330        if self.batch_first and is_batched:
1331            # make sure that the transpose op does not affect the "is" property
1332            if key is value:
1333                if query is key:
1334                    query = key = value = query.transpose(1, 0)
1335                else:
1336                    query, key = (x.transpose(1, 0) for x in (query, key))
1337                    value = key
1338            else:
1339                query, key, value = (x.transpose(1, 0) for x in (query, key, value))
1340
1341        if not self._qkv_same_embed_dim:
1342            attn_output, attn_output_weights = F.multi_head_attention_forward(
1343                query,
1344                key,
1345                value,
1346                self.embed_dim,
1347                self.num_heads,
1348                self.in_proj_weight,
1349                self.in_proj_bias,
1350                self.bias_k,
1351                self.bias_v,
1352                self.add_zero_attn,
1353                self.dropout,
1354                self.out_proj.weight,
1355                self.out_proj.bias,
1356                training=self.training,
1357                key_padding_mask=key_padding_mask,
1358                need_weights=need_weights,
1359                attn_mask=attn_mask,
1360                use_separate_proj_weight=True,
1361                q_proj_weight=self.q_proj_weight,
1362                k_proj_weight=self.k_proj_weight,
1363                v_proj_weight=self.v_proj_weight,
1364                average_attn_weights=average_attn_weights,
1365                is_causal=is_causal,
1366            )
1367        else:
1368            attn_output, attn_output_weights = F.multi_head_attention_forward(
1369                query,
1370                key,
1371                value,
1372                self.embed_dim,
1373                self.num_heads,
1374                self.in_proj_weight,
1375                self.in_proj_bias,
1376                self.bias_k,
1377                self.bias_v,
1378                self.add_zero_attn,
1379                self.dropout,
1380                self.out_proj.weight,
1381                self.out_proj.bias,
1382                training=self.training,
1383                key_padding_mask=key_padding_mask,
1384                need_weights=need_weights,
1385                attn_mask=attn_mask,
1386                average_attn_weights=average_attn_weights,
1387                is_causal=is_causal,
1388            )
1389        if self.batch_first and is_batched:
1390            return attn_output.transpose(1, 0), attn_output_weights
1391        else:
1392            return attn_output, attn_output_weights
1393
1394    def merge_masks(
1395        self,
1396        attn_mask: Optional[Tensor],
1397        key_padding_mask: Optional[Tensor],
1398        query: Tensor,
1399    ) -> Tuple[Optional[Tensor], Optional[int]]:
1400        r"""Determine mask type and combine masks if necessary.
1401
1402        If only one mask is provided, that mask
1403        and the corresponding mask type will be returned. If both masks are provided, they will be both
1404        expanded to shape ``(batch_size, num_heads, seq_len, seq_len)``, combined with logical ``or``
1405        and mask type 2 will be returned
1406        Args:
1407            attn_mask: attention mask of shape ``(seq_len, seq_len)``, mask type 0
1408            key_padding_mask: padding mask of shape ``(batch_size, seq_len)``, mask type 1
1409            query: query embeddings of shape ``(batch_size, seq_len, embed_dim)``
1410        Returns:
1411            merged_mask: merged mask
1412            mask_type: merged mask type (0, 1, or 2)
1413        """
1414        mask_type: Optional[int] = None
1415        merged_mask: Optional[Tensor] = None
1416
1417        if key_padding_mask is not None:
1418            mask_type = 1
1419            merged_mask = key_padding_mask
1420
1421        if attn_mask is not None:
1422            # In this branch query can't be a nested tensor, so it has a shape
1423            batch_size, seq_len, _ = query.shape
1424            mask_type = 2
1425
1426            # Always expands attn_mask to 4D
1427            if attn_mask.dim() == 3:
1428                attn_mask_expanded = attn_mask.view(batch_size, -1, seq_len, seq_len)
1429            else:  # attn_mask.dim() == 2:
1430                attn_mask_expanded = attn_mask.view(1, 1, seq_len, seq_len).expand(
1431                    batch_size, self.num_heads, -1, -1
1432                )
1433            merged_mask = attn_mask_expanded
1434
1435            if key_padding_mask is not None:
1436                key_padding_mask_expanded = key_padding_mask.view(
1437                    batch_size, 1, 1, seq_len
1438                ).expand(-1, self.num_heads, -1, -1)
1439                merged_mask = attn_mask_expanded + key_padding_mask_expanded
1440
1441        # no attn_mask and no key_padding_mask, returns None, None
1442        return merged_mask, mask_type
1443
1444
1445class PReLU(Module):
1446    r"""Applies the element-wise PReLU function.
1447
1448    .. math::
1449        \text{PReLU}(x) = \max(0,x) + a * \min(0,x)
1450
1451    or
1452
1453    .. math::
1454        \text{PReLU}(x) =
1455        \begin{cases}
1456        x, & \text{ if } x \ge 0 \\
1457        ax, & \text{ otherwise }
1458        \end{cases}
1459
1460    Here :math:`a` is a learnable parameter. When called without arguments, `nn.PReLU()` uses a single
1461    parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`,
1462    a separate :math:`a` is used for each input channel.
1463
1464
1465    .. note::
1466        weight decay should not be used when learning :math:`a` for good performance.
1467
1468    .. note::
1469        Channel dim is the 2nd dim of input. When input has dims < 2, then there is
1470        no channel dim and the number of channels = 1.
1471
1472    Args:
1473        num_parameters (int): number of :math:`a` to learn.
1474            Although it takes an int as input, there is only two values are legitimate:
1475            1, or the number of channels at input. Default: 1
1476        init (float): the initial value of :math:`a`. Default: 0.25
1477
1478    Shape:
1479        - Input: :math:`( *)` where `*` means, any number of additional
1480          dimensions.
1481        - Output: :math:`(*)`, same shape as the input.
1482
1483    Attributes:
1484        weight (Tensor): the learnable weights of shape (:attr:`num_parameters`).
1485
1486    .. image:: ../scripts/activation_images/PReLU.png
1487
1488    Examples::
1489
1490        >>> m = nn.PReLU()
1491        >>> input = torch.randn(2)
1492        >>> output = m(input)
1493    """
1494
1495    __constants__ = ["num_parameters"]
1496    num_parameters: int
1497
1498    def __init__(
1499        self, num_parameters: int = 1, init: float = 0.25, device=None, dtype=None
1500    ) -> None:
1501        factory_kwargs = {"device": device, "dtype": dtype}
1502        self.num_parameters = num_parameters
1503        super().__init__()
1504        self.init = init
1505        self.weight = Parameter(torch.empty(num_parameters, **factory_kwargs))
1506        self.reset_parameters()
1507
1508    def reset_parameters(self):
1509        torch.nn.init.constant_(self.weight, self.init)
1510
1511    def forward(self, input: Tensor) -> Tensor:
1512        return F.prelu(input, self.weight)
1513
1514    def extra_repr(self) -> str:
1515        return f"num_parameters={self.num_parameters}"
1516
1517
1518class Softsign(Module):
1519    r"""Applies the element-wise Softsign function.
1520
1521    .. math::
1522        \text{SoftSign}(x) = \frac{x}{ 1 + |x|}
1523
1524    Shape:
1525        - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
1526        - Output: :math:`(*)`, same shape as the input.
1527
1528    .. image:: ../scripts/activation_images/Softsign.png
1529
1530    Examples::
1531
1532        >>> m = nn.Softsign()
1533        >>> input = torch.randn(2)
1534        >>> output = m(input)
1535    """
1536
1537    def forward(self, input: Tensor) -> Tensor:
1538        return F.softsign(input)
1539
1540
1541class Tanhshrink(Module):
1542    r"""Applies the element-wise Tanhshrink function.
1543
1544    .. math::
1545        \text{Tanhshrink}(x) = x - \tanh(x)
1546
1547    Shape:
1548        - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
1549        - Output: :math:`(*)`, same shape as the input.
1550
1551    .. image:: ../scripts/activation_images/Tanhshrink.png
1552
1553    Examples::
1554
1555        >>> m = nn.Tanhshrink()
1556        >>> input = torch.randn(2)
1557        >>> output = m(input)
1558    """
1559
1560    def forward(self, input: Tensor) -> Tensor:
1561        return F.tanhshrink(input)
1562
1563
1564class Softmin(Module):
1565    r"""Applies the Softmin function to an n-dimensional input Tensor.
1566
1567    Rescales them so that the elements of the n-dimensional output Tensor
1568    lie in the range `[0, 1]` and sum to 1.
1569
1570    Softmin is defined as:
1571
1572    .. math::
1573        \text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)}
1574
1575    Shape:
1576        - Input: :math:`(*)` where `*` means, any number of additional
1577          dimensions
1578        - Output: :math:`(*)`, same shape as the input
1579
1580    Args:
1581        dim (int): A dimension along which Softmin will be computed (so every slice
1582            along dim will sum to 1).
1583
1584    Returns:
1585        a Tensor of the same dimension and shape as the input, with
1586        values in the range [0, 1]
1587
1588    Examples::
1589
1590        >>> m = nn.Softmin(dim=1)
1591        >>> input = torch.randn(2, 3)
1592        >>> output = m(input)
1593    """
1594
1595    __constants__ = ["dim"]
1596    dim: Optional[int]
1597
1598    def __init__(self, dim: Optional[int] = None) -> None:
1599        super().__init__()
1600        self.dim = dim
1601
1602    def __setstate__(self, state):
1603        super().__setstate__(state)
1604        if not hasattr(self, "dim"):
1605            self.dim = None
1606
1607    def forward(self, input: Tensor) -> Tensor:
1608        return F.softmin(input, self.dim, _stacklevel=5)
1609
1610    def extra_repr(self):
1611        return f"dim={self.dim}"
1612
1613
1614class Softmax(Module):
1615    r"""Applies the Softmax function to an n-dimensional input Tensor.
1616
1617    Rescales them so that the elements of the n-dimensional output Tensor
1618    lie in the range [0,1] and sum to 1.
1619
1620    Softmax is defined as:
1621
1622    .. math::
1623        \text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
1624
1625    When the input Tensor is a sparse tensor then the unspecified
1626    values are treated as ``-inf``.
1627
1628    Shape:
1629        - Input: :math:`(*)` where `*` means, any number of additional
1630          dimensions
1631        - Output: :math:`(*)`, same shape as the input
1632
1633    Returns:
1634        a Tensor of the same dimension and shape as the input with
1635        values in the range [0, 1]
1636
1637    Args:
1638        dim (int): A dimension along which Softmax will be computed (so every slice
1639            along dim will sum to 1).
1640
1641    .. note::
1642        This module doesn't work directly with NLLLoss,
1643        which expects the Log to be computed between the Softmax and itself.
1644        Use `LogSoftmax` instead (it's faster and has better numerical properties).
1645
1646    Examples::
1647
1648        >>> m = nn.Softmax(dim=1)
1649        >>> input = torch.randn(2, 3)
1650        >>> output = m(input)
1651
1652    """
1653
1654    __constants__ = ["dim"]
1655    dim: Optional[int]
1656
1657    def __init__(self, dim: Optional[int] = None) -> None:
1658        super().__init__()
1659        self.dim = dim
1660
1661    def __setstate__(self, state):
1662        super().__setstate__(state)
1663        if not hasattr(self, "dim"):
1664            self.dim = None
1665
1666    def forward(self, input: Tensor) -> Tensor:
1667        return F.softmax(input, self.dim, _stacklevel=5)
1668
1669    def extra_repr(self) -> str:
1670        return f"dim={self.dim}"
1671
1672
1673class Softmax2d(Module):
1674    r"""Applies SoftMax over features to each spatial location.
1675
1676    When given an image of ``Channels x Height x Width``, it will
1677    apply `Softmax` to each location :math:`(Channels, h_i, w_j)`
1678
1679    Shape:
1680        - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`.
1681        - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input)
1682
1683    Returns:
1684        a Tensor of the same dimension and shape as the input with
1685        values in the range [0, 1]
1686
1687    Examples::
1688
1689        >>> m = nn.Softmax2d()
1690        >>> # you softmax over the 2nd dimension
1691        >>> input = torch.randn(2, 3, 12, 13)
1692        >>> output = m(input)
1693    """
1694
1695    def forward(self, input: Tensor) -> Tensor:
1696        if input.dim() not in (3, 4):
1697            raise ValueError(
1698                f"Softmax2d: expected input to be 3D or 4D, got {input.dim()}D instead"
1699            )
1700        return F.softmax(input, -3, _stacklevel=5)
1701
1702
1703class LogSoftmax(Module):
1704    r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional input Tensor.
1705
1706    The LogSoftmax formulation can be simplified as:
1707
1708    .. math::
1709        \text{LogSoftmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right)
1710
1711    Shape:
1712        - Input: :math:`(*)` where `*` means, any number of additional
1713          dimensions
1714        - Output: :math:`(*)`, same shape as the input
1715
1716    Args:
1717        dim (int): A dimension along which LogSoftmax will be computed.
1718
1719    Returns:
1720        a Tensor of the same dimension and shape as the input with
1721        values in the range [-inf, 0)
1722
1723    Examples::
1724
1725        >>> m = nn.LogSoftmax(dim=1)
1726        >>> input = torch.randn(2, 3)
1727        >>> output = m(input)
1728    """
1729
1730    __constants__ = ["dim"]
1731    dim: Optional[int]
1732
1733    def __init__(self, dim: Optional[int] = None) -> None:
1734        super().__init__()
1735        self.dim = dim
1736
1737    def __setstate__(self, state):
1738        super().__setstate__(state)
1739        if not hasattr(self, "dim"):
1740            self.dim = None
1741
1742    def forward(self, input: Tensor) -> Tensor:
1743        return F.log_softmax(input, self.dim, _stacklevel=5)
1744
1745    def extra_repr(self):
1746        return f"dim={self.dim}"
1747