xref: /aosp_15_r20/external/pytorch/torch/nn/modules/transformer.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import copy
3import warnings
4from typing import Any, Callable, Optional, Union
5
6import torch
7import torch.nn.functional as F
8from torch import Tensor
9from torch.nn.init import xavier_uniform_
10
11from .activation import MultiheadAttention
12from .container import ModuleList
13from .dropout import Dropout
14from .linear import Linear
15from .module import Module
16from .normalization import LayerNorm
17
18
19__all__ = [
20    "Transformer",
21    "TransformerEncoder",
22    "TransformerDecoder",
23    "TransformerEncoderLayer",
24    "TransformerDecoderLayer",
25]
26
27
28def _generate_square_subsequent_mask(
29    sz: int,
30    device: Optional[torch.device] = None,
31    dtype: Optional[torch.dtype] = None,
32) -> Tensor:
33    r"""Generate a square causal mask for the sequence.
34
35    The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0).
36    """
37    if device is None:
38        device = torch.device("cpu")
39    if dtype is None:
40        dtype = torch.float32
41    return torch.triu(
42        torch.full((sz, sz), float("-inf"), dtype=dtype, device=device),
43        diagonal=1,
44    )
45
46
47def _get_seq_len(src: Tensor, batch_first: bool) -> Optional[int]:
48    if src.is_nested:
49        return None
50    else:
51        src_size = src.size()
52        if len(src_size) == 2:
53            # unbatched: S, E
54            return src_size[0]
55        else:
56            # batched: B, S, E if batch_first else S, B, E
57            seq_len_pos = 1 if batch_first else 0
58            return src_size[seq_len_pos]
59
60
61class Transformer(Module):
62    r"""A transformer model.
63
64    User is able to modify the attributes as needed. The architecture
65    is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
66    Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
67    Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information
68    Processing Systems, pages 6000-6010.
69
70    Args:
71        d_model: the number of expected features in the encoder/decoder inputs (default=512).
72        nhead: the number of heads in the multiheadattention models (default=8).
73        num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
74        num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
75        dim_feedforward: the dimension of the feedforward network model (default=2048).
76        dropout: the dropout value (default=0.1).
77        activation: the activation function of encoder/decoder intermediate layer, can be a string
78            ("relu" or "gelu") or a unary callable. Default: relu
79        custom_encoder: custom encoder (default=None).
80        custom_decoder: custom decoder (default=None).
81        layer_norm_eps: the eps value in layer normalization components (default=1e-5).
82        batch_first: If ``True``, then the input and output tensors are provided
83            as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
84        norm_first: if ``True``, encoder and decoder layers will perform LayerNorms before
85            other attention and feedforward operations, otherwise after. Default: ``False`` (after).
86        bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive
87            bias. Default: ``True``.
88
89    Examples::
90        >>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
91        >>> src = torch.rand((10, 32, 512))
92        >>> tgt = torch.rand((20, 32, 512))
93        >>> out = transformer_model(src, tgt)
94
95    Note: A full example to apply nn.Transformer module for the word language model is available in
96    https://github.com/pytorch/examples/tree/master/word_language_model
97    """
98
99    def __init__(
100        self,
101        d_model: int = 512,
102        nhead: int = 8,
103        num_encoder_layers: int = 6,
104        num_decoder_layers: int = 6,
105        dim_feedforward: int = 2048,
106        dropout: float = 0.1,
107        activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
108        custom_encoder: Optional[Any] = None,
109        custom_decoder: Optional[Any] = None,
110        layer_norm_eps: float = 1e-5,
111        batch_first: bool = False,
112        norm_first: bool = False,
113        bias: bool = True,
114        device=None,
115        dtype=None,
116    ) -> None:
117        factory_kwargs = {"device": device, "dtype": dtype}
118        super().__init__()
119        torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
120
121        if custom_encoder is not None:
122            self.encoder = custom_encoder
123        else:
124            encoder_layer = TransformerEncoderLayer(
125                d_model,
126                nhead,
127                dim_feedforward,
128                dropout,
129                activation,
130                layer_norm_eps,
131                batch_first,
132                norm_first,
133                bias,
134                **factory_kwargs,
135            )
136            encoder_norm = LayerNorm(
137                d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs
138            )
139            self.encoder = TransformerEncoder(
140                encoder_layer, num_encoder_layers, encoder_norm
141            )
142
143        if custom_decoder is not None:
144            self.decoder = custom_decoder
145        else:
146            decoder_layer = TransformerDecoderLayer(
147                d_model,
148                nhead,
149                dim_feedforward,
150                dropout,
151                activation,
152                layer_norm_eps,
153                batch_first,
154                norm_first,
155                bias,
156                **factory_kwargs,
157            )
158            decoder_norm = LayerNorm(
159                d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs
160            )
161            self.decoder = TransformerDecoder(
162                decoder_layer, num_decoder_layers, decoder_norm
163            )
164
165        self._reset_parameters()
166
167        self.d_model = d_model
168        self.nhead = nhead
169
170        self.batch_first = batch_first
171
172    def forward(
173        self,
174        src: Tensor,
175        tgt: Tensor,
176        src_mask: Optional[Tensor] = None,
177        tgt_mask: Optional[Tensor] = None,
178        memory_mask: Optional[Tensor] = None,
179        src_key_padding_mask: Optional[Tensor] = None,
180        tgt_key_padding_mask: Optional[Tensor] = None,
181        memory_key_padding_mask: Optional[Tensor] = None,
182        src_is_causal: Optional[bool] = None,
183        tgt_is_causal: Optional[bool] = None,
184        memory_is_causal: bool = False,
185    ) -> Tensor:
186        r"""Take in and process masked source/target sequences.
187
188        .. note::
189
190            If a boolean tensor is provided for any of the [src/tgt/memory]_mask arguments, positions with a ``True`` value are
191            not allowed to participate in the attention,
192            which is the opposite of the definition for :attr:`attn_mask`
193            in :func:`torch.nn.functional.scaled_dot_product_attention`.
194
195        Args:
196            src: the sequence to the encoder (required).
197            tgt: the sequence to the decoder (required).
198            src_mask: the additive mask for the src sequence (optional).
199            tgt_mask: the additive mask for the tgt sequence (optional).
200            memory_mask: the additive mask for the encoder output (optional).
201            src_key_padding_mask: the Tensor mask for src keys per batch (optional).
202            tgt_key_padding_mask: the Tensor mask for tgt keys per batch (optional).
203            memory_key_padding_mask: the Tensor mask for memory keys per batch (optional).
204            src_is_causal: If specified, applies a causal mask as ``src_mask``.
205                Default: ``None``; try to detect a causal mask.
206                Warning:
207                ``src_is_causal`` provides a hint that ``src_mask`` is
208                the causal mask. Providing incorrect hints can result in
209                incorrect execution, including forward and backward
210                compatibility.
211            tgt_is_causal: If specified, applies a causal mask as ``tgt_mask``.
212                Default: ``None``; try to detect a causal mask.
213                Warning:
214                ``tgt_is_causal`` provides a hint that ``tgt_mask`` is
215                the causal mask. Providing incorrect hints can result in
216                incorrect execution, including forward and backward
217                compatibility.
218            memory_is_causal: If specified, applies a causal mask as
219                ``memory_mask``.
220                Default: ``False``.
221                Warning:
222                ``memory_is_causal`` provides a hint that
223                ``memory_mask`` is the causal mask. Providing incorrect
224                hints can result in incorrect execution, including
225                forward and backward compatibility.
226
227        Shape:
228            - src: :math:`(S, E)` for unbatched input, :math:`(S, N, E)` if `batch_first=False` or
229              `(N, S, E)` if `batch_first=True`.
230            - tgt: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or
231              `(N, T, E)` if `batch_first=True`.
232            - src_mask: :math:`(S, S)` or :math:`(N\cdot\text{num\_heads}, S, S)`.
233            - tgt_mask: :math:`(T, T)` or :math:`(N\cdot\text{num\_heads}, T, T)`.
234            - memory_mask: :math:`(T, S)`.
235            - src_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`.
236            - tgt_key_padding_mask: :math:`(T)` for unbatched input otherwise :math:`(N, T)`.
237            - memory_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`.
238
239            Note: [src/tgt/memory]_mask ensures that position :math:`i` is allowed to attend the unmasked
240            positions. If a BoolTensor is provided, positions with ``True``
241            are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
242            is provided, it will be added to the attention weight.
243            [src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by
244            the attention. If a BoolTensor is provided, the positions with the
245            value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
246
247            - output: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or
248              `(N, T, E)` if `batch_first=True`.
249
250            Note: Due to the multi-head attention architecture in the transformer model,
251            the output sequence length of a transformer is same as the input sequence
252            (i.e. target) length of the decoder.
253
254            where :math:`S` is the source sequence length, :math:`T` is the target sequence length, :math:`N` is the
255            batch size, :math:`E` is the feature number
256
257        Examples:
258            >>> # xdoctest: +SKIP
259            >>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
260        """
261        is_batched = src.dim() == 3
262        if not self.batch_first and src.size(1) != tgt.size(1) and is_batched:
263            raise RuntimeError("the batch number of src and tgt must be equal")
264        elif self.batch_first and src.size(0) != tgt.size(0) and is_batched:
265            raise RuntimeError("the batch number of src and tgt must be equal")
266
267        if src.size(-1) != self.d_model or tgt.size(-1) != self.d_model:
268            raise RuntimeError(
269                "the feature number of src and tgt must be equal to d_model"
270            )
271
272        memory = self.encoder(
273            src,
274            mask=src_mask,
275            src_key_padding_mask=src_key_padding_mask,
276            is_causal=src_is_causal,
277        )
278        output = self.decoder(
279            tgt,
280            memory,
281            tgt_mask=tgt_mask,
282            memory_mask=memory_mask,
283            tgt_key_padding_mask=tgt_key_padding_mask,
284            memory_key_padding_mask=memory_key_padding_mask,
285            tgt_is_causal=tgt_is_causal,
286            memory_is_causal=memory_is_causal,
287        )
288        return output
289
290    @staticmethod
291    def generate_square_subsequent_mask(
292        sz: int,
293        device: Optional[torch.device] = None,
294        dtype: Optional[torch.dtype] = None,
295    ) -> Tensor:
296        r"""Generate a square causal mask for the sequence.
297
298        The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0).
299        """
300        return _generate_square_subsequent_mask(sz, dtype=dtype, device=device)
301
302    def _reset_parameters(self):
303        r"""Initiate parameters in the transformer model."""
304        for p in self.parameters():
305            if p.dim() > 1:
306                xavier_uniform_(p)
307
308
309class TransformerEncoder(Module):
310    r"""TransformerEncoder is a stack of N encoder layers.
311
312    Users can build the BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
313
314    Args:
315        encoder_layer: an instance of the TransformerEncoderLayer() class (required).
316        num_layers: the number of sub-encoder-layers in the encoder (required).
317        norm: the layer normalization component (optional).
318        enable_nested_tensor: if True, input will automatically convert to nested tensor
319            (and convert back on output). This will improve the overall performance of
320            TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
321
322    Examples::
323        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
324        >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
325        >>> src = torch.rand(10, 32, 512)
326        >>> out = transformer_encoder(src)
327    """
328
329    __constants__ = ["norm"]
330
331    def __init__(
332        self,
333        encoder_layer: "TransformerEncoderLayer",
334        num_layers: int,
335        norm: Optional[Module] = None,
336        enable_nested_tensor: bool = True,
337        mask_check: bool = True,
338    ) -> None:
339        super().__init__()
340        torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
341        self.layers = _get_clones(encoder_layer, num_layers)
342        self.num_layers = num_layers
343        self.norm = norm
344        # this attribute saves the value providedat object construction
345        self.enable_nested_tensor = enable_nested_tensor
346        # this attribute controls whether nested tensors are used
347        self.use_nested_tensor = enable_nested_tensor
348        self.mask_check = mask_check
349
350        enc_layer = "encoder_layer"
351        why_not_sparsity_fast_path = ""
352        if not isinstance(encoder_layer, torch.nn.TransformerEncoderLayer):
353            why_not_sparsity_fast_path = f"{enc_layer} was not TransformerEncoderLayer"
354        elif encoder_layer.norm_first:
355            why_not_sparsity_fast_path = f"{enc_layer}.norm_first was True"
356        elif not encoder_layer.self_attn.batch_first:
357            why_not_sparsity_fast_path = (
358                f"{enc_layer}.self_attn.batch_first was not True"
359                + "(use batch_first for better inference performance)"
360            )
361        elif not encoder_layer.self_attn._qkv_same_embed_dim:
362            why_not_sparsity_fast_path = (
363                f"{enc_layer}.self_attn._qkv_same_embed_dim was not True"
364            )
365        elif encoder_layer.self_attn.in_proj_bias is None:
366            why_not_sparsity_fast_path = f"{enc_layer}.self_attn was passed bias=False"
367        elif not encoder_layer.activation_relu_or_gelu:
368            why_not_sparsity_fast_path = (
369                f"{enc_layer}.activation_relu_or_gelu was not True"
370            )
371        elif not (encoder_layer.norm1.eps == encoder_layer.norm2.eps):
372            why_not_sparsity_fast_path = (
373                f"{enc_layer}.norm1.eps was not equal to {enc_layer}.norm2.eps"
374            )
375        elif encoder_layer.self_attn.num_heads % 2 == 1:
376            why_not_sparsity_fast_path = f"{enc_layer}.self_attn.num_heads is odd"
377
378        if enable_nested_tensor and why_not_sparsity_fast_path:
379            warnings.warn(
380                f"enable_nested_tensor is True, but self.use_nested_tensor is False because {why_not_sparsity_fast_path}"
381            )
382            self.use_nested_tensor = False
383
384    def forward(
385        self,
386        src: Tensor,
387        mask: Optional[Tensor] = None,
388        src_key_padding_mask: Optional[Tensor] = None,
389        is_causal: Optional[bool] = None,
390    ) -> Tensor:
391        r"""Pass the input through the encoder layers in turn.
392
393        Args:
394            src: the sequence to the encoder (required).
395            mask: the mask for the src sequence (optional).
396            src_key_padding_mask: the mask for the src keys per batch (optional).
397            is_causal: If specified, applies a causal mask as ``mask``.
398                Default: ``None``; try to detect a causal mask.
399                Warning:
400                ``is_causal`` provides a hint that ``mask`` is the
401                causal mask. Providing incorrect hints can result in
402                incorrect execution, including forward and backward
403                compatibility.
404
405        Shape:
406            see the docs in :class:`~torch.nn.Transformer`.
407        """
408        src_key_padding_mask = F._canonical_mask(
409            mask=src_key_padding_mask,
410            mask_name="src_key_padding_mask",
411            other_type=F._none_or_dtype(mask),
412            other_name="mask",
413            target_type=src.dtype,
414        )
415
416        mask = F._canonical_mask(
417            mask=mask,
418            mask_name="mask",
419            other_type=None,
420            other_name="",
421            target_type=src.dtype,
422            check_other=False,
423        )
424
425        output = src
426        convert_to_nested = False
427        first_layer = self.layers[0]
428        src_key_padding_mask_for_layers = src_key_padding_mask
429        why_not_sparsity_fast_path = ""
430        str_first_layer = "self.layers[0]"
431        batch_first = first_layer.self_attn.batch_first
432        is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled()
433
434        if not is_fastpath_enabled:
435            why_not_sparsity_fast_path = (
436                "torch.backends.mha.get_fastpath_enabled() was not True"
437            )
438        elif not hasattr(self, "use_nested_tensor"):
439            why_not_sparsity_fast_path = "use_nested_tensor attribute not present"
440        elif not self.use_nested_tensor:
441            why_not_sparsity_fast_path = (
442                "self.use_nested_tensor (set in init) was not True"
443            )
444        elif first_layer.training:
445            why_not_sparsity_fast_path = f"{str_first_layer} was in training mode"
446        elif not src.dim() == 3:
447            why_not_sparsity_fast_path = (
448                f"input not batched; expected src.dim() of 3 but got {src.dim()}"
449            )
450        elif src_key_padding_mask is None:
451            why_not_sparsity_fast_path = "src_key_padding_mask was None"
452        elif (
453            (not hasattr(self, "mask_check")) or self.mask_check
454        ) and not torch._nested_tensor_from_mask_left_aligned(
455            src, src_key_padding_mask.logical_not()
456        ):
457            why_not_sparsity_fast_path = "mask_check enabled, and src and src_key_padding_mask was not left aligned"
458        elif output.is_nested:
459            why_not_sparsity_fast_path = "NestedTensor input is not supported"
460        elif mask is not None:
461            why_not_sparsity_fast_path = (
462                "src_key_padding_mask and mask were both supplied"
463            )
464        elif torch.is_autocast_enabled():
465            why_not_sparsity_fast_path = "autocast is enabled"
466
467        if not why_not_sparsity_fast_path:
468            tensor_args = (
469                src,
470                first_layer.self_attn.in_proj_weight,
471                first_layer.self_attn.in_proj_bias,
472                first_layer.self_attn.out_proj.weight,
473                first_layer.self_attn.out_proj.bias,
474                first_layer.norm1.weight,
475                first_layer.norm1.bias,
476                first_layer.norm2.weight,
477                first_layer.norm2.bias,
478                first_layer.linear1.weight,
479                first_layer.linear1.bias,
480                first_layer.linear2.weight,
481                first_layer.linear2.bias,
482            )
483            _supported_device_type = [
484                "cpu",
485                "cuda",
486                torch.utils.backend_registration._privateuse1_backend_name,
487            ]
488            if torch.overrides.has_torch_function(tensor_args):
489                why_not_sparsity_fast_path = "some Tensor argument has_torch_function"
490            elif src.device.type not in _supported_device_type:
491                why_not_sparsity_fast_path = (
492                    f"src device is neither one of {_supported_device_type}"
493                )
494            elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args):
495                why_not_sparsity_fast_path = (
496                    "grad is enabled and at least one of query or the "
497                    "input/output projection weights or biases requires_grad"
498                )
499
500            if (not why_not_sparsity_fast_path) and (src_key_padding_mask is not None):
501                convert_to_nested = True
502                output = torch._nested_tensor_from_mask(
503                    output, src_key_padding_mask.logical_not(), mask_check=False
504                )
505                src_key_padding_mask_for_layers = None
506
507        seq_len = _get_seq_len(src, batch_first)
508        is_causal = _detect_is_causal_mask(mask, is_causal, seq_len)
509
510        for mod in self.layers:
511            output = mod(
512                output,
513                src_mask=mask,
514                is_causal=is_causal,
515                src_key_padding_mask=src_key_padding_mask_for_layers,
516            )
517
518        if convert_to_nested:
519            output = output.to_padded_tensor(0.0, src.size())
520
521        if self.norm is not None:
522            output = self.norm(output)
523
524        return output
525
526
527class TransformerDecoder(Module):
528    r"""TransformerDecoder is a stack of N decoder layers.
529
530    Args:
531        decoder_layer: an instance of the TransformerDecoderLayer() class (required).
532        num_layers: the number of sub-decoder-layers in the decoder (required).
533        norm: the layer normalization component (optional).
534
535    Examples::
536        >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
537        >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
538        >>> memory = torch.rand(10, 32, 512)
539        >>> tgt = torch.rand(20, 32, 512)
540        >>> out = transformer_decoder(tgt, memory)
541    """
542
543    __constants__ = ["norm"]
544
545    def __init__(
546        self,
547        decoder_layer: "TransformerDecoderLayer",
548        num_layers: int,
549        norm: Optional[Module] = None,
550    ) -> None:
551        super().__init__()
552        torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
553        self.layers = _get_clones(decoder_layer, num_layers)
554        self.num_layers = num_layers
555        self.norm = norm
556
557    def forward(
558        self,
559        tgt: Tensor,
560        memory: Tensor,
561        tgt_mask: Optional[Tensor] = None,
562        memory_mask: Optional[Tensor] = None,
563        tgt_key_padding_mask: Optional[Tensor] = None,
564        memory_key_padding_mask: Optional[Tensor] = None,
565        tgt_is_causal: Optional[bool] = None,
566        memory_is_causal: bool = False,
567    ) -> Tensor:
568        r"""Pass the inputs (and mask) through the decoder layer in turn.
569
570        Args:
571            tgt: the sequence to the decoder (required).
572            memory: the sequence from the last layer of the encoder (required).
573            tgt_mask: the mask for the tgt sequence (optional).
574            memory_mask: the mask for the memory sequence (optional).
575            tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
576            memory_key_padding_mask: the mask for the memory keys per batch (optional).
577            tgt_is_causal: If specified, applies a causal mask as ``tgt mask``.
578                Default: ``None``; try to detect a causal mask.
579                Warning:
580                ``tgt_is_causal`` provides a hint that ``tgt_mask`` is
581                the causal mask. Providing incorrect hints can result in
582                incorrect execution, including forward and backward
583                compatibility.
584            memory_is_causal: If specified, applies a causal mask as
585                ``memory mask``.
586                Default: ``False``.
587                Warning:
588                ``memory_is_causal`` provides a hint that
589                ``memory_mask`` is the causal mask. Providing incorrect
590                hints can result in incorrect execution, including
591                forward and backward compatibility.
592
593        Shape:
594            see the docs in :class:`~torch.nn.Transformer`.
595        """
596        output = tgt
597
598        seq_len = _get_seq_len(tgt, self.layers[0].self_attn.batch_first)
599        tgt_is_causal = _detect_is_causal_mask(tgt_mask, tgt_is_causal, seq_len)
600
601        for mod in self.layers:
602            output = mod(
603                output,
604                memory,
605                tgt_mask=tgt_mask,
606                memory_mask=memory_mask,
607                tgt_key_padding_mask=tgt_key_padding_mask,
608                memory_key_padding_mask=memory_key_padding_mask,
609                tgt_is_causal=tgt_is_causal,
610                memory_is_causal=memory_is_causal,
611            )
612
613        if self.norm is not None:
614            output = self.norm(output)
615
616        return output
617
618
619class TransformerEncoderLayer(Module):
620    r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
621
622    This standard encoder layer is based on the paper "Attention Is All You Need".
623    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
624    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
625    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
626    in a different way during application.
627
628    TransformerEncoderLayer can handle either traditional torch.tensor inputs,
629    or Nested Tensor inputs.  Derived classes are expected to similarly accept
630    both input formats.  (Not all combinations of inputs are currently
631    supported by TransformerEncoderLayer while Nested Tensor is in prototype
632    state.)
633
634    If you are implementing a custom layer, you may derive it either from
635    the Module or TransformerEncoderLayer class.  If your custom layer
636    supports both torch.Tensors and Nested Tensors inputs, make its
637    implementation a derived class of TransformerEncoderLayer. If your custom
638    Layer supports only torch.Tensor inputs, derive its implementation from
639    Module.
640
641    Args:
642        d_model: the number of expected features in the input (required).
643        nhead: the number of heads in the multiheadattention models (required).
644        dim_feedforward: the dimension of the feedforward network model (default=2048).
645        dropout: the dropout value (default=0.1).
646        activation: the activation function of the intermediate layer, can be a string
647            ("relu" or "gelu") or a unary callable. Default: relu
648        layer_norm_eps: the eps value in layer normalization components (default=1e-5).
649        batch_first: If ``True``, then the input and output tensors are provided
650            as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
651        norm_first: if ``True``, layer norm is done prior to attention and feedforward
652            operations, respectively. Otherwise it's done after. Default: ``False`` (after).
653        bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive
654            bias. Default: ``True``.
655
656    Examples::
657        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
658        >>> src = torch.rand(10, 32, 512)
659        >>> out = encoder_layer(src)
660
661    Alternatively, when ``batch_first`` is ``True``:
662        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
663        >>> src = torch.rand(32, 10, 512)
664        >>> out = encoder_layer(src)
665
666    Fast path:
667        forward() will use a special optimized implementation described in
668        `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`_ if all of the following
669        conditions are met:
670
671        - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor
672          argument ``requires_grad``
673        - training is disabled (using ``.eval()``)
674        - batch_first is ``True`` and the input is batched (i.e., ``src.dim() == 3``)
675        - activation is one of: ``"relu"``, ``"gelu"``, ``torch.functional.relu``, or ``torch.functional.gelu``
676        - at most one of ``src_mask`` and ``src_key_padding_mask`` is passed
677        - if src is a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_, neither ``src_mask``
678          nor ``src_key_padding_mask`` is passed
679        - the two ``LayerNorm`` instances have a consistent ``eps`` value (this will naturally be the case
680          unless the caller has manually modified one without modifying the other)
681
682        If the optimized implementation is in use, a
683        `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be
684        passed for ``src`` to represent padding more efficiently than using a padding
685        mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ will be
686        returned, and an additional speedup proportional to the fraction of the input that
687        is padding can be expected.
688
689        .. _`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`:
690         https://arxiv.org/abs/2205.14135
691
692    """
693
694    __constants__ = ["norm_first"]
695
696    def __init__(
697        self,
698        d_model: int,
699        nhead: int,
700        dim_feedforward: int = 2048,
701        dropout: float = 0.1,
702        activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
703        layer_norm_eps: float = 1e-5,
704        batch_first: bool = False,
705        norm_first: bool = False,
706        bias: bool = True,
707        device=None,
708        dtype=None,
709    ) -> None:
710        factory_kwargs = {"device": device, "dtype": dtype}
711        super().__init__()
712        self.self_attn = MultiheadAttention(
713            d_model,
714            nhead,
715            dropout=dropout,
716            bias=bias,
717            batch_first=batch_first,
718            **factory_kwargs,
719        )
720        # Implementation of Feedforward model
721        self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs)
722        self.dropout = Dropout(dropout)
723        self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)
724
725        self.norm_first = norm_first
726        self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
727        self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
728        self.dropout1 = Dropout(dropout)
729        self.dropout2 = Dropout(dropout)
730
731        # Legacy string support for activation function.
732        if isinstance(activation, str):
733            activation = _get_activation_fn(activation)
734
735        # We can't test self.activation in forward() in TorchScript,
736        # so stash some information about it instead.
737        if activation is F.relu or isinstance(activation, torch.nn.ReLU):
738            self.activation_relu_or_gelu = 1
739        elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
740            self.activation_relu_or_gelu = 2
741        else:
742            self.activation_relu_or_gelu = 0
743        self.activation = activation
744
745    def __setstate__(self, state):
746        super().__setstate__(state)
747        if not hasattr(self, "activation"):
748            self.activation = F.relu
749
750    def forward(
751        self,
752        src: Tensor,
753        src_mask: Optional[Tensor] = None,
754        src_key_padding_mask: Optional[Tensor] = None,
755        is_causal: bool = False,
756    ) -> Tensor:
757        r"""Pass the input through the encoder layer.
758
759        Args:
760            src: the sequence to the encoder layer (required).
761            src_mask: the mask for the src sequence (optional).
762            src_key_padding_mask: the mask for the src keys per batch (optional).
763            is_causal: If specified, applies a causal mask as ``src mask``.
764                Default: ``False``.
765                Warning:
766                ``is_causal`` provides a hint that ``src_mask`` is the
767                causal mask. Providing incorrect hints can result in
768                incorrect execution, including forward and backward
769                compatibility.
770
771        Shape:
772            see the docs in :class:`~torch.nn.Transformer`.
773        """
774        src_key_padding_mask = F._canonical_mask(
775            mask=src_key_padding_mask,
776            mask_name="src_key_padding_mask",
777            other_type=F._none_or_dtype(src_mask),
778            other_name="src_mask",
779            target_type=src.dtype,
780        )
781
782        src_mask = F._canonical_mask(
783            mask=src_mask,
784            mask_name="src_mask",
785            other_type=None,
786            other_name="",
787            target_type=src.dtype,
788            check_other=False,
789        )
790
791        is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled()
792
793        why_not_sparsity_fast_path = ""
794        if not is_fastpath_enabled:
795            why_not_sparsity_fast_path = (
796                "torch.backends.mha.get_fastpath_enabled() was not True"
797            )
798        elif not src.dim() == 3:
799            why_not_sparsity_fast_path = (
800                f"input not batched; expected src.dim() of 3 but got {src.dim()}"
801            )
802        elif self.training:
803            why_not_sparsity_fast_path = "training is enabled"
804        elif not self.self_attn.batch_first:
805            why_not_sparsity_fast_path = "self_attn.batch_first was not True"
806        elif self.self_attn.in_proj_bias is None:
807            why_not_sparsity_fast_path = "self_attn was passed bias=False"
808        elif not self.self_attn._qkv_same_embed_dim:
809            why_not_sparsity_fast_path = "self_attn._qkv_same_embed_dim was not True"
810        elif not self.activation_relu_or_gelu:
811            why_not_sparsity_fast_path = "activation_relu_or_gelu was not True"
812        elif not (self.norm1.eps == self.norm2.eps):
813            why_not_sparsity_fast_path = "norm1.eps is not equal to norm2.eps"
814        elif src.is_nested and (
815            src_key_padding_mask is not None or src_mask is not None
816        ):
817            why_not_sparsity_fast_path = "neither src_key_padding_mask nor src_mask are not supported with NestedTensor input"
818        elif self.self_attn.num_heads % 2 == 1:
819            why_not_sparsity_fast_path = "num_head is odd"
820        elif torch.is_autocast_enabled():
821            why_not_sparsity_fast_path = "autocast is enabled"
822        elif any(
823            len(getattr(m, "_forward_hooks", {}))
824            + len(getattr(m, "_forward_pre_hooks", {}))
825            for m in self.modules()
826        ):
827            why_not_sparsity_fast_path = "forward pre-/hooks are attached to the module"
828        if not why_not_sparsity_fast_path:
829            tensor_args = (
830                src,
831                self.self_attn.in_proj_weight,
832                self.self_attn.in_proj_bias,
833                self.self_attn.out_proj.weight,
834                self.self_attn.out_proj.bias,
835                self.norm1.weight,
836                self.norm1.bias,
837                self.norm2.weight,
838                self.norm2.bias,
839                self.linear1.weight,
840                self.linear1.bias,
841                self.linear2.weight,
842                self.linear2.bias,
843            )
844
845            # We have to use list comprehensions below because TorchScript does not support
846            # generator expressions.
847            _supported_device_type = [
848                "cpu",
849                "cuda",
850                torch.utils.backend_registration._privateuse1_backend_name,
851            ]
852            if torch.overrides.has_torch_function(tensor_args):
853                why_not_sparsity_fast_path = "some Tensor argument has_torch_function"
854            elif not all(
855                (x.device.type in _supported_device_type) for x in tensor_args
856            ):
857                why_not_sparsity_fast_path = (
858                    "some Tensor argument's device is neither one of "
859                    f"{_supported_device_type}"
860                )
861            elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args):
862                why_not_sparsity_fast_path = (
863                    "grad is enabled and at least one of query or the "
864                    "input/output projection weights or biases requires_grad"
865                )
866
867            if not why_not_sparsity_fast_path:
868                merged_mask, mask_type = self.self_attn.merge_masks(
869                    src_mask, src_key_padding_mask, src
870                )
871                return torch._transformer_encoder_layer_fwd(
872                    src,
873                    self.self_attn.embed_dim,
874                    self.self_attn.num_heads,
875                    self.self_attn.in_proj_weight,
876                    self.self_attn.in_proj_bias,
877                    self.self_attn.out_proj.weight,
878                    self.self_attn.out_proj.bias,
879                    self.activation_relu_or_gelu == 2,
880                    self.norm_first,
881                    self.norm1.eps,
882                    self.norm1.weight,
883                    self.norm1.bias,
884                    self.norm2.weight,
885                    self.norm2.bias,
886                    self.linear1.weight,
887                    self.linear1.bias,
888                    self.linear2.weight,
889                    self.linear2.bias,
890                    merged_mask,
891                    mask_type,
892                )
893
894        # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
895        x = src
896        if self.norm_first:
897            x = x + self._sa_block(
898                self.norm1(x), src_mask, src_key_padding_mask, is_causal=is_causal
899            )
900            x = x + self._ff_block(self.norm2(x))
901        else:
902            x = self.norm1(
903                x
904                + self._sa_block(x, src_mask, src_key_padding_mask, is_causal=is_causal)
905            )
906            x = self.norm2(x + self._ff_block(x))
907
908        return x
909
910    # self-attention block
911    def _sa_block(
912        self,
913        x: Tensor,
914        attn_mask: Optional[Tensor],
915        key_padding_mask: Optional[Tensor],
916        is_causal: bool = False,
917    ) -> Tensor:
918        x = self.self_attn(
919            x,
920            x,
921            x,
922            attn_mask=attn_mask,
923            key_padding_mask=key_padding_mask,
924            need_weights=False,
925            is_causal=is_causal,
926        )[0]
927        return self.dropout1(x)
928
929    # feed forward block
930    def _ff_block(self, x: Tensor) -> Tensor:
931        x = self.linear2(self.dropout(self.activation(self.linear1(x))))
932        return self.dropout2(x)
933
934
935class TransformerDecoderLayer(Module):
936    r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
937
938    This standard decoder layer is based on the paper "Attention Is All You Need".
939    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
940    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
941    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
942    in a different way during application.
943
944    Args:
945        d_model: the number of expected features in the input (required).
946        nhead: the number of heads in the multiheadattention models (required).
947        dim_feedforward: the dimension of the feedforward network model (default=2048).
948        dropout: the dropout value (default=0.1).
949        activation: the activation function of the intermediate layer, can be a string
950            ("relu" or "gelu") or a unary callable. Default: relu
951        layer_norm_eps: the eps value in layer normalization components (default=1e-5).
952        batch_first: If ``True``, then the input and output tensors are provided
953            as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
954        norm_first: if ``True``, layer norm is done prior to self attention, multihead
955            attention and feedforward operations, respectively. Otherwise it's done after.
956            Default: ``False`` (after).
957        bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive
958            bias. Default: ``True``.
959
960    Examples::
961        >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
962        >>> memory = torch.rand(10, 32, 512)
963        >>> tgt = torch.rand(20, 32, 512)
964        >>> out = decoder_layer(tgt, memory)
965
966    Alternatively, when ``batch_first`` is ``True``:
967        >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=True)
968        >>> memory = torch.rand(32, 10, 512)
969        >>> tgt = torch.rand(32, 20, 512)
970        >>> out = decoder_layer(tgt, memory)
971    """
972
973    __constants__ = ["norm_first"]
974
975    def __init__(
976        self,
977        d_model: int,
978        nhead: int,
979        dim_feedforward: int = 2048,
980        dropout: float = 0.1,
981        activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
982        layer_norm_eps: float = 1e-5,
983        batch_first: bool = False,
984        norm_first: bool = False,
985        bias: bool = True,
986        device=None,
987        dtype=None,
988    ) -> None:
989        factory_kwargs = {"device": device, "dtype": dtype}
990        super().__init__()
991        self.self_attn = MultiheadAttention(
992            d_model,
993            nhead,
994            dropout=dropout,
995            batch_first=batch_first,
996            bias=bias,
997            **factory_kwargs,
998        )
999        self.multihead_attn = MultiheadAttention(
1000            d_model,
1001            nhead,
1002            dropout=dropout,
1003            batch_first=batch_first,
1004            bias=bias,
1005            **factory_kwargs,
1006        )
1007        # Implementation of Feedforward model
1008        self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs)
1009        self.dropout = Dropout(dropout)
1010        self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)
1011
1012        self.norm_first = norm_first
1013        self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
1014        self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
1015        self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
1016        self.dropout1 = Dropout(dropout)
1017        self.dropout2 = Dropout(dropout)
1018        self.dropout3 = Dropout(dropout)
1019
1020        # Legacy string support for activation function.
1021        if isinstance(activation, str):
1022            self.activation = _get_activation_fn(activation)
1023        else:
1024            self.activation = activation
1025
1026    def __setstate__(self, state):
1027        if "activation" not in state:
1028            state["activation"] = F.relu
1029        super().__setstate__(state)
1030
1031    def forward(
1032        self,
1033        tgt: Tensor,
1034        memory: Tensor,
1035        tgt_mask: Optional[Tensor] = None,
1036        memory_mask: Optional[Tensor] = None,
1037        tgt_key_padding_mask: Optional[Tensor] = None,
1038        memory_key_padding_mask: Optional[Tensor] = None,
1039        tgt_is_causal: bool = False,
1040        memory_is_causal: bool = False,
1041    ) -> Tensor:
1042        r"""Pass the inputs (and mask) through the decoder layer.
1043
1044        Args:
1045            tgt: the sequence to the decoder layer (required).
1046            memory: the sequence from the last layer of the encoder (required).
1047            tgt_mask: the mask for the tgt sequence (optional).
1048            memory_mask: the mask for the memory sequence (optional).
1049            tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
1050            memory_key_padding_mask: the mask for the memory keys per batch (optional).
1051            tgt_is_causal: If specified, applies a causal mask as ``tgt mask``.
1052                Default: ``False``.
1053                Warning:
1054                ``tgt_is_causal`` provides a hint that ``tgt_mask`` is
1055                the causal mask. Providing incorrect hints can result in
1056                incorrect execution, including forward and backward
1057                compatibility.
1058            memory_is_causal: If specified, applies a causal mask as
1059                ``memory mask``.
1060                Default: ``False``.
1061                Warning:
1062                ``memory_is_causal`` provides a hint that
1063                ``memory_mask`` is the causal mask. Providing incorrect
1064                hints can result in incorrect execution, including
1065                forward and backward compatibility.
1066
1067        Shape:
1068            see the docs in :class:`~torch.nn.Transformer`.
1069        """
1070        # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
1071
1072        x = tgt
1073        if self.norm_first:
1074            x = x + self._sa_block(
1075                self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal
1076            )
1077            x = x + self._mha_block(
1078                self.norm2(x),
1079                memory,
1080                memory_mask,
1081                memory_key_padding_mask,
1082                memory_is_causal,
1083            )
1084            x = x + self._ff_block(self.norm3(x))
1085        else:
1086            x = self.norm1(
1087                x + self._sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal)
1088            )
1089            x = self.norm2(
1090                x
1091                + self._mha_block(
1092                    x, memory, memory_mask, memory_key_padding_mask, memory_is_causal
1093                )
1094            )
1095            x = self.norm3(x + self._ff_block(x))
1096
1097        return x
1098
1099    # self-attention block
1100    def _sa_block(
1101        self,
1102        x: Tensor,
1103        attn_mask: Optional[Tensor],
1104        key_padding_mask: Optional[Tensor],
1105        is_causal: bool = False,
1106    ) -> Tensor:
1107        x = self.self_attn(
1108            x,
1109            x,
1110            x,
1111            attn_mask=attn_mask,
1112            key_padding_mask=key_padding_mask,
1113            is_causal=is_causal,
1114            need_weights=False,
1115        )[0]
1116        return self.dropout1(x)
1117
1118    # multihead attention block
1119    def _mha_block(
1120        self,
1121        x: Tensor,
1122        mem: Tensor,
1123        attn_mask: Optional[Tensor],
1124        key_padding_mask: Optional[Tensor],
1125        is_causal: bool = False,
1126    ) -> Tensor:
1127        x = self.multihead_attn(
1128            x,
1129            mem,
1130            mem,
1131            attn_mask=attn_mask,
1132            key_padding_mask=key_padding_mask,
1133            is_causal=is_causal,
1134            need_weights=False,
1135        )[0]
1136        return self.dropout2(x)
1137
1138    # feed forward block
1139    def _ff_block(self, x: Tensor) -> Tensor:
1140        x = self.linear2(self.dropout(self.activation(self.linear1(x))))
1141        return self.dropout3(x)
1142
1143
1144def _get_clones(module, N):
1145    # FIXME: copy.deepcopy() is not defined on nn.module
1146    return ModuleList([copy.deepcopy(module) for i in range(N)])
1147
1148
1149def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
1150    if activation == "relu":
1151        return F.relu
1152    elif activation == "gelu":
1153        return F.gelu
1154
1155    raise RuntimeError(f"activation should be relu/gelu, not {activation}")
1156
1157
1158def _detect_is_causal_mask(
1159    mask: Optional[Tensor],
1160    is_causal: Optional[bool] = None,
1161    size: Optional[int] = None,
1162) -> bool:
1163    """Return whether the given attention mask is causal.
1164
1165    Warning:
1166    If ``is_causal`` is not ``None``, its value will be returned as is.  If a
1167    user supplies an incorrect ``is_causal`` hint,
1168
1169    ``is_causal=False`` when the mask is in fact a causal attention.mask
1170       may lead to reduced performance relative to what would be achievable
1171       with ``is_causal=True``;
1172    ``is_causal=True`` when the mask is in fact not a causal attention.mask
1173       may lead to incorrect and unpredictable execution - in some scenarios,
1174       a causal mask may be applied based on the hint, in other execution
1175       scenarios the specified mask may be used.  The choice may not appear
1176       to be deterministic, in that a number of factors like alignment,
1177       hardware SKU, etc influence the decision whether to use a mask or
1178       rely on the hint.
1179    ``size`` if not None, check whether the mask is a causal mask of the provided size
1180       Otherwise, checks for any causal mask.
1181    """
1182    # Prevent type refinement
1183    make_causal = is_causal is True
1184
1185    if is_causal is None and mask is not None:
1186        sz = size if size is not None else mask.size(-2)
1187        causal_comparison = _generate_square_subsequent_mask(
1188            sz, device=mask.device, dtype=mask.dtype
1189        )
1190
1191        # Do not use `torch.equal` so we handle batched masks by
1192        # broadcasting the comparison.
1193        if mask.size() == causal_comparison.size():
1194            make_causal = bool((mask == causal_comparison).all())
1195        else:
1196            make_causal = False
1197
1198    return make_causal
1199