1# mypy: allow-untyped-defs 2import copy 3import warnings 4from typing import Any, Callable, Optional, Union 5 6import torch 7import torch.nn.functional as F 8from torch import Tensor 9from torch.nn.init import xavier_uniform_ 10 11from .activation import MultiheadAttention 12from .container import ModuleList 13from .dropout import Dropout 14from .linear import Linear 15from .module import Module 16from .normalization import LayerNorm 17 18 19__all__ = [ 20 "Transformer", 21 "TransformerEncoder", 22 "TransformerDecoder", 23 "TransformerEncoderLayer", 24 "TransformerDecoderLayer", 25] 26 27 28def _generate_square_subsequent_mask( 29 sz: int, 30 device: Optional[torch.device] = None, 31 dtype: Optional[torch.dtype] = None, 32) -> Tensor: 33 r"""Generate a square causal mask for the sequence. 34 35 The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0). 36 """ 37 if device is None: 38 device = torch.device("cpu") 39 if dtype is None: 40 dtype = torch.float32 41 return torch.triu( 42 torch.full((sz, sz), float("-inf"), dtype=dtype, device=device), 43 diagonal=1, 44 ) 45 46 47def _get_seq_len(src: Tensor, batch_first: bool) -> Optional[int]: 48 if src.is_nested: 49 return None 50 else: 51 src_size = src.size() 52 if len(src_size) == 2: 53 # unbatched: S, E 54 return src_size[0] 55 else: 56 # batched: B, S, E if batch_first else S, B, E 57 seq_len_pos = 1 if batch_first else 0 58 return src_size[seq_len_pos] 59 60 61class Transformer(Module): 62 r"""A transformer model. 63 64 User is able to modify the attributes as needed. The architecture 65 is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer, 66 Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and 67 Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information 68 Processing Systems, pages 6000-6010. 69 70 Args: 71 d_model: the number of expected features in the encoder/decoder inputs (default=512). 72 nhead: the number of heads in the multiheadattention models (default=8). 73 num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6). 74 num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6). 75 dim_feedforward: the dimension of the feedforward network model (default=2048). 76 dropout: the dropout value (default=0.1). 77 activation: the activation function of encoder/decoder intermediate layer, can be a string 78 ("relu" or "gelu") or a unary callable. Default: relu 79 custom_encoder: custom encoder (default=None). 80 custom_decoder: custom decoder (default=None). 81 layer_norm_eps: the eps value in layer normalization components (default=1e-5). 82 batch_first: If ``True``, then the input and output tensors are provided 83 as (batch, seq, feature). Default: ``False`` (seq, batch, feature). 84 norm_first: if ``True``, encoder and decoder layers will perform LayerNorms before 85 other attention and feedforward operations, otherwise after. Default: ``False`` (after). 86 bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive 87 bias. Default: ``True``. 88 89 Examples:: 90 >>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12) 91 >>> src = torch.rand((10, 32, 512)) 92 >>> tgt = torch.rand((20, 32, 512)) 93 >>> out = transformer_model(src, tgt) 94 95 Note: A full example to apply nn.Transformer module for the word language model is available in 96 https://github.com/pytorch/examples/tree/master/word_language_model 97 """ 98 99 def __init__( 100 self, 101 d_model: int = 512, 102 nhead: int = 8, 103 num_encoder_layers: int = 6, 104 num_decoder_layers: int = 6, 105 dim_feedforward: int = 2048, 106 dropout: float = 0.1, 107 activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, 108 custom_encoder: Optional[Any] = None, 109 custom_decoder: Optional[Any] = None, 110 layer_norm_eps: float = 1e-5, 111 batch_first: bool = False, 112 norm_first: bool = False, 113 bias: bool = True, 114 device=None, 115 dtype=None, 116 ) -> None: 117 factory_kwargs = {"device": device, "dtype": dtype} 118 super().__init__() 119 torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}") 120 121 if custom_encoder is not None: 122 self.encoder = custom_encoder 123 else: 124 encoder_layer = TransformerEncoderLayer( 125 d_model, 126 nhead, 127 dim_feedforward, 128 dropout, 129 activation, 130 layer_norm_eps, 131 batch_first, 132 norm_first, 133 bias, 134 **factory_kwargs, 135 ) 136 encoder_norm = LayerNorm( 137 d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs 138 ) 139 self.encoder = TransformerEncoder( 140 encoder_layer, num_encoder_layers, encoder_norm 141 ) 142 143 if custom_decoder is not None: 144 self.decoder = custom_decoder 145 else: 146 decoder_layer = TransformerDecoderLayer( 147 d_model, 148 nhead, 149 dim_feedforward, 150 dropout, 151 activation, 152 layer_norm_eps, 153 batch_first, 154 norm_first, 155 bias, 156 **factory_kwargs, 157 ) 158 decoder_norm = LayerNorm( 159 d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs 160 ) 161 self.decoder = TransformerDecoder( 162 decoder_layer, num_decoder_layers, decoder_norm 163 ) 164 165 self._reset_parameters() 166 167 self.d_model = d_model 168 self.nhead = nhead 169 170 self.batch_first = batch_first 171 172 def forward( 173 self, 174 src: Tensor, 175 tgt: Tensor, 176 src_mask: Optional[Tensor] = None, 177 tgt_mask: Optional[Tensor] = None, 178 memory_mask: Optional[Tensor] = None, 179 src_key_padding_mask: Optional[Tensor] = None, 180 tgt_key_padding_mask: Optional[Tensor] = None, 181 memory_key_padding_mask: Optional[Tensor] = None, 182 src_is_causal: Optional[bool] = None, 183 tgt_is_causal: Optional[bool] = None, 184 memory_is_causal: bool = False, 185 ) -> Tensor: 186 r"""Take in and process masked source/target sequences. 187 188 .. note:: 189 190 If a boolean tensor is provided for any of the [src/tgt/memory]_mask arguments, positions with a ``True`` value are 191 not allowed to participate in the attention, 192 which is the opposite of the definition for :attr:`attn_mask` 193 in :func:`torch.nn.functional.scaled_dot_product_attention`. 194 195 Args: 196 src: the sequence to the encoder (required). 197 tgt: the sequence to the decoder (required). 198 src_mask: the additive mask for the src sequence (optional). 199 tgt_mask: the additive mask for the tgt sequence (optional). 200 memory_mask: the additive mask for the encoder output (optional). 201 src_key_padding_mask: the Tensor mask for src keys per batch (optional). 202 tgt_key_padding_mask: the Tensor mask for tgt keys per batch (optional). 203 memory_key_padding_mask: the Tensor mask for memory keys per batch (optional). 204 src_is_causal: If specified, applies a causal mask as ``src_mask``. 205 Default: ``None``; try to detect a causal mask. 206 Warning: 207 ``src_is_causal`` provides a hint that ``src_mask`` is 208 the causal mask. Providing incorrect hints can result in 209 incorrect execution, including forward and backward 210 compatibility. 211 tgt_is_causal: If specified, applies a causal mask as ``tgt_mask``. 212 Default: ``None``; try to detect a causal mask. 213 Warning: 214 ``tgt_is_causal`` provides a hint that ``tgt_mask`` is 215 the causal mask. Providing incorrect hints can result in 216 incorrect execution, including forward and backward 217 compatibility. 218 memory_is_causal: If specified, applies a causal mask as 219 ``memory_mask``. 220 Default: ``False``. 221 Warning: 222 ``memory_is_causal`` provides a hint that 223 ``memory_mask`` is the causal mask. Providing incorrect 224 hints can result in incorrect execution, including 225 forward and backward compatibility. 226 227 Shape: 228 - src: :math:`(S, E)` for unbatched input, :math:`(S, N, E)` if `batch_first=False` or 229 `(N, S, E)` if `batch_first=True`. 230 - tgt: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or 231 `(N, T, E)` if `batch_first=True`. 232 - src_mask: :math:`(S, S)` or :math:`(N\cdot\text{num\_heads}, S, S)`. 233 - tgt_mask: :math:`(T, T)` or :math:`(N\cdot\text{num\_heads}, T, T)`. 234 - memory_mask: :math:`(T, S)`. 235 - src_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`. 236 - tgt_key_padding_mask: :math:`(T)` for unbatched input otherwise :math:`(N, T)`. 237 - memory_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`. 238 239 Note: [src/tgt/memory]_mask ensures that position :math:`i` is allowed to attend the unmasked 240 positions. If a BoolTensor is provided, positions with ``True`` 241 are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor 242 is provided, it will be added to the attention weight. 243 [src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by 244 the attention. If a BoolTensor is provided, the positions with the 245 value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. 246 247 - output: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or 248 `(N, T, E)` if `batch_first=True`. 249 250 Note: Due to the multi-head attention architecture in the transformer model, 251 the output sequence length of a transformer is same as the input sequence 252 (i.e. target) length of the decoder. 253 254 where :math:`S` is the source sequence length, :math:`T` is the target sequence length, :math:`N` is the 255 batch size, :math:`E` is the feature number 256 257 Examples: 258 >>> # xdoctest: +SKIP 259 >>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask) 260 """ 261 is_batched = src.dim() == 3 262 if not self.batch_first and src.size(1) != tgt.size(1) and is_batched: 263 raise RuntimeError("the batch number of src and tgt must be equal") 264 elif self.batch_first and src.size(0) != tgt.size(0) and is_batched: 265 raise RuntimeError("the batch number of src and tgt must be equal") 266 267 if src.size(-1) != self.d_model or tgt.size(-1) != self.d_model: 268 raise RuntimeError( 269 "the feature number of src and tgt must be equal to d_model" 270 ) 271 272 memory = self.encoder( 273 src, 274 mask=src_mask, 275 src_key_padding_mask=src_key_padding_mask, 276 is_causal=src_is_causal, 277 ) 278 output = self.decoder( 279 tgt, 280 memory, 281 tgt_mask=tgt_mask, 282 memory_mask=memory_mask, 283 tgt_key_padding_mask=tgt_key_padding_mask, 284 memory_key_padding_mask=memory_key_padding_mask, 285 tgt_is_causal=tgt_is_causal, 286 memory_is_causal=memory_is_causal, 287 ) 288 return output 289 290 @staticmethod 291 def generate_square_subsequent_mask( 292 sz: int, 293 device: Optional[torch.device] = None, 294 dtype: Optional[torch.dtype] = None, 295 ) -> Tensor: 296 r"""Generate a square causal mask for the sequence. 297 298 The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0). 299 """ 300 return _generate_square_subsequent_mask(sz, dtype=dtype, device=device) 301 302 def _reset_parameters(self): 303 r"""Initiate parameters in the transformer model.""" 304 for p in self.parameters(): 305 if p.dim() > 1: 306 xavier_uniform_(p) 307 308 309class TransformerEncoder(Module): 310 r"""TransformerEncoder is a stack of N encoder layers. 311 312 Users can build the BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters. 313 314 Args: 315 encoder_layer: an instance of the TransformerEncoderLayer() class (required). 316 num_layers: the number of sub-encoder-layers in the encoder (required). 317 norm: the layer normalization component (optional). 318 enable_nested_tensor: if True, input will automatically convert to nested tensor 319 (and convert back on output). This will improve the overall performance of 320 TransformerEncoder when padding rate is high. Default: ``True`` (enabled). 321 322 Examples:: 323 >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) 324 >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6) 325 >>> src = torch.rand(10, 32, 512) 326 >>> out = transformer_encoder(src) 327 """ 328 329 __constants__ = ["norm"] 330 331 def __init__( 332 self, 333 encoder_layer: "TransformerEncoderLayer", 334 num_layers: int, 335 norm: Optional[Module] = None, 336 enable_nested_tensor: bool = True, 337 mask_check: bool = True, 338 ) -> None: 339 super().__init__() 340 torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}") 341 self.layers = _get_clones(encoder_layer, num_layers) 342 self.num_layers = num_layers 343 self.norm = norm 344 # this attribute saves the value providedat object construction 345 self.enable_nested_tensor = enable_nested_tensor 346 # this attribute controls whether nested tensors are used 347 self.use_nested_tensor = enable_nested_tensor 348 self.mask_check = mask_check 349 350 enc_layer = "encoder_layer" 351 why_not_sparsity_fast_path = "" 352 if not isinstance(encoder_layer, torch.nn.TransformerEncoderLayer): 353 why_not_sparsity_fast_path = f"{enc_layer} was not TransformerEncoderLayer" 354 elif encoder_layer.norm_first: 355 why_not_sparsity_fast_path = f"{enc_layer}.norm_first was True" 356 elif not encoder_layer.self_attn.batch_first: 357 why_not_sparsity_fast_path = ( 358 f"{enc_layer}.self_attn.batch_first was not True" 359 + "(use batch_first for better inference performance)" 360 ) 361 elif not encoder_layer.self_attn._qkv_same_embed_dim: 362 why_not_sparsity_fast_path = ( 363 f"{enc_layer}.self_attn._qkv_same_embed_dim was not True" 364 ) 365 elif encoder_layer.self_attn.in_proj_bias is None: 366 why_not_sparsity_fast_path = f"{enc_layer}.self_attn was passed bias=False" 367 elif not encoder_layer.activation_relu_or_gelu: 368 why_not_sparsity_fast_path = ( 369 f"{enc_layer}.activation_relu_or_gelu was not True" 370 ) 371 elif not (encoder_layer.norm1.eps == encoder_layer.norm2.eps): 372 why_not_sparsity_fast_path = ( 373 f"{enc_layer}.norm1.eps was not equal to {enc_layer}.norm2.eps" 374 ) 375 elif encoder_layer.self_attn.num_heads % 2 == 1: 376 why_not_sparsity_fast_path = f"{enc_layer}.self_attn.num_heads is odd" 377 378 if enable_nested_tensor and why_not_sparsity_fast_path: 379 warnings.warn( 380 f"enable_nested_tensor is True, but self.use_nested_tensor is False because {why_not_sparsity_fast_path}" 381 ) 382 self.use_nested_tensor = False 383 384 def forward( 385 self, 386 src: Tensor, 387 mask: Optional[Tensor] = None, 388 src_key_padding_mask: Optional[Tensor] = None, 389 is_causal: Optional[bool] = None, 390 ) -> Tensor: 391 r"""Pass the input through the encoder layers in turn. 392 393 Args: 394 src: the sequence to the encoder (required). 395 mask: the mask for the src sequence (optional). 396 src_key_padding_mask: the mask for the src keys per batch (optional). 397 is_causal: If specified, applies a causal mask as ``mask``. 398 Default: ``None``; try to detect a causal mask. 399 Warning: 400 ``is_causal`` provides a hint that ``mask`` is the 401 causal mask. Providing incorrect hints can result in 402 incorrect execution, including forward and backward 403 compatibility. 404 405 Shape: 406 see the docs in :class:`~torch.nn.Transformer`. 407 """ 408 src_key_padding_mask = F._canonical_mask( 409 mask=src_key_padding_mask, 410 mask_name="src_key_padding_mask", 411 other_type=F._none_or_dtype(mask), 412 other_name="mask", 413 target_type=src.dtype, 414 ) 415 416 mask = F._canonical_mask( 417 mask=mask, 418 mask_name="mask", 419 other_type=None, 420 other_name="", 421 target_type=src.dtype, 422 check_other=False, 423 ) 424 425 output = src 426 convert_to_nested = False 427 first_layer = self.layers[0] 428 src_key_padding_mask_for_layers = src_key_padding_mask 429 why_not_sparsity_fast_path = "" 430 str_first_layer = "self.layers[0]" 431 batch_first = first_layer.self_attn.batch_first 432 is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled() 433 434 if not is_fastpath_enabled: 435 why_not_sparsity_fast_path = ( 436 "torch.backends.mha.get_fastpath_enabled() was not True" 437 ) 438 elif not hasattr(self, "use_nested_tensor"): 439 why_not_sparsity_fast_path = "use_nested_tensor attribute not present" 440 elif not self.use_nested_tensor: 441 why_not_sparsity_fast_path = ( 442 "self.use_nested_tensor (set in init) was not True" 443 ) 444 elif first_layer.training: 445 why_not_sparsity_fast_path = f"{str_first_layer} was in training mode" 446 elif not src.dim() == 3: 447 why_not_sparsity_fast_path = ( 448 f"input not batched; expected src.dim() of 3 but got {src.dim()}" 449 ) 450 elif src_key_padding_mask is None: 451 why_not_sparsity_fast_path = "src_key_padding_mask was None" 452 elif ( 453 (not hasattr(self, "mask_check")) or self.mask_check 454 ) and not torch._nested_tensor_from_mask_left_aligned( 455 src, src_key_padding_mask.logical_not() 456 ): 457 why_not_sparsity_fast_path = "mask_check enabled, and src and src_key_padding_mask was not left aligned" 458 elif output.is_nested: 459 why_not_sparsity_fast_path = "NestedTensor input is not supported" 460 elif mask is not None: 461 why_not_sparsity_fast_path = ( 462 "src_key_padding_mask and mask were both supplied" 463 ) 464 elif torch.is_autocast_enabled(): 465 why_not_sparsity_fast_path = "autocast is enabled" 466 467 if not why_not_sparsity_fast_path: 468 tensor_args = ( 469 src, 470 first_layer.self_attn.in_proj_weight, 471 first_layer.self_attn.in_proj_bias, 472 first_layer.self_attn.out_proj.weight, 473 first_layer.self_attn.out_proj.bias, 474 first_layer.norm1.weight, 475 first_layer.norm1.bias, 476 first_layer.norm2.weight, 477 first_layer.norm2.bias, 478 first_layer.linear1.weight, 479 first_layer.linear1.bias, 480 first_layer.linear2.weight, 481 first_layer.linear2.bias, 482 ) 483 _supported_device_type = [ 484 "cpu", 485 "cuda", 486 torch.utils.backend_registration._privateuse1_backend_name, 487 ] 488 if torch.overrides.has_torch_function(tensor_args): 489 why_not_sparsity_fast_path = "some Tensor argument has_torch_function" 490 elif src.device.type not in _supported_device_type: 491 why_not_sparsity_fast_path = ( 492 f"src device is neither one of {_supported_device_type}" 493 ) 494 elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args): 495 why_not_sparsity_fast_path = ( 496 "grad is enabled and at least one of query or the " 497 "input/output projection weights or biases requires_grad" 498 ) 499 500 if (not why_not_sparsity_fast_path) and (src_key_padding_mask is not None): 501 convert_to_nested = True 502 output = torch._nested_tensor_from_mask( 503 output, src_key_padding_mask.logical_not(), mask_check=False 504 ) 505 src_key_padding_mask_for_layers = None 506 507 seq_len = _get_seq_len(src, batch_first) 508 is_causal = _detect_is_causal_mask(mask, is_causal, seq_len) 509 510 for mod in self.layers: 511 output = mod( 512 output, 513 src_mask=mask, 514 is_causal=is_causal, 515 src_key_padding_mask=src_key_padding_mask_for_layers, 516 ) 517 518 if convert_to_nested: 519 output = output.to_padded_tensor(0.0, src.size()) 520 521 if self.norm is not None: 522 output = self.norm(output) 523 524 return output 525 526 527class TransformerDecoder(Module): 528 r"""TransformerDecoder is a stack of N decoder layers. 529 530 Args: 531 decoder_layer: an instance of the TransformerDecoderLayer() class (required). 532 num_layers: the number of sub-decoder-layers in the decoder (required). 533 norm: the layer normalization component (optional). 534 535 Examples:: 536 >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) 537 >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6) 538 >>> memory = torch.rand(10, 32, 512) 539 >>> tgt = torch.rand(20, 32, 512) 540 >>> out = transformer_decoder(tgt, memory) 541 """ 542 543 __constants__ = ["norm"] 544 545 def __init__( 546 self, 547 decoder_layer: "TransformerDecoderLayer", 548 num_layers: int, 549 norm: Optional[Module] = None, 550 ) -> None: 551 super().__init__() 552 torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}") 553 self.layers = _get_clones(decoder_layer, num_layers) 554 self.num_layers = num_layers 555 self.norm = norm 556 557 def forward( 558 self, 559 tgt: Tensor, 560 memory: Tensor, 561 tgt_mask: Optional[Tensor] = None, 562 memory_mask: Optional[Tensor] = None, 563 tgt_key_padding_mask: Optional[Tensor] = None, 564 memory_key_padding_mask: Optional[Tensor] = None, 565 tgt_is_causal: Optional[bool] = None, 566 memory_is_causal: bool = False, 567 ) -> Tensor: 568 r"""Pass the inputs (and mask) through the decoder layer in turn. 569 570 Args: 571 tgt: the sequence to the decoder (required). 572 memory: the sequence from the last layer of the encoder (required). 573 tgt_mask: the mask for the tgt sequence (optional). 574 memory_mask: the mask for the memory sequence (optional). 575 tgt_key_padding_mask: the mask for the tgt keys per batch (optional). 576 memory_key_padding_mask: the mask for the memory keys per batch (optional). 577 tgt_is_causal: If specified, applies a causal mask as ``tgt mask``. 578 Default: ``None``; try to detect a causal mask. 579 Warning: 580 ``tgt_is_causal`` provides a hint that ``tgt_mask`` is 581 the causal mask. Providing incorrect hints can result in 582 incorrect execution, including forward and backward 583 compatibility. 584 memory_is_causal: If specified, applies a causal mask as 585 ``memory mask``. 586 Default: ``False``. 587 Warning: 588 ``memory_is_causal`` provides a hint that 589 ``memory_mask`` is the causal mask. Providing incorrect 590 hints can result in incorrect execution, including 591 forward and backward compatibility. 592 593 Shape: 594 see the docs in :class:`~torch.nn.Transformer`. 595 """ 596 output = tgt 597 598 seq_len = _get_seq_len(tgt, self.layers[0].self_attn.batch_first) 599 tgt_is_causal = _detect_is_causal_mask(tgt_mask, tgt_is_causal, seq_len) 600 601 for mod in self.layers: 602 output = mod( 603 output, 604 memory, 605 tgt_mask=tgt_mask, 606 memory_mask=memory_mask, 607 tgt_key_padding_mask=tgt_key_padding_mask, 608 memory_key_padding_mask=memory_key_padding_mask, 609 tgt_is_causal=tgt_is_causal, 610 memory_is_causal=memory_is_causal, 611 ) 612 613 if self.norm is not None: 614 output = self.norm(output) 615 616 return output 617 618 619class TransformerEncoderLayer(Module): 620 r"""TransformerEncoderLayer is made up of self-attn and feedforward network. 621 622 This standard encoder layer is based on the paper "Attention Is All You Need". 623 Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, 624 Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in 625 Neural Information Processing Systems, pages 6000-6010. Users may modify or implement 626 in a different way during application. 627 628 TransformerEncoderLayer can handle either traditional torch.tensor inputs, 629 or Nested Tensor inputs. Derived classes are expected to similarly accept 630 both input formats. (Not all combinations of inputs are currently 631 supported by TransformerEncoderLayer while Nested Tensor is in prototype 632 state.) 633 634 If you are implementing a custom layer, you may derive it either from 635 the Module or TransformerEncoderLayer class. If your custom layer 636 supports both torch.Tensors and Nested Tensors inputs, make its 637 implementation a derived class of TransformerEncoderLayer. If your custom 638 Layer supports only torch.Tensor inputs, derive its implementation from 639 Module. 640 641 Args: 642 d_model: the number of expected features in the input (required). 643 nhead: the number of heads in the multiheadattention models (required). 644 dim_feedforward: the dimension of the feedforward network model (default=2048). 645 dropout: the dropout value (default=0.1). 646 activation: the activation function of the intermediate layer, can be a string 647 ("relu" or "gelu") or a unary callable. Default: relu 648 layer_norm_eps: the eps value in layer normalization components (default=1e-5). 649 batch_first: If ``True``, then the input and output tensors are provided 650 as (batch, seq, feature). Default: ``False`` (seq, batch, feature). 651 norm_first: if ``True``, layer norm is done prior to attention and feedforward 652 operations, respectively. Otherwise it's done after. Default: ``False`` (after). 653 bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive 654 bias. Default: ``True``. 655 656 Examples:: 657 >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) 658 >>> src = torch.rand(10, 32, 512) 659 >>> out = encoder_layer(src) 660 661 Alternatively, when ``batch_first`` is ``True``: 662 >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True) 663 >>> src = torch.rand(32, 10, 512) 664 >>> out = encoder_layer(src) 665 666 Fast path: 667 forward() will use a special optimized implementation described in 668 `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`_ if all of the following 669 conditions are met: 670 671 - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor 672 argument ``requires_grad`` 673 - training is disabled (using ``.eval()``) 674 - batch_first is ``True`` and the input is batched (i.e., ``src.dim() == 3``) 675 - activation is one of: ``"relu"``, ``"gelu"``, ``torch.functional.relu``, or ``torch.functional.gelu`` 676 - at most one of ``src_mask`` and ``src_key_padding_mask`` is passed 677 - if src is a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_, neither ``src_mask`` 678 nor ``src_key_padding_mask`` is passed 679 - the two ``LayerNorm`` instances have a consistent ``eps`` value (this will naturally be the case 680 unless the caller has manually modified one without modifying the other) 681 682 If the optimized implementation is in use, a 683 `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be 684 passed for ``src`` to represent padding more efficiently than using a padding 685 mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ will be 686 returned, and an additional speedup proportional to the fraction of the input that 687 is padding can be expected. 688 689 .. _`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`: 690 https://arxiv.org/abs/2205.14135 691 692 """ 693 694 __constants__ = ["norm_first"] 695 696 def __init__( 697 self, 698 d_model: int, 699 nhead: int, 700 dim_feedforward: int = 2048, 701 dropout: float = 0.1, 702 activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, 703 layer_norm_eps: float = 1e-5, 704 batch_first: bool = False, 705 norm_first: bool = False, 706 bias: bool = True, 707 device=None, 708 dtype=None, 709 ) -> None: 710 factory_kwargs = {"device": device, "dtype": dtype} 711 super().__init__() 712 self.self_attn = MultiheadAttention( 713 d_model, 714 nhead, 715 dropout=dropout, 716 bias=bias, 717 batch_first=batch_first, 718 **factory_kwargs, 719 ) 720 # Implementation of Feedforward model 721 self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs) 722 self.dropout = Dropout(dropout) 723 self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs) 724 725 self.norm_first = norm_first 726 self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) 727 self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) 728 self.dropout1 = Dropout(dropout) 729 self.dropout2 = Dropout(dropout) 730 731 # Legacy string support for activation function. 732 if isinstance(activation, str): 733 activation = _get_activation_fn(activation) 734 735 # We can't test self.activation in forward() in TorchScript, 736 # so stash some information about it instead. 737 if activation is F.relu or isinstance(activation, torch.nn.ReLU): 738 self.activation_relu_or_gelu = 1 739 elif activation is F.gelu or isinstance(activation, torch.nn.GELU): 740 self.activation_relu_or_gelu = 2 741 else: 742 self.activation_relu_or_gelu = 0 743 self.activation = activation 744 745 def __setstate__(self, state): 746 super().__setstate__(state) 747 if not hasattr(self, "activation"): 748 self.activation = F.relu 749 750 def forward( 751 self, 752 src: Tensor, 753 src_mask: Optional[Tensor] = None, 754 src_key_padding_mask: Optional[Tensor] = None, 755 is_causal: bool = False, 756 ) -> Tensor: 757 r"""Pass the input through the encoder layer. 758 759 Args: 760 src: the sequence to the encoder layer (required). 761 src_mask: the mask for the src sequence (optional). 762 src_key_padding_mask: the mask for the src keys per batch (optional). 763 is_causal: If specified, applies a causal mask as ``src mask``. 764 Default: ``False``. 765 Warning: 766 ``is_causal`` provides a hint that ``src_mask`` is the 767 causal mask. Providing incorrect hints can result in 768 incorrect execution, including forward and backward 769 compatibility. 770 771 Shape: 772 see the docs in :class:`~torch.nn.Transformer`. 773 """ 774 src_key_padding_mask = F._canonical_mask( 775 mask=src_key_padding_mask, 776 mask_name="src_key_padding_mask", 777 other_type=F._none_or_dtype(src_mask), 778 other_name="src_mask", 779 target_type=src.dtype, 780 ) 781 782 src_mask = F._canonical_mask( 783 mask=src_mask, 784 mask_name="src_mask", 785 other_type=None, 786 other_name="", 787 target_type=src.dtype, 788 check_other=False, 789 ) 790 791 is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled() 792 793 why_not_sparsity_fast_path = "" 794 if not is_fastpath_enabled: 795 why_not_sparsity_fast_path = ( 796 "torch.backends.mha.get_fastpath_enabled() was not True" 797 ) 798 elif not src.dim() == 3: 799 why_not_sparsity_fast_path = ( 800 f"input not batched; expected src.dim() of 3 but got {src.dim()}" 801 ) 802 elif self.training: 803 why_not_sparsity_fast_path = "training is enabled" 804 elif not self.self_attn.batch_first: 805 why_not_sparsity_fast_path = "self_attn.batch_first was not True" 806 elif self.self_attn.in_proj_bias is None: 807 why_not_sparsity_fast_path = "self_attn was passed bias=False" 808 elif not self.self_attn._qkv_same_embed_dim: 809 why_not_sparsity_fast_path = "self_attn._qkv_same_embed_dim was not True" 810 elif not self.activation_relu_or_gelu: 811 why_not_sparsity_fast_path = "activation_relu_or_gelu was not True" 812 elif not (self.norm1.eps == self.norm2.eps): 813 why_not_sparsity_fast_path = "norm1.eps is not equal to norm2.eps" 814 elif src.is_nested and ( 815 src_key_padding_mask is not None or src_mask is not None 816 ): 817 why_not_sparsity_fast_path = "neither src_key_padding_mask nor src_mask are not supported with NestedTensor input" 818 elif self.self_attn.num_heads % 2 == 1: 819 why_not_sparsity_fast_path = "num_head is odd" 820 elif torch.is_autocast_enabled(): 821 why_not_sparsity_fast_path = "autocast is enabled" 822 elif any( 823 len(getattr(m, "_forward_hooks", {})) 824 + len(getattr(m, "_forward_pre_hooks", {})) 825 for m in self.modules() 826 ): 827 why_not_sparsity_fast_path = "forward pre-/hooks are attached to the module" 828 if not why_not_sparsity_fast_path: 829 tensor_args = ( 830 src, 831 self.self_attn.in_proj_weight, 832 self.self_attn.in_proj_bias, 833 self.self_attn.out_proj.weight, 834 self.self_attn.out_proj.bias, 835 self.norm1.weight, 836 self.norm1.bias, 837 self.norm2.weight, 838 self.norm2.bias, 839 self.linear1.weight, 840 self.linear1.bias, 841 self.linear2.weight, 842 self.linear2.bias, 843 ) 844 845 # We have to use list comprehensions below because TorchScript does not support 846 # generator expressions. 847 _supported_device_type = [ 848 "cpu", 849 "cuda", 850 torch.utils.backend_registration._privateuse1_backend_name, 851 ] 852 if torch.overrides.has_torch_function(tensor_args): 853 why_not_sparsity_fast_path = "some Tensor argument has_torch_function" 854 elif not all( 855 (x.device.type in _supported_device_type) for x in tensor_args 856 ): 857 why_not_sparsity_fast_path = ( 858 "some Tensor argument's device is neither one of " 859 f"{_supported_device_type}" 860 ) 861 elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args): 862 why_not_sparsity_fast_path = ( 863 "grad is enabled and at least one of query or the " 864 "input/output projection weights or biases requires_grad" 865 ) 866 867 if not why_not_sparsity_fast_path: 868 merged_mask, mask_type = self.self_attn.merge_masks( 869 src_mask, src_key_padding_mask, src 870 ) 871 return torch._transformer_encoder_layer_fwd( 872 src, 873 self.self_attn.embed_dim, 874 self.self_attn.num_heads, 875 self.self_attn.in_proj_weight, 876 self.self_attn.in_proj_bias, 877 self.self_attn.out_proj.weight, 878 self.self_attn.out_proj.bias, 879 self.activation_relu_or_gelu == 2, 880 self.norm_first, 881 self.norm1.eps, 882 self.norm1.weight, 883 self.norm1.bias, 884 self.norm2.weight, 885 self.norm2.bias, 886 self.linear1.weight, 887 self.linear1.bias, 888 self.linear2.weight, 889 self.linear2.bias, 890 merged_mask, 891 mask_type, 892 ) 893 894 # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf 895 x = src 896 if self.norm_first: 897 x = x + self._sa_block( 898 self.norm1(x), src_mask, src_key_padding_mask, is_causal=is_causal 899 ) 900 x = x + self._ff_block(self.norm2(x)) 901 else: 902 x = self.norm1( 903 x 904 + self._sa_block(x, src_mask, src_key_padding_mask, is_causal=is_causal) 905 ) 906 x = self.norm2(x + self._ff_block(x)) 907 908 return x 909 910 # self-attention block 911 def _sa_block( 912 self, 913 x: Tensor, 914 attn_mask: Optional[Tensor], 915 key_padding_mask: Optional[Tensor], 916 is_causal: bool = False, 917 ) -> Tensor: 918 x = self.self_attn( 919 x, 920 x, 921 x, 922 attn_mask=attn_mask, 923 key_padding_mask=key_padding_mask, 924 need_weights=False, 925 is_causal=is_causal, 926 )[0] 927 return self.dropout1(x) 928 929 # feed forward block 930 def _ff_block(self, x: Tensor) -> Tensor: 931 x = self.linear2(self.dropout(self.activation(self.linear1(x)))) 932 return self.dropout2(x) 933 934 935class TransformerDecoderLayer(Module): 936 r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network. 937 938 This standard decoder layer is based on the paper "Attention Is All You Need". 939 Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, 940 Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in 941 Neural Information Processing Systems, pages 6000-6010. Users may modify or implement 942 in a different way during application. 943 944 Args: 945 d_model: the number of expected features in the input (required). 946 nhead: the number of heads in the multiheadattention models (required). 947 dim_feedforward: the dimension of the feedforward network model (default=2048). 948 dropout: the dropout value (default=0.1). 949 activation: the activation function of the intermediate layer, can be a string 950 ("relu" or "gelu") or a unary callable. Default: relu 951 layer_norm_eps: the eps value in layer normalization components (default=1e-5). 952 batch_first: If ``True``, then the input and output tensors are provided 953 as (batch, seq, feature). Default: ``False`` (seq, batch, feature). 954 norm_first: if ``True``, layer norm is done prior to self attention, multihead 955 attention and feedforward operations, respectively. Otherwise it's done after. 956 Default: ``False`` (after). 957 bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive 958 bias. Default: ``True``. 959 960 Examples:: 961 >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) 962 >>> memory = torch.rand(10, 32, 512) 963 >>> tgt = torch.rand(20, 32, 512) 964 >>> out = decoder_layer(tgt, memory) 965 966 Alternatively, when ``batch_first`` is ``True``: 967 >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=True) 968 >>> memory = torch.rand(32, 10, 512) 969 >>> tgt = torch.rand(32, 20, 512) 970 >>> out = decoder_layer(tgt, memory) 971 """ 972 973 __constants__ = ["norm_first"] 974 975 def __init__( 976 self, 977 d_model: int, 978 nhead: int, 979 dim_feedforward: int = 2048, 980 dropout: float = 0.1, 981 activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, 982 layer_norm_eps: float = 1e-5, 983 batch_first: bool = False, 984 norm_first: bool = False, 985 bias: bool = True, 986 device=None, 987 dtype=None, 988 ) -> None: 989 factory_kwargs = {"device": device, "dtype": dtype} 990 super().__init__() 991 self.self_attn = MultiheadAttention( 992 d_model, 993 nhead, 994 dropout=dropout, 995 batch_first=batch_first, 996 bias=bias, 997 **factory_kwargs, 998 ) 999 self.multihead_attn = MultiheadAttention( 1000 d_model, 1001 nhead, 1002 dropout=dropout, 1003 batch_first=batch_first, 1004 bias=bias, 1005 **factory_kwargs, 1006 ) 1007 # Implementation of Feedforward model 1008 self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs) 1009 self.dropout = Dropout(dropout) 1010 self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs) 1011 1012 self.norm_first = norm_first 1013 self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) 1014 self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) 1015 self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) 1016 self.dropout1 = Dropout(dropout) 1017 self.dropout2 = Dropout(dropout) 1018 self.dropout3 = Dropout(dropout) 1019 1020 # Legacy string support for activation function. 1021 if isinstance(activation, str): 1022 self.activation = _get_activation_fn(activation) 1023 else: 1024 self.activation = activation 1025 1026 def __setstate__(self, state): 1027 if "activation" not in state: 1028 state["activation"] = F.relu 1029 super().__setstate__(state) 1030 1031 def forward( 1032 self, 1033 tgt: Tensor, 1034 memory: Tensor, 1035 tgt_mask: Optional[Tensor] = None, 1036 memory_mask: Optional[Tensor] = None, 1037 tgt_key_padding_mask: Optional[Tensor] = None, 1038 memory_key_padding_mask: Optional[Tensor] = None, 1039 tgt_is_causal: bool = False, 1040 memory_is_causal: bool = False, 1041 ) -> Tensor: 1042 r"""Pass the inputs (and mask) through the decoder layer. 1043 1044 Args: 1045 tgt: the sequence to the decoder layer (required). 1046 memory: the sequence from the last layer of the encoder (required). 1047 tgt_mask: the mask for the tgt sequence (optional). 1048 memory_mask: the mask for the memory sequence (optional). 1049 tgt_key_padding_mask: the mask for the tgt keys per batch (optional). 1050 memory_key_padding_mask: the mask for the memory keys per batch (optional). 1051 tgt_is_causal: If specified, applies a causal mask as ``tgt mask``. 1052 Default: ``False``. 1053 Warning: 1054 ``tgt_is_causal`` provides a hint that ``tgt_mask`` is 1055 the causal mask. Providing incorrect hints can result in 1056 incorrect execution, including forward and backward 1057 compatibility. 1058 memory_is_causal: If specified, applies a causal mask as 1059 ``memory mask``. 1060 Default: ``False``. 1061 Warning: 1062 ``memory_is_causal`` provides a hint that 1063 ``memory_mask`` is the causal mask. Providing incorrect 1064 hints can result in incorrect execution, including 1065 forward and backward compatibility. 1066 1067 Shape: 1068 see the docs in :class:`~torch.nn.Transformer`. 1069 """ 1070 # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf 1071 1072 x = tgt 1073 if self.norm_first: 1074 x = x + self._sa_block( 1075 self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal 1076 ) 1077 x = x + self._mha_block( 1078 self.norm2(x), 1079 memory, 1080 memory_mask, 1081 memory_key_padding_mask, 1082 memory_is_causal, 1083 ) 1084 x = x + self._ff_block(self.norm3(x)) 1085 else: 1086 x = self.norm1( 1087 x + self._sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal) 1088 ) 1089 x = self.norm2( 1090 x 1091 + self._mha_block( 1092 x, memory, memory_mask, memory_key_padding_mask, memory_is_causal 1093 ) 1094 ) 1095 x = self.norm3(x + self._ff_block(x)) 1096 1097 return x 1098 1099 # self-attention block 1100 def _sa_block( 1101 self, 1102 x: Tensor, 1103 attn_mask: Optional[Tensor], 1104 key_padding_mask: Optional[Tensor], 1105 is_causal: bool = False, 1106 ) -> Tensor: 1107 x = self.self_attn( 1108 x, 1109 x, 1110 x, 1111 attn_mask=attn_mask, 1112 key_padding_mask=key_padding_mask, 1113 is_causal=is_causal, 1114 need_weights=False, 1115 )[0] 1116 return self.dropout1(x) 1117 1118 # multihead attention block 1119 def _mha_block( 1120 self, 1121 x: Tensor, 1122 mem: Tensor, 1123 attn_mask: Optional[Tensor], 1124 key_padding_mask: Optional[Tensor], 1125 is_causal: bool = False, 1126 ) -> Tensor: 1127 x = self.multihead_attn( 1128 x, 1129 mem, 1130 mem, 1131 attn_mask=attn_mask, 1132 key_padding_mask=key_padding_mask, 1133 is_causal=is_causal, 1134 need_weights=False, 1135 )[0] 1136 return self.dropout2(x) 1137 1138 # feed forward block 1139 def _ff_block(self, x: Tensor) -> Tensor: 1140 x = self.linear2(self.dropout(self.activation(self.linear1(x)))) 1141 return self.dropout3(x) 1142 1143 1144def _get_clones(module, N): 1145 # FIXME: copy.deepcopy() is not defined on nn.module 1146 return ModuleList([copy.deepcopy(module) for i in range(N)]) 1147 1148 1149def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: 1150 if activation == "relu": 1151 return F.relu 1152 elif activation == "gelu": 1153 return F.gelu 1154 1155 raise RuntimeError(f"activation should be relu/gelu, not {activation}") 1156 1157 1158def _detect_is_causal_mask( 1159 mask: Optional[Tensor], 1160 is_causal: Optional[bool] = None, 1161 size: Optional[int] = None, 1162) -> bool: 1163 """Return whether the given attention mask is causal. 1164 1165 Warning: 1166 If ``is_causal`` is not ``None``, its value will be returned as is. If a 1167 user supplies an incorrect ``is_causal`` hint, 1168 1169 ``is_causal=False`` when the mask is in fact a causal attention.mask 1170 may lead to reduced performance relative to what would be achievable 1171 with ``is_causal=True``; 1172 ``is_causal=True`` when the mask is in fact not a causal attention.mask 1173 may lead to incorrect and unpredictable execution - in some scenarios, 1174 a causal mask may be applied based on the hint, in other execution 1175 scenarios the specified mask may be used. The choice may not appear 1176 to be deterministic, in that a number of factors like alignment, 1177 hardware SKU, etc influence the decision whether to use a mask or 1178 rely on the hint. 1179 ``size`` if not None, check whether the mask is a causal mask of the provided size 1180 Otherwise, checks for any causal mask. 1181 """ 1182 # Prevent type refinement 1183 make_causal = is_causal is True 1184 1185 if is_causal is None and mask is not None: 1186 sz = size if size is not None else mask.size(-2) 1187 causal_comparison = _generate_square_subsequent_mask( 1188 sz, device=mask.device, dtype=mask.dtype 1189 ) 1190 1191 # Do not use `torch.equal` so we handle batched masks by 1192 # broadcasting the comparison. 1193 if mask.size() == causal_comparison.size(): 1194 make_causal = bool((mask == causal_comparison).all()) 1195 else: 1196 make_causal = False 1197 1198 return make_causal 1199