xref: /aosp_15_r20/external/executorch/examples/models/llama/source_transformation/attention.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9# Example script for exporting Llama2 to flatbuffer
10
11import math
12from typing import List, Optional, Tuple
13
14import torch
15from executorch.examples.models.llama.llama_transformer import Attention
16from torch import nn
17
18
19def apply_rotary_emb_single(
20    x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
21) -> torch.Tensor:
22    x_r, x_i = x[..., ::2], x[..., 1::2]
23
24    x_out_r = x_r * freqs_cos - x_i * freqs_sin
25    x_out_i = x_r * freqs_sin + x_i * freqs_cos
26
27    x_out = torch.cat([x_out_r, x_out_i], dim=-1)
28    return x_out
29
30
31class KVCacheSHA(torch.nn.Module):
32    def __init__(
33        self,
34        max_batch_size: int,
35        max_seq_length: int,
36        n_heads: int,
37        head_dim: int,
38        dtype=torch.float32,
39    ):
40        super().__init__()
41
42        # a buffer per head
43        cache_shape = (max_batch_size, max_seq_length, head_dim)
44        for i in range(n_heads):
45            self.register_buffer(
46                f"past_k_caches_{i}",
47                torch.zeros(cache_shape, dtype=dtype, device="cpu"),
48                persistent=False,
49            )
50            self.register_buffer(
51                f"past_v_caches_{i}",
52                torch.zeros(cache_shape, dtype=dtype, device="cpu"),
53                persistent=False,
54            )
55
56    def update(
57        self,
58        input_pos: torch.Tensor,
59        k_val: torch.Tensor,
60        v_val: torch.Tensor,
61        cache_idx: int,
62    ) -> Tuple[torch.Tensor, torch.Tensor]:
63        new_k = torch.ops.aten.index_put_(
64            getattr(self, f"past_k_caches_{cache_idx}"), [None, input_pos], k_val
65        )
66        new_v = torch.ops.aten.index_put_(
67            getattr(self, f"past_v_caches_{cache_idx}"), [None, input_pos], v_val
68        )
69        return new_k, new_v
70
71    def get_cache(self, head_idx):
72        return getattr(self, f"past_k_caches_{head_idx}"), getattr(
73            self, f"past_v_caches_{head_idx}"
74        )
75
76
77class SDPASHA(torch.nn.Module):
78
79    def __init__(
80        self,
81        max_batch_size: int,
82        max_seq_length: int,
83        n_heads: int,
84        n_rep: int,
85        head_dim: int,
86        dim: int,
87    ):
88        super().__init__()
89        self.head_dim = head_dim
90        self.n_rep = n_rep
91        self.dim = dim
92        self.kv_cache = KVCacheSHA(
93            max_batch_size, max_seq_length, n_heads // n_rep, head_dim
94        )
95        self.scale_factor = math.sqrt(head_dim)
96
97    def forward(
98        self,
99        input_pos: torch.Tensor,
100        qs: List[torch.Tensor],
101        ks: List[torch.Tensor],
102        vs: List[torch.Tensor],
103        mask,
104    ):
105
106        transpose_ks = []
107        for i in range(len(ks)):
108            new_k, _ = self.kv_cache.update(input_pos, ks[i], vs[i], i)
109            transpose_ks.append(new_k.transpose(-2, -1).contiguous())
110
111        output = []
112        for i, q in enumerate(qs):
113            cache_idx = i // self.n_rep
114            _, v = self.kv_cache.get_cache(cache_idx)
115
116            attn_mask = mask[input_pos]
117
118            attn_weight = q @ transpose_ks[cache_idx] / self.scale_factor
119            attn_weight += attn_mask
120            attn_weight = torch.softmax(attn_weight, dim=-1)
121            output.append(attn_weight @ v.contiguous())
122
123        return torch.cat(output, dim=-1)
124
125
126class AttentionSHA(nn.Module):
127    def __init__(self, attention_mha: nn.Module):
128        super().__init__()
129        if not attention_mha.use_kv_cache:
130            raise NotImplementedError("bert mode is not support")
131
132        self.n_heads = attention_mha.n_heads
133        self.n_kv_heads = attention_mha.n_kv_heads
134        self.n_rep = self.n_heads // self.n_kv_heads
135        self.dim = attention_mha.dim
136        self.max_batch_size = attention_mha.max_batch_size
137        self.max_seq_len = attention_mha.max_seq_len
138        self.head_dim = attention_mha.dim // self.n_heads
139        self.SDPA = SDPASHA(
140            self.max_batch_size,
141            self.max_seq_len,
142            self.n_heads,
143            self.n_rep,
144            self.head_dim,
145            self.dim,
146        )
147        self.wq = nn.ModuleList(
148            [
149                nn.Linear(self.dim, self.head_dim, bias=False)
150                for _ in range(self.n_heads)
151            ]
152        )
153        self.wk = nn.ModuleList(
154            [
155                nn.Linear(self.dim, self.head_dim, bias=False)
156                for _ in range(self.n_kv_heads)
157            ]
158        )
159        self.wv = nn.ModuleList(
160            [
161                nn.Linear(self.dim, self.head_dim, bias=False)
162                for _ in range(self.n_kv_heads)
163            ]
164        )
165
166        for i in range(self.n_heads):
167            self.wq[i].weight.data.copy_(
168                # pyre-fixme[16]: Item `Tensor` of `Union[Tensor, Module]` has no
169                #  attribute `weight`.
170                attention_mha.wq.weight[i * self.head_dim : (i + 1) * self.head_dim]
171            )
172        for i in range(self.n_kv_heads):
173            self.wk[i].weight.data.copy_(
174                # pyre-fixme[16]: Item `Tensor` of `Union[Tensor, Module]` has no
175                #  attribute `weight`.
176                attention_mha.wk.weight[i * self.head_dim : (i + 1) * self.head_dim]
177            )
178            self.wv[i].weight.data.copy_(
179                # pyre-fixme[16]: Item `Tensor` of `Union[Tensor, Module]` has no
180                #  attribute `weight`.
181                attention_mha.wv.weight[i * self.head_dim : (i + 1) * self.head_dim]
182            )
183        self.wo = attention_mha.wo
184
185        causal_mask = torch.tril(
186            torch.ones(
187                self.max_seq_len,
188                self.max_seq_len,
189                dtype=torch.bool,
190                device="cpu",
191            )
192        )
193        self.register_buffer("mask", causal_mask, persistent=False)
194
195    def forward(
196        self,
197        x: torch.Tensor,
198        freqs_cos: torch.Tensor,
199        freqs_sin: torch.Tensor,
200        input_pos: Optional[torch.Tensor] = None,
201    ):
202        # QKV
203        q = [wq(x) for wq in self.wq]
204        k = [wk(x) for wk in self.wk]
205        v = [wv(x) for wv in self.wv]
206        for i in range(len(q)):
207            q[i] = apply_rotary_emb_single(q[i], freqs_cos, freqs_sin)
208        for i in range(len(k)):
209            k[i] = apply_rotary_emb_single(k[i], freqs_cos, freqs_sin)
210
211        output = self.SDPA(input_pos, q, k, v, self.mask)
212        return self.wo(output)
213
214
215def replace_attention_to_attention_sha(module: torch.nn.Module):
216    for name, child in module.named_children():
217        if isinstance(child, Attention):
218            setattr(
219                module,
220                name,
221                AttentionSHA(child),
222            )
223        else:
224            replace_attention_to_attention_sha(child)
225    return module
226