1# flake8: noqa: E266, C417, B950 2from dataclasses import dataclass 3from typing import Optional 4 5import torch 6import torch.nn as nn 7from torch import Tensor 8from torch.nn import functional as F 9 10 11def find_multiple(n: int, k: int) -> int: 12 if n % k == 0: 13 return n 14 return n + k - (n % k) 15 16 17@dataclass 18class ModelArgs: 19 block_size: int = 2048 20 vocab_size: int = 32000 21 n_layer: int = 32 22 n_head: int = 32 23 dim: int = 4096 24 intermediate_size: int = None 25 n_local_heads: int = -1 26 head_dim: int = 64 27 rope_base: float = 10000 28 norm_eps: float = 1e-5 29 num_experts: int = 8 30 num_activated_experts: int = 2 31 32 def __post_init__(self): 33 if self.n_local_heads == -1: 34 self.n_local_heads = self.n_head 35 if self.intermediate_size is None: 36 hidden_dim = 4 * self.dim 37 n_hidden = int(2 * hidden_dim / 3) 38 self.intermediate_size = find_multiple(n_hidden, 256) 39 self.head_dim = self.dim // self.n_head 40 41 @classmethod 42 def from_name(cls, name: str): 43 if name in transformer_configs: 44 return cls(**transformer_configs[name]) 45 # fuzzy search 46 config = [ 47 config 48 for config in transformer_configs 49 if config in str(name).upper() or config in str(name) 50 ] 51 assert len(config) == 1, name 52 return cls(**transformer_configs[config[0]]) 53 54 55transformer_configs = { 56 "Mixtral-8x7B-v0.1": dict( 57 block_size=32768, 58 n_layer=16, 59 n_head=32, 60 n_local_heads=8, 61 dim=4096, 62 intermediate_size=14336, 63 rope_base=1000000.0, 64 num_experts=8, 65 num_activated_experts=2, 66 ), 67} 68 69 70class KVCache(nn.Module): 71 def __init__( 72 self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16 73 ): 74 super().__init__() 75 cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) 76 self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype)) 77 self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype)) 78 79 def update(self, input_pos, k_val, v_val): 80 # input_pos: [S], k_val: [B, H, S, D] 81 assert input_pos.shape[0] == k_val.shape[2] 82 83 k_out = self.k_cache 84 v_out = self.v_cache 85 k_out[:, :, input_pos] = k_val 86 v_out[:, :, input_pos] = v_val 87 88 return k_out, v_out 89 90 91class Transformer(nn.Module): 92 def __init__(self, config: ModelArgs) -> None: 93 super().__init__() 94 self.config = config 95 96 self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) 97 self.layers = nn.ModuleList( 98 TransformerBlock(config) for _ in range(config.n_layer) 99 ) 100 self.norm = RMSNorm(config.dim, eps=config.norm_eps) 101 self.output = nn.Linear(config.dim, config.vocab_size, bias=False) 102 103 self.freqs_cis: Optional[Tensor] = None 104 self.mask_cache: Optional[Tensor] = None 105 self.max_batch_size = -1 106 self.max_seq_length = -1 107 108 def setup_caches(self, max_batch_size, max_seq_length): 109 if ( 110 self.max_seq_length >= max_seq_length 111 and self.max_batch_size >= max_batch_size 112 ): 113 return 114 head_dim = self.config.dim // self.config.n_head 115 max_seq_length = find_multiple(max_seq_length, 8) 116 self.max_seq_length = max_seq_length 117 self.max_batch_size = max_batch_size 118 for b in self.layers: 119 b.attention.kv_cache = KVCache( 120 max_batch_size, max_seq_length, self.config.n_local_heads, head_dim 121 ) 122 123 self.freqs_cis = precompute_freqs_cis( 124 self.config.block_size, 125 self.config.dim // self.config.n_head, 126 self.config.rope_base, 127 ) 128 self.causal_mask = torch.tril( 129 torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool) 130 ) 131 132 def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: 133 assert self.freqs_cis is not None, "Caches must be initialized first" 134 mask = self.causal_mask[None, None, input_pos] 135 freqs_cis = self.freqs_cis[input_pos] 136 x = self.tok_embeddings(idx) 137 138 for i, layer in enumerate(self.layers): 139 x = layer(x, input_pos, freqs_cis, mask) 140 x = self.norm(x) 141 logits = self.output(x) 142 return logits 143 144 @classmethod 145 def from_name(cls, name: str): 146 return cls(ModelArgs.from_name(name)) 147 148 149class TransformerBlock(nn.Module): 150 def __init__(self, config: ModelArgs) -> None: 151 super().__init__() 152 self.attention = Attention(config) 153 self.block_sparse_moe = MOEFeedForward(config) 154 self.ffn_norm = RMSNorm(config.dim, config.norm_eps) 155 self.attention_norm = RMSNorm(config.dim, config.norm_eps) 156 157 def forward( 158 self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor 159 ) -> Tensor: 160 h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) 161 out = h + self.block_sparse_moe(self.ffn_norm(h)) 162 return out 163 164 165class Attention(nn.Module): 166 def __init__(self, config: ModelArgs): 167 super().__init__() 168 assert config.dim % config.n_head == 0 169 170 total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim 171 # key, query, value projections for all heads, but in a batch 172 self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) 173 self.wo = nn.Linear(config.dim, config.dim, bias=False) 174 self.kv_cache = None 175 176 self.n_head = config.n_head 177 self.head_dim = config.head_dim 178 self.n_local_heads = config.n_local_heads 179 self.dim = config.dim 180 self._register_load_state_dict_pre_hook(self.load_hook) 181 182 def load_hook(self, state_dict, prefix, *args): 183 if prefix + "wq.weight" in state_dict: 184 wq = state_dict.pop(prefix + "wq.weight") 185 wk = state_dict.pop(prefix + "wk.weight") 186 wv = state_dict.pop(prefix + "wv.weight") 187 state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) 188 189 def forward( 190 self, 191 x: Tensor, 192 freqs_cis: Tensor, 193 mask: Tensor, 194 input_pos: Optional[Tensor] = None, 195 ) -> Tensor: 196 bsz, seqlen, _ = x.shape 197 198 kv_size = self.n_local_heads * self.head_dim 199 q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) 200 201 q = q.view(bsz, seqlen, self.n_head, self.head_dim) 202 k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) 203 v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) 204 205 q = apply_rotary_emb(q, freqs_cis) 206 k = apply_rotary_emb(k, freqs_cis) 207 208 q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) 209 210 if self.kv_cache is not None: 211 k, v = self.kv_cache.update(input_pos, k, v) 212 213 k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) 214 v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) 215 y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) 216 217 y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) 218 219 y = self.wo(y) 220 return y 221 222 223class ConditionalFeedForward(nn.Module): 224 def __init__(self, config): 225 super().__init__() 226 self.w1 = nn.Parameter( 227 torch.empty(config.num_experts, config.intermediate_size, config.dim) 228 ) 229 self.w2 = nn.Parameter( 230 torch.empty(config.num_experts, config.dim, config.intermediate_size) 231 ) 232 self.w3 = nn.Parameter( 233 torch.empty(config.num_experts, config.intermediate_size, config.dim) 234 ) 235 236 def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor: 237 w1_weights = self.w1[expert_indices] # [T, A, D, D] 238 w3_weights = self.w3[expert_indices] # [T, A, D, D] 239 w2_weights = self.w2[expert_indices] # [T, A, D, D] 240 x1 = F.silu(torch.einsum("ti,taoi -> tao", x, w1_weights)) 241 x3 = torch.einsum("ti, taoi -> tao", x, w3_weights) 242 expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights) 243 return expert_outs 244 245 246class MOEFeedForward(nn.Module): 247 def __init__(self, config) -> None: 248 super().__init__() 249 self.gate = nn.Linear(config.dim, config.num_experts, bias=False) 250 self.cond_ffn = ConditionalFeedForward(config) 251 self.dim = config.dim 252 self.num_activated_experts = config.num_activated_experts 253 254 def forward(self, x: Tensor) -> Tensor: 255 x = x.view(-1, self.dim) 256 # T = num_tokens, E = num_experts, D = hidden dim, A = activated experts 257 # x: [T, D] 258 scores = self.gate(x) # [T, E] 259 expert_weights = F.softmax(scores, dim=-1) 260 expert_weights, expert_indices = torch.topk( 261 expert_weights, self.num_activated_experts, dim=-1 262 ) # [T, A], [T, A] 263 expert_weights /= expert_weights.sum(dim=-1, keepdim=True) # [T, A] 264 expert_outs = self.cond_ffn(x, expert_indices) 265 return torch.einsum("tai,ta -> ti", expert_outs, expert_weights) 266 267 268class RMSNorm(nn.Module): 269 def __init__(self, dim: int, eps: float = 1e-5): 270 super().__init__() 271 self.eps = eps 272 self.weight = nn.Parameter(torch.ones(dim)) 273 274 def _norm(self, x): 275 return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) 276 277 def forward(self, x: Tensor) -> Tensor: 278 output = self._norm(x.float()).type_as(x) 279 return output * self.weight 280 281 282def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor: 283 freqs = 1.0 / ( 284 base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem) 285 ) 286 t = torch.arange(seq_len, device=freqs.device) 287 freqs = torch.outer(t, freqs) 288 freqs_cis = torch.polar(torch.ones_like(freqs), freqs) 289 cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) 290 return cache.to(dtype=torch.bfloat16) 291 292 293def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: 294 xshaped = x.float().reshape(*x.shape[:-1], -1, 2) 295 freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) 296 x_out2 = torch.stack( 297 [ 298 xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], 299 xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], 300 ], 301 -1, 302 ) 303 304 x_out2 = x_out2.flatten(3) 305 return x_out2.type_as(x) 306