1# Copyright (c) Facebook, Inc. and its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6import math 7 8import torch 9from torch import nn 10 11 12class BertSelfAttention(nn.Module): 13 def __init__( 14 self, 15 hidden_size, 16 num_attention_heads, 17 attention_probs_dropout_prob, 18 position_embedding_type=None, 19 max_position_embeddings=None, 20 ): 21 super().__init__() 22 if hidden_size % num_attention_heads != 0: 23 raise ValueError( 24 f"The hidden size ({hidden_size}) is not a multiple of the number of attention " 25 f"heads ({num_attention_heads})" 26 ) 27 28 self.num_attention_heads = num_attention_heads 29 self.attention_head_size = int(hidden_size / num_attention_heads) 30 self.all_head_size = self.num_attention_heads * self.attention_head_size 31 32 self.query = nn.Linear(hidden_size, self.all_head_size) 33 self.key = nn.Linear(hidden_size, self.all_head_size) 34 self.value = nn.Linear(hidden_size, self.all_head_size) 35 36 self.dropout = nn.Dropout(attention_probs_dropout_prob) 37 self.position_embedding_type = position_embedding_type 38 39 if self.position_embedding_type is not None: 40 assert max_position_embeddings is not None 41 self.max_position_embeddings = max_position_embeddings 42 self.distance_embedding = nn.Embedding( 43 2 * max_position_embeddings - 1, self.attention_head_size 44 ) 45 46 def transpose_for_scores(self, x): 47 new_x_shape = x.size()[:-1] + ( 48 self.num_attention_heads, 49 self.attention_head_size, 50 ) 51 x = x.view(*new_x_shape) 52 return x.permute(0, 2, 1, 3) 53 54 def forward( 55 self, 56 hidden_states, 57 past_key_value=None, 58 ): 59 q = self.query(hidden_states) 60 k = self.key(hidden_states) 61 v = self.value(hidden_states) 62 63 q = self.transpose_for_scores(q) 64 k = self.transpose_for_scores(k) 65 v = self.transpose_for_scores(v) 66 67 if past_key_value is not None: 68 k = torch.cat([past_key_value[0], k], dim=2) 69 v = torch.cat([past_key_value[1], v], dim=2) 70 71 # Take the dot product between "query" and "key" to get the raw attention scores. 72 attention_scores = torch.matmul(q, k.transpose(-1, -2)) 73 attention_scores = attention_scores / math.sqrt(self.attention_head_size) 74 75 if self.position_embedding_type is not None: 76 seq_length = hidden_states.size()[1] 77 position_ids_l = torch.arange( 78 seq_length, dtype=torch.long, device=hidden_states.device 79 ).view(-1, 1) 80 position_ids_r = torch.arange( 81 seq_length, dtype=torch.long, device=hidden_states.device 82 ).view(1, -1) 83 distance = position_ids_l - position_ids_r 84 positional_embedding = self.distance_embedding( 85 distance + self.max_position_embeddings - 1 86 ) 87 positional_embedding = positional_embedding.to( 88 dtype=q.dtype 89 ) # fp16 compatibility 90 91 if self.position_embedding_type == "relative_key": 92 relative_position_scores = torch.einsum( 93 "bhld,lrd->bhlr", q, positional_embedding 94 ) 95 attention_scores = attention_scores + relative_position_scores 96 elif self.position_embedding_type == "relative_key_query": 97 relative_position_scores_query = torch.einsum( 98 "bhld,lrd->bhlr", q, positional_embedding 99 ) 100 relative_position_scores_key = torch.einsum( 101 "bhrd,lrd->bhlr", k, positional_embedding 102 ) 103 attention_scores = ( 104 attention_scores 105 + relative_position_scores_query 106 + relative_position_scores_key 107 ) 108 109 attention_probs = attention_scores 110 # Normalize the attention scores to probabilities. 111 attention_probs = nn.functional.softmax(attention_scores, dim=-1) 112 # # This is actually dropping out entire tokens to attend to, which might 113 # # seem a bit unusual, but is taken from the original Transformer paper. 114 attention_probs = self.dropout(attention_probs) 115 116 context_layer = torch.matmul(attention_probs, v) 117 118 context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 119 new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 120 context_layer = context_layer.view(*new_context_layer_shape) 121 return context_layer 122