xref: /aosp_15_r20/external/pytorch/test/functorch/attn_positional.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Copyright (c) Facebook, Inc. and its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6import math
7
8import torch
9from torch import nn
10
11
12class BertSelfAttention(nn.Module):
13    def __init__(
14        self,
15        hidden_size,
16        num_attention_heads,
17        attention_probs_dropout_prob,
18        position_embedding_type=None,
19        max_position_embeddings=None,
20    ):
21        super().__init__()
22        if hidden_size % num_attention_heads != 0:
23            raise ValueError(
24                f"The hidden size ({hidden_size}) is not a multiple of the number of attention "
25                f"heads ({num_attention_heads})"
26            )
27
28        self.num_attention_heads = num_attention_heads
29        self.attention_head_size = int(hidden_size / num_attention_heads)
30        self.all_head_size = self.num_attention_heads * self.attention_head_size
31
32        self.query = nn.Linear(hidden_size, self.all_head_size)
33        self.key = nn.Linear(hidden_size, self.all_head_size)
34        self.value = nn.Linear(hidden_size, self.all_head_size)
35
36        self.dropout = nn.Dropout(attention_probs_dropout_prob)
37        self.position_embedding_type = position_embedding_type
38
39        if self.position_embedding_type is not None:
40            assert max_position_embeddings is not None
41            self.max_position_embeddings = max_position_embeddings
42            self.distance_embedding = nn.Embedding(
43                2 * max_position_embeddings - 1, self.attention_head_size
44            )
45
46    def transpose_for_scores(self, x):
47        new_x_shape = x.size()[:-1] + (
48            self.num_attention_heads,
49            self.attention_head_size,
50        )
51        x = x.view(*new_x_shape)
52        return x.permute(0, 2, 1, 3)
53
54    def forward(
55        self,
56        hidden_states,
57        past_key_value=None,
58    ):
59        q = self.query(hidden_states)
60        k = self.key(hidden_states)
61        v = self.value(hidden_states)
62
63        q = self.transpose_for_scores(q)
64        k = self.transpose_for_scores(k)
65        v = self.transpose_for_scores(v)
66
67        if past_key_value is not None:
68            k = torch.cat([past_key_value[0], k], dim=2)
69            v = torch.cat([past_key_value[1], v], dim=2)
70
71        # Take the dot product between "query" and "key" to get the raw attention scores.
72        attention_scores = torch.matmul(q, k.transpose(-1, -2))
73        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
74
75        if self.position_embedding_type is not None:
76            seq_length = hidden_states.size()[1]
77            position_ids_l = torch.arange(
78                seq_length, dtype=torch.long, device=hidden_states.device
79            ).view(-1, 1)
80            position_ids_r = torch.arange(
81                seq_length, dtype=torch.long, device=hidden_states.device
82            ).view(1, -1)
83            distance = position_ids_l - position_ids_r
84            positional_embedding = self.distance_embedding(
85                distance + self.max_position_embeddings - 1
86            )
87            positional_embedding = positional_embedding.to(
88                dtype=q.dtype
89            )  # fp16 compatibility
90
91            if self.position_embedding_type == "relative_key":
92                relative_position_scores = torch.einsum(
93                    "bhld,lrd->bhlr", q, positional_embedding
94                )
95                attention_scores = attention_scores + relative_position_scores
96            elif self.position_embedding_type == "relative_key_query":
97                relative_position_scores_query = torch.einsum(
98                    "bhld,lrd->bhlr", q, positional_embedding
99                )
100                relative_position_scores_key = torch.einsum(
101                    "bhrd,lrd->bhlr", k, positional_embedding
102                )
103                attention_scores = (
104                    attention_scores
105                    + relative_position_scores_query
106                    + relative_position_scores_key
107                )
108
109        attention_probs = attention_scores
110        # Normalize the attention scores to probabilities.
111        attention_probs = nn.functional.softmax(attention_scores, dim=-1)
112        # # This is actually dropping out entire tokens to attend to, which might
113        # # seem a bit unusual, but is taken from the original Transformer paper.
114        attention_probs = self.dropout(attention_probs)
115
116        context_layer = torch.matmul(attention_probs, v)
117
118        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
119        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
120        context_layer = context_layer.view(*new_context_layer_shape)
121        return context_layer
122