1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9# Example script for exporting Llama2 to flatbuffer 10 11import math 12from typing import List, Optional, Tuple 13 14import torch 15from executorch.examples.models.llama.llama_transformer import Attention 16from torch import nn 17 18 19def apply_rotary_emb_single( 20 x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor 21) -> torch.Tensor: 22 x_r, x_i = x[..., ::2], x[..., 1::2] 23 24 x_out_r = x_r * freqs_cos - x_i * freqs_sin 25 x_out_i = x_r * freqs_sin + x_i * freqs_cos 26 27 x_out = torch.cat([x_out_r, x_out_i], dim=-1) 28 return x_out 29 30 31class KVCacheSHA(torch.nn.Module): 32 def __init__( 33 self, 34 max_batch_size: int, 35 max_seq_length: int, 36 n_heads: int, 37 head_dim: int, 38 dtype=torch.float32, 39 ): 40 super().__init__() 41 42 # a buffer per head 43 cache_shape = (max_batch_size, max_seq_length, head_dim) 44 for i in range(n_heads): 45 self.register_buffer( 46 f"past_k_caches_{i}", 47 torch.zeros(cache_shape, dtype=dtype, device="cpu"), 48 persistent=False, 49 ) 50 self.register_buffer( 51 f"past_v_caches_{i}", 52 torch.zeros(cache_shape, dtype=dtype, device="cpu"), 53 persistent=False, 54 ) 55 56 def update( 57 self, 58 input_pos: torch.Tensor, 59 k_val: torch.Tensor, 60 v_val: torch.Tensor, 61 cache_idx: int, 62 ) -> Tuple[torch.Tensor, torch.Tensor]: 63 new_k = torch.ops.aten.index_put_( 64 getattr(self, f"past_k_caches_{cache_idx}"), [None, input_pos], k_val 65 ) 66 new_v = torch.ops.aten.index_put_( 67 getattr(self, f"past_v_caches_{cache_idx}"), [None, input_pos], v_val 68 ) 69 return new_k, new_v 70 71 def get_cache(self, head_idx): 72 return getattr(self, f"past_k_caches_{head_idx}"), getattr( 73 self, f"past_v_caches_{head_idx}" 74 ) 75 76 77class SDPASHA(torch.nn.Module): 78 79 def __init__( 80 self, 81 max_batch_size: int, 82 max_seq_length: int, 83 n_heads: int, 84 n_rep: int, 85 head_dim: int, 86 dim: int, 87 ): 88 super().__init__() 89 self.head_dim = head_dim 90 self.n_rep = n_rep 91 self.dim = dim 92 self.kv_cache = KVCacheSHA( 93 max_batch_size, max_seq_length, n_heads // n_rep, head_dim 94 ) 95 self.scale_factor = math.sqrt(head_dim) 96 97 def forward( 98 self, 99 input_pos: torch.Tensor, 100 qs: List[torch.Tensor], 101 ks: List[torch.Tensor], 102 vs: List[torch.Tensor], 103 mask, 104 ): 105 106 transpose_ks = [] 107 for i in range(len(ks)): 108 new_k, _ = self.kv_cache.update(input_pos, ks[i], vs[i], i) 109 transpose_ks.append(new_k.transpose(-2, -1).contiguous()) 110 111 output = [] 112 for i, q in enumerate(qs): 113 cache_idx = i // self.n_rep 114 _, v = self.kv_cache.get_cache(cache_idx) 115 116 attn_mask = mask[input_pos] 117 118 attn_weight = q @ transpose_ks[cache_idx] / self.scale_factor 119 attn_weight += attn_mask 120 attn_weight = torch.softmax(attn_weight, dim=-1) 121 output.append(attn_weight @ v.contiguous()) 122 123 return torch.cat(output, dim=-1) 124 125 126class AttentionSHA(nn.Module): 127 def __init__(self, attention_mha: nn.Module): 128 super().__init__() 129 if not attention_mha.use_kv_cache: 130 raise NotImplementedError("bert mode is not support") 131 132 self.n_heads = attention_mha.n_heads 133 self.n_kv_heads = attention_mha.n_kv_heads 134 self.n_rep = self.n_heads // self.n_kv_heads 135 self.dim = attention_mha.dim 136 self.max_batch_size = attention_mha.max_batch_size 137 self.max_seq_len = attention_mha.max_seq_len 138 self.head_dim = attention_mha.dim // self.n_heads 139 self.SDPA = SDPASHA( 140 self.max_batch_size, 141 self.max_seq_len, 142 self.n_heads, 143 self.n_rep, 144 self.head_dim, 145 self.dim, 146 ) 147 self.wq = nn.ModuleList( 148 [ 149 nn.Linear(self.dim, self.head_dim, bias=False) 150 for _ in range(self.n_heads) 151 ] 152 ) 153 self.wk = nn.ModuleList( 154 [ 155 nn.Linear(self.dim, self.head_dim, bias=False) 156 for _ in range(self.n_kv_heads) 157 ] 158 ) 159 self.wv = nn.ModuleList( 160 [ 161 nn.Linear(self.dim, self.head_dim, bias=False) 162 for _ in range(self.n_kv_heads) 163 ] 164 ) 165 166 for i in range(self.n_heads): 167 self.wq[i].weight.data.copy_( 168 # pyre-fixme[16]: Item `Tensor` of `Union[Tensor, Module]` has no 169 # attribute `weight`. 170 attention_mha.wq.weight[i * self.head_dim : (i + 1) * self.head_dim] 171 ) 172 for i in range(self.n_kv_heads): 173 self.wk[i].weight.data.copy_( 174 # pyre-fixme[16]: Item `Tensor` of `Union[Tensor, Module]` has no 175 # attribute `weight`. 176 attention_mha.wk.weight[i * self.head_dim : (i + 1) * self.head_dim] 177 ) 178 self.wv[i].weight.data.copy_( 179 # pyre-fixme[16]: Item `Tensor` of `Union[Tensor, Module]` has no 180 # attribute `weight`. 181 attention_mha.wv.weight[i * self.head_dim : (i + 1) * self.head_dim] 182 ) 183 self.wo = attention_mha.wo 184 185 causal_mask = torch.tril( 186 torch.ones( 187 self.max_seq_len, 188 self.max_seq_len, 189 dtype=torch.bool, 190 device="cpu", 191 ) 192 ) 193 self.register_buffer("mask", causal_mask, persistent=False) 194 195 def forward( 196 self, 197 x: torch.Tensor, 198 freqs_cos: torch.Tensor, 199 freqs_sin: torch.Tensor, 200 input_pos: Optional[torch.Tensor] = None, 201 ): 202 # QKV 203 q = [wq(x) for wq in self.wq] 204 k = [wk(x) for wk in self.wk] 205 v = [wv(x) for wv in self.wv] 206 for i in range(len(q)): 207 q[i] = apply_rotary_emb_single(q[i], freqs_cos, freqs_sin) 208 for i in range(len(k)): 209 k[i] = apply_rotary_emb_single(k[i], freqs_cos, freqs_sin) 210 211 output = self.SDPA(input_pos, q, k, v, self.mask) 212 return self.wo(output) 213 214 215def replace_attention_to_attention_sha(module: torch.nn.Module): 216 for name, child in module.named_children(): 217 if isinstance(child, Attention): 218 setattr( 219 module, 220 name, 221 AttentionSHA(child), 222 ) 223 else: 224 replace_attention_to_attention_sha(child) 225 return module 226