1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from typing import List, Tuple 8 9import torch 10import torch.nn as nn 11 12from executorch.examples.models.llama.llama_transformer import ( 13 FeedForward, 14 ModelArgs, 15 precompute_freqs_cis, 16) 17 18 19def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: 20 """ 21 This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, 22 num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) 23 """ 24 batch, num_key_value_heads, slen, head_dim = hidden_states.shape 25 if n_rep == 1: 26 return hidden_states 27 hidden_states = hidden_states[:, :, None, :, :].expand( 28 batch, num_key_value_heads, n_rep, slen, head_dim 29 ) 30 return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) 31 32 33def apply_rotary_emb_single( 34 x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor 35) -> torch.Tensor: 36 x_r, x_i = x[..., ::2], x[..., 1::2] 37 38 x_out_r = x_r * freqs_cos - x_i * freqs_sin 39 x_out_i = x_r * freqs_sin + x_i * freqs_cos 40 41 x_out = torch.cat([x_out_r, x_out_i], dim=-1) 42 return x_out 43 44 45class LlamaAttention(nn.Module): 46 def __init__(self, config: ModelArgs, output_new_cache_only=False): 47 super().__init__() 48 self.dim = config.dim 49 self.n_heads = config.n_heads 50 self.head_dim = config.dim // config.n_heads 51 self.n_kv_heads = config.n_kv_heads 52 self.num_key_value_groups = config.n_heads // self.n_kv_heads 53 self.max_seq_len = config.max_seq_len 54 self.output_new_cache_only = output_new_cache_only 55 56 self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False) 57 self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) 58 self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) 59 self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False) 60 61 self.attn_softmax = torch.nn.Softmax(dim=-1) 62 63 self.scale = float(self.head_dim) ** 0.5 64 65 def prepare_sha(self): 66 self.wq_sha = nn.ModuleList( 67 [ 68 nn.Linear(self.dim, self.head_dim, bias=False) 69 for _ in range(self.n_heads) 70 ] 71 ) 72 self.wk_sha = nn.ModuleList( 73 [ 74 nn.Linear(self.dim, self.head_dim, bias=False) 75 for _ in range(self.n_kv_heads) 76 ] 77 ) 78 self.wv_sha = nn.ModuleList( 79 [ 80 nn.Linear(self.dim, self.head_dim, bias=False) 81 for _ in range(self.n_kv_heads) 82 ] 83 ) 84 85 self.forward_mha = self.forward 86 self.forward = self.forward_sha 87 88 for i in range(self.n_heads): 89 self.wq_sha[i].weight.data.copy_( 90 self.wq.weight[i * self.head_dim : (i + 1) * self.head_dim] 91 ) 92 for i in range(self.n_kv_heads): 93 self.wk_sha[i].weight.data.copy_( 94 self.wk.weight[i * self.head_dim : (i + 1) * self.head_dim] 95 ) 96 self.wv_sha[i].weight.data.copy_( 97 self.wv.weight[i * self.head_dim : (i + 1) * self.head_dim] 98 ) 99 100 def forward_sha( 101 self, 102 hidden_states: torch.Tensor, 103 freqs_cos: torch.Tensor, 104 freqs_sin: torch.Tensor, 105 atten_mask: torch.Tensor, 106 k_caches: List[torch.Tensor], 107 v_caches: List[torch.Tensor], 108 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 109 q = [wq_sha(hidden_states) for wq_sha in self.wq_sha] 110 k = [wk_sha(hidden_states) for wk_sha in self.wk_sha] 111 v = [wv_sha(hidden_states) for wv_sha in self.wv_sha] 112 for i in range(len(q)): 113 q[i] = apply_rotary_emb_single(q[i], freqs_cos, freqs_sin) 114 for i in range(len(k)): 115 k[i] = apply_rotary_emb_single(k[i], freqs_cos, freqs_sin).permute(0, 2, 1) 116 117 output_y = [] 118 kh, vh = [], [] 119 for i, _ in enumerate(k_caches): 120 kh.append(torch.cat([k_caches[i], k[i]], dim=-1)) 121 vh.append(torch.cat([v_caches[i], v[i]], dim=1)) 122 123 for i, _ in enumerate(q): 124 cache_idx = i // self.num_key_value_groups 125 attn = q[i] @ kh[cache_idx] 126 attn = attn / self.scale + atten_mask 127 attn = self.attn_softmax(attn) 128 y = attn @ vh[cache_idx] 129 130 output_y.append(y) 131 132 y = torch.concat(output_y, dim=-1) 133 y = self.wo(y) 134 return y, k, v 135 136 def forward( 137 self, 138 hidden_states: torch.Tensor, 139 freqs_cos: torch.Tensor, 140 freqs_sin: torch.Tensor, 141 atten_mask: torch.Tensor, 142 k_caches: List[torch.Tensor], 143 v_caches: List[torch.Tensor], 144 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 145 bsz, seqlen, _ = hidden_states.shape 146 147 q, k, v = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states) 148 q = q.view(bsz, seqlen, self.n_heads, self.head_dim) 149 k = k.view(bsz, seqlen, self.n_kv_heads, self.head_dim) 150 v = v.view(bsz, seqlen, self.n_kv_heads, self.head_dim) 151 152 q = apply_rotary_emb_single(q, freqs_cos, freqs_sin) 153 k = apply_rotary_emb_single(k, freqs_cos, freqs_sin).permute(0, 2, 3, 1) 154 155 output_kh, output_vh, output_y = [], [], [] 156 kh, vh = [], [] 157 for i, _ in enumerate(k_caches): 158 kh.append(torch.cat([k_caches[i], k[:, i, :, :]], dim=-1)) 159 vh.append(torch.cat([v_caches[i], v[:, :, i, :]], dim=1)) 160 161 for i in range(self.n_heads): 162 cache_idx = i // self.num_key_value_groups 163 164 attn = q[:, :, i, :] @ kh[cache_idx] 165 attn = attn / self.scale + atten_mask 166 attn = self.attn_softmax(attn) 167 y = attn @ vh[cache_idx] 168 169 output_y.append(y) 170 171 for i in range(len(k_caches)): 172 if self.output_new_cache_only: 173 output_kh.append(k[:, i, :, :]) 174 output_vh.append(v[:, :, i, :]) 175 else: 176 output_kh.append(kh[i]) 177 output_vh.append(vh[i]) 178 179 y = torch.concat(output_y, dim=-1) 180 y = self.wo(y) 181 182 return y, output_kh, output_vh 183 184 185class LlamaDecoderLayer(nn.Module): 186 def __init__(self, config: ModelArgs, output_new_cache_only=False): 187 super().__init__() 188 self.dim = config.dim 189 self.attention = LlamaAttention( 190 config=config, output_new_cache_only=output_new_cache_only 191 ) 192 self.feed_forward = FeedForward(config) 193 self.attention_norm = torch.nn.RMSNorm(config.dim, eps=config.norm_eps) 194 self.ffn_norm = torch.nn.RMSNorm(config.dim, eps=config.norm_eps) 195 196 def forward( 197 self, 198 x: torch.Tensor, 199 freqs_cos: torch.Tensor, 200 freqs_sin: torch.Tensor, 201 atten_mask: torch.Tensor, 202 k_caches: List[torch.Tensor], 203 v_caches: List[torch.Tensor], 204 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 205 h, k_cache, v_cache = self.attention( 206 hidden_states=self.attention_norm(x), 207 freqs_cos=freqs_cos, 208 freqs_sin=freqs_sin, 209 atten_mask=atten_mask, 210 k_caches=k_caches, 211 v_caches=v_caches, 212 ) 213 h = x + h 214 output = h + self.feed_forward(self.ffn_norm(h)) 215 return output, k_cache, v_cache 216 217 218class LlamaModel(nn.Module): 219 def __init__(self, config: ModelArgs, output_new_cache_only=True): 220 super().__init__() 221 self.dim = config.dim 222 self.head_dim = config.dim // config.n_heads 223 self.max_batch_size = config.max_batch_size 224 self.max_seq_len = config.max_seq_len 225 self.n_heads = config.n_heads 226 self.n_kv_heads = config.n_kv_heads 227 self.n_layers = config.n_layers 228 self.vocab_size = config.vocab_size 229 self.rope_freq_base = config.rope_freq_base 230 self.output_new_cache_only = output_new_cache_only 231 232 self.layers = nn.ModuleList( 233 [ 234 LlamaDecoderLayer(config, self.output_new_cache_only) 235 for _ in range(config.n_layers) 236 ] 237 ) 238 self.norm = torch.nn.RMSNorm(config.dim, eps=config.norm_eps) 239 self.output = nn.Linear(config.dim, config.vocab_size, bias=False) 240 self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) 241 freqs_cos, freqs_sin = precompute_freqs_cis( 242 config.dim // config.n_heads, 243 config.max_seq_len, 244 config.rope_freq_base, 245 ) 246 self.register_buffer("freqs_cos", freqs_cos, persistent=False) 247 self.register_buffer("freqs_sin", freqs_sin, persistent=False) 248 249 def forward( 250 self, 251 tokens: torch.Tensor, 252 input_pos: torch.Tensor, 253 atten_mask: torch.Tensor, 254 *args, 255 ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: 256 output_k_cache = [] 257 output_v_cache = [] 258 # following tensors should be invariant across batches 259 freqs_cos = self.freqs_cos[input_pos][0] 260 freqs_sin = self.freqs_sin[input_pos][0] 261 262 hidden_states = self.tok_embeddings(tokens) 263 for ind, decoder_layer in enumerate(self.layers): 264 offset_k = ind * self.n_kv_heads 265 offset_v = self.n_layers * self.n_kv_heads + offset_k 266 k_caches = args[offset_k : offset_k + self.n_kv_heads] 267 v_caches = args[offset_v : offset_v + self.n_kv_heads] 268 hidden_states, k, v = decoder_layer( 269 hidden_states, 270 freqs_cos=freqs_cos, 271 freqs_sin=freqs_sin, 272 atten_mask=atten_mask, 273 k_caches=k_caches, 274 v_caches=v_caches, 275 ) 276 output_k_cache.extend(k) 277 output_v_cache.extend(v) 278 279 hidden_states = self.norm(hidden_states) 280 logits = self.output(hidden_states) 281 282 return logits, output_k_cache, output_v_cache 283 284 def get_example_inputs(self): 285 tokens = torch.randint( 286 self.vocab_size, (self.max_batch_size, 1), dtype=torch.int32 287 ) 288 pos_ids = torch.zeros((self.max_batch_size, 1), dtype=torch.int32) 289 k_cache, v_cache = [], [] 290 atten_mask = torch.full((self.max_batch_size, self.max_seq_len), -255.0) 291 atten_mask[:, -1] = 0 292 for _ in range(self.n_layers): 293 for _ in range(self.n_kv_heads): 294 # transpose first to decrease the runtime efforts 295 k_cache.append( 296 torch.zeros( 297 self.max_batch_size, 298 self.head_dim, 299 self.max_seq_len - 1, 300 ) 301 ) 302 v_cache.append( 303 torch.zeros( 304 self.max_batch_size, 305 self.max_seq_len - 1, 306 self.head_dim, 307 ) 308 ) 309 return ( 310 tokens, 311 pos_ids, 312 atten_mask, 313 k_cache, 314 v_cache, 315 ) 316 317 def get_metadata(self): 318 # TODO: modify this when enabling LLAMA 7B 319 return { 320 "get_bos_id": 1, 321 "get_eos_id": 2, 322 "get_dim": self.dim, 323 "get_head_dim": self.dim // self.n_heads, 324 "get_max_batch_size": self.max_batch_size, 325 "get_max_seq_len": self.max_seq_len, 326 "get_n_bos": 1, 327 "get_n_eos": 1, 328 "get_n_kv_heads": self.n_kv_heads, 329 "get_n_layers": self.n_layers, 330 "get_vocab_size": self.vocab_size, 331 } 332