xref: /aosp_15_r20/external/executorch/examples/qualcomm/oss_scripts/llama2/model/static_llama.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from typing import List, Tuple
8
9import torch
10import torch.nn as nn
11
12from executorch.examples.models.llama.llama_transformer import (
13    FeedForward,
14    ModelArgs,
15    precompute_freqs_cis,
16)
17
18
19def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
20    """
21    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
22    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
23    """
24    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
25    if n_rep == 1:
26        return hidden_states
27    hidden_states = hidden_states[:, :, None, :, :].expand(
28        batch, num_key_value_heads, n_rep, slen, head_dim
29    )
30    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
31
32
33def apply_rotary_emb_single(
34    x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
35) -> torch.Tensor:
36    x_r, x_i = x[..., ::2], x[..., 1::2]
37
38    x_out_r = x_r * freqs_cos - x_i * freqs_sin
39    x_out_i = x_r * freqs_sin + x_i * freqs_cos
40
41    x_out = torch.cat([x_out_r, x_out_i], dim=-1)
42    return x_out
43
44
45class LlamaAttention(nn.Module):
46    def __init__(self, config: ModelArgs, output_new_cache_only=False):
47        super().__init__()
48        self.dim = config.dim
49        self.n_heads = config.n_heads
50        self.head_dim = config.dim // config.n_heads
51        self.n_kv_heads = config.n_kv_heads
52        self.num_key_value_groups = config.n_heads // self.n_kv_heads
53        self.max_seq_len = config.max_seq_len
54        self.output_new_cache_only = output_new_cache_only
55
56        self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False)
57        self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
58        self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
59        self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
60
61        self.attn_softmax = torch.nn.Softmax(dim=-1)
62
63        self.scale = float(self.head_dim) ** 0.5
64
65    def prepare_sha(self):
66        self.wq_sha = nn.ModuleList(
67            [
68                nn.Linear(self.dim, self.head_dim, bias=False)
69                for _ in range(self.n_heads)
70            ]
71        )
72        self.wk_sha = nn.ModuleList(
73            [
74                nn.Linear(self.dim, self.head_dim, bias=False)
75                for _ in range(self.n_kv_heads)
76            ]
77        )
78        self.wv_sha = nn.ModuleList(
79            [
80                nn.Linear(self.dim, self.head_dim, bias=False)
81                for _ in range(self.n_kv_heads)
82            ]
83        )
84
85        self.forward_mha = self.forward
86        self.forward = self.forward_sha
87
88        for i in range(self.n_heads):
89            self.wq_sha[i].weight.data.copy_(
90                self.wq.weight[i * self.head_dim : (i + 1) * self.head_dim]
91            )
92        for i in range(self.n_kv_heads):
93            self.wk_sha[i].weight.data.copy_(
94                self.wk.weight[i * self.head_dim : (i + 1) * self.head_dim]
95            )
96            self.wv_sha[i].weight.data.copy_(
97                self.wv.weight[i * self.head_dim : (i + 1) * self.head_dim]
98            )
99
100    def forward_sha(
101        self,
102        hidden_states: torch.Tensor,
103        freqs_cos: torch.Tensor,
104        freqs_sin: torch.Tensor,
105        atten_mask: torch.Tensor,
106        k_caches: List[torch.Tensor],
107        v_caches: List[torch.Tensor],
108    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
109        q = [wq_sha(hidden_states) for wq_sha in self.wq_sha]
110        k = [wk_sha(hidden_states) for wk_sha in self.wk_sha]
111        v = [wv_sha(hidden_states) for wv_sha in self.wv_sha]
112        for i in range(len(q)):
113            q[i] = apply_rotary_emb_single(q[i], freqs_cos, freqs_sin)
114        for i in range(len(k)):
115            k[i] = apply_rotary_emb_single(k[i], freqs_cos, freqs_sin).permute(0, 2, 1)
116
117        output_y = []
118        kh, vh = [], []
119        for i, _ in enumerate(k_caches):
120            kh.append(torch.cat([k_caches[i], k[i]], dim=-1))
121            vh.append(torch.cat([v_caches[i], v[i]], dim=1))
122
123        for i, _ in enumerate(q):
124            cache_idx = i // self.num_key_value_groups
125            attn = q[i] @ kh[cache_idx]
126            attn = attn / self.scale + atten_mask
127            attn = self.attn_softmax(attn)
128            y = attn @ vh[cache_idx]
129
130            output_y.append(y)
131
132        y = torch.concat(output_y, dim=-1)
133        y = self.wo(y)
134        return y, k, v
135
136    def forward(
137        self,
138        hidden_states: torch.Tensor,
139        freqs_cos: torch.Tensor,
140        freqs_sin: torch.Tensor,
141        atten_mask: torch.Tensor,
142        k_caches: List[torch.Tensor],
143        v_caches: List[torch.Tensor],
144    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
145        bsz, seqlen, _ = hidden_states.shape
146
147        q, k, v = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states)
148        q = q.view(bsz, seqlen, self.n_heads, self.head_dim)
149        k = k.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
150        v = v.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
151
152        q = apply_rotary_emb_single(q, freqs_cos, freqs_sin)
153        k = apply_rotary_emb_single(k, freqs_cos, freqs_sin).permute(0, 2, 3, 1)
154
155        output_kh, output_vh, output_y = [], [], []
156        kh, vh = [], []
157        for i, _ in enumerate(k_caches):
158            kh.append(torch.cat([k_caches[i], k[:, i, :, :]], dim=-1))
159            vh.append(torch.cat([v_caches[i], v[:, :, i, :]], dim=1))
160
161        for i in range(self.n_heads):
162            cache_idx = i // self.num_key_value_groups
163
164            attn = q[:, :, i, :] @ kh[cache_idx]
165            attn = attn / self.scale + atten_mask
166            attn = self.attn_softmax(attn)
167            y = attn @ vh[cache_idx]
168
169            output_y.append(y)
170
171        for i in range(len(k_caches)):
172            if self.output_new_cache_only:
173                output_kh.append(k[:, i, :, :])
174                output_vh.append(v[:, :, i, :])
175            else:
176                output_kh.append(kh[i])
177                output_vh.append(vh[i])
178
179        y = torch.concat(output_y, dim=-1)
180        y = self.wo(y)
181
182        return y, output_kh, output_vh
183
184
185class LlamaDecoderLayer(nn.Module):
186    def __init__(self, config: ModelArgs, output_new_cache_only=False):
187        super().__init__()
188        self.dim = config.dim
189        self.attention = LlamaAttention(
190            config=config, output_new_cache_only=output_new_cache_only
191        )
192        self.feed_forward = FeedForward(config)
193        self.attention_norm = torch.nn.RMSNorm(config.dim, eps=config.norm_eps)
194        self.ffn_norm = torch.nn.RMSNorm(config.dim, eps=config.norm_eps)
195
196    def forward(
197        self,
198        x: torch.Tensor,
199        freqs_cos: torch.Tensor,
200        freqs_sin: torch.Tensor,
201        atten_mask: torch.Tensor,
202        k_caches: List[torch.Tensor],
203        v_caches: List[torch.Tensor],
204    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
205        h, k_cache, v_cache = self.attention(
206            hidden_states=self.attention_norm(x),
207            freqs_cos=freqs_cos,
208            freqs_sin=freqs_sin,
209            atten_mask=atten_mask,
210            k_caches=k_caches,
211            v_caches=v_caches,
212        )
213        h = x + h
214        output = h + self.feed_forward(self.ffn_norm(h))
215        return output, k_cache, v_cache
216
217
218class LlamaModel(nn.Module):
219    def __init__(self, config: ModelArgs, output_new_cache_only=True):
220        super().__init__()
221        self.dim = config.dim
222        self.head_dim = config.dim // config.n_heads
223        self.max_batch_size = config.max_batch_size
224        self.max_seq_len = config.max_seq_len
225        self.n_heads = config.n_heads
226        self.n_kv_heads = config.n_kv_heads
227        self.n_layers = config.n_layers
228        self.vocab_size = config.vocab_size
229        self.rope_freq_base = config.rope_freq_base
230        self.output_new_cache_only = output_new_cache_only
231
232        self.layers = nn.ModuleList(
233            [
234                LlamaDecoderLayer(config, self.output_new_cache_only)
235                for _ in range(config.n_layers)
236            ]
237        )
238        self.norm = torch.nn.RMSNorm(config.dim, eps=config.norm_eps)
239        self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
240        self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
241        freqs_cos, freqs_sin = precompute_freqs_cis(
242            config.dim // config.n_heads,
243            config.max_seq_len,
244            config.rope_freq_base,
245        )
246        self.register_buffer("freqs_cos", freqs_cos, persistent=False)
247        self.register_buffer("freqs_sin", freqs_sin, persistent=False)
248
249    def forward(
250        self,
251        tokens: torch.Tensor,
252        input_pos: torch.Tensor,
253        atten_mask: torch.Tensor,
254        *args,
255    ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
256        output_k_cache = []
257        output_v_cache = []
258        # following tensors should be invariant across batches
259        freqs_cos = self.freqs_cos[input_pos][0]
260        freqs_sin = self.freqs_sin[input_pos][0]
261
262        hidden_states = self.tok_embeddings(tokens)
263        for ind, decoder_layer in enumerate(self.layers):
264            offset_k = ind * self.n_kv_heads
265            offset_v = self.n_layers * self.n_kv_heads + offset_k
266            k_caches = args[offset_k : offset_k + self.n_kv_heads]
267            v_caches = args[offset_v : offset_v + self.n_kv_heads]
268            hidden_states, k, v = decoder_layer(
269                hidden_states,
270                freqs_cos=freqs_cos,
271                freqs_sin=freqs_sin,
272                atten_mask=atten_mask,
273                k_caches=k_caches,
274                v_caches=v_caches,
275            )
276            output_k_cache.extend(k)
277            output_v_cache.extend(v)
278
279        hidden_states = self.norm(hidden_states)
280        logits = self.output(hidden_states)
281
282        return logits, output_k_cache, output_v_cache
283
284    def get_example_inputs(self):
285        tokens = torch.randint(
286            self.vocab_size, (self.max_batch_size, 1), dtype=torch.int32
287        )
288        pos_ids = torch.zeros((self.max_batch_size, 1), dtype=torch.int32)
289        k_cache, v_cache = [], []
290        atten_mask = torch.full((self.max_batch_size, self.max_seq_len), -255.0)
291        atten_mask[:, -1] = 0
292        for _ in range(self.n_layers):
293            for _ in range(self.n_kv_heads):
294                # transpose first to decrease the runtime efforts
295                k_cache.append(
296                    torch.zeros(
297                        self.max_batch_size,
298                        self.head_dim,
299                        self.max_seq_len - 1,
300                    )
301                )
302                v_cache.append(
303                    torch.zeros(
304                        self.max_batch_size,
305                        self.max_seq_len - 1,
306                        self.head_dim,
307                    )
308                )
309        return (
310            tokens,
311            pos_ids,
312            atten_mask,
313            k_cache,
314            v_cache,
315        )
316
317    def get_metadata(self):
318        # TODO: modify this when enabling LLAMA 7B
319        return {
320            "get_bos_id": 1,
321            "get_eos_id": 2,
322            "get_dim": self.dim,
323            "get_head_dim": self.dim // self.n_heads,
324            "get_max_batch_size": self.max_batch_size,
325            "get_max_seq_len": self.max_seq_len,
326            "get_n_bos": 1,
327            "get_n_eos": 1,
328            "get_n_kv_heads": self.n_kv_heads,
329            "get_n_layers": self.n_layers,
330            "get_vocab_size": self.vocab_size,
331        }
332