1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import logging 8from typing import Optional 9 10import torch 11import torchtune.modules.attention as TorchTuneAttention 12from executorch.extension.llm.modules.kv_cache import KVCache as InferenceKVCache 13from torch import nn 14from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention 15from torchtune.modules.kv_cache import KVCache 16 17logger = logging.getLogger(__name__) 18 19 20class MultiHeadAttention(nn.Module): 21 """ 22 NOTE: copied from Torchtune's mha.py. Should be mostly 1:1 except 23 that SDPA is factored out so that it can be swapped for more 24 efficient ExecuTorch-defined SDPA ops. 25 26 Multi-headed attention layer with support for grouped query 27 attention (GQA) introduced in https://arxiv.org/abs/2305.13245v1. 28 29 GQA is a version of multiheaded attention (MHA) which uses fewer 30 key/value heads than query heads by grouping n query heads for each 31 key and value head. Multi-Query Attention is an extreme 32 version where we have a single key and value head shared by all 33 query heads. 34 35 Following is an example of MHA, GQA and MQA with num_heads = 4 36 37 (credit for the documentation: 38 `litgpt.Config <https://github.com/Lightning-AI/litgpt/blob/eda1aaaf391fd689664f95487ab03dc137e213fd/litgpt/config.py>`_). 39 40 41 :: 42 43 ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ 44 │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │ 45 └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ 46 │ │ │ │ │ │ │ 47 ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ 48 │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │ 49 └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ 50 │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐ 51 ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ 52 │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ 53 └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ 54 ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶ 55 MHA GQA MQA 56 n_kv_heads =4 n_kv_heads=2 n_kv_heads=1 57 58 Args: 59 embed_dim (int): embedding dimension for the model 60 num_heads (int): number of query heads. For MHA this is also the 61 number of heads for key and value 62 num_kv_heads (int): number of key and value heads. User should ensure 63 ``num_heads % num_kv_heads == 0``. For standard MHA set ``num_kv_heads == num_heads``, 64 for GQA ``num_kv_heads < num_heads``, and for MQA set ``num_kv_heads == 1``. 65 head_dim (int): dimension of each head, calculated by ``embed_dim // num_heads``. 66 q_proj (nn.Module): projection layer for query. 67 k_proj (nn.Module): projection layer for key. 68 v_proj (nn.Module): projection layer for value. 69 output_proj (nn.Module): projection layer for output. 70 pos_embeddings (Optional[nn.Module]): positional embeddings layer, e.g. RotaryPositionalEmbeddings. 71 q_norm (Optional[nn.Module]): normalization layer for query, e.g. RMSNorm. For decoding, this is applied 72 before updating from kv_cache. This means it will only support token wide normalization and not 73 batch or sequence wide normalization. 74 k_norm (Optional[nn.Module]): normalization layer for key, must be set if q_norm is. 75 kv_cache (Optional[KVCache]): KVCache object used to cache key and value 76 max_seq_len (int): maximum sequence length supported by the model. 77 This is needed to compute the RoPE Cache. Default: 4096. 78 is_causal (bool): sets the default mask to causal when no mask is provided 79 attn_dropout (float): dropout value passed onto the scaled_dot_product_attention function. 80 Default value is 0.0. 81 82 Raises: 83 ValueError: If ``num_heads % num_kv_heads != 0`` 84 ValueError: If ``embed_dim % num_heads != 0`` 85 ValueError: If ``attn_dropout < 0`` or ``attn_dropout > 1`` 86 ValueError: if q_norm is defined without k_norm or vice versa 87 """ 88 89 def __init__( 90 self, 91 *, 92 embed_dim: int, 93 num_heads: int, 94 num_kv_heads: int, 95 head_dim: int, 96 q_proj: nn.Module, 97 k_proj: nn.Module, 98 v_proj: nn.Module, 99 output_proj: nn.Module, 100 pos_embeddings: Optional[nn.Module] = None, 101 q_norm: Optional[nn.Module] = None, 102 k_norm: Optional[nn.Module] = None, 103 kv_cache: Optional[KVCache] = None, 104 max_seq_len: int = 4096, 105 is_causal: bool = True, 106 attn_dropout: float = 0.0, 107 ) -> None: 108 super().__init__() 109 if num_heads % num_kv_heads != 0: 110 raise ValueError( 111 f"num_heads ({num_heads}) must be divisible by " 112 f"num_kv_heads ({num_kv_heads})" 113 ) 114 115 if embed_dim % num_heads != 0: 116 raise ValueError( 117 f"embed_dim ({embed_dim}) must be divisible by " 118 f"num_heads ({num_heads})" 119 ) 120 121 if attn_dropout < 0 or attn_dropout > 1: 122 raise ValueError(f"attn_dropout ({embed_dim}) must be between 0.0 and 1.0") 123 124 if bool(q_norm) ^ bool(k_norm): 125 raise ValueError("q and k norm must be set together") 126 127 # Set attributes 128 self.num_heads = num_heads 129 self.num_kv_heads = num_kv_heads 130 self.embed_dim = embed_dim 131 self.attn_dropout = attn_dropout 132 self.head_dim = head_dim 133 self.max_seq_len = max_seq_len 134 self.is_causal = is_causal 135 136 # Set layers 137 self.kv_cache = kv_cache 138 self.q_proj = q_proj 139 self.k_proj = k_proj 140 self.v_proj = v_proj 141 self.output_proj = output_proj 142 self.q_norm = q_norm 143 self.k_norm = k_norm 144 self.pos_embeddings = pos_embeddings 145 146 # Use flex attention if supported and we are sample packing 147 self._attention_call = _sdpa_or_flex_attention() 148 self._sdpa = SDPA( 149 num_kv_heads=self.num_kv_heads, 150 num_heads=self.num_heads, 151 head_dim=self.head_dim, 152 attn_dropout=self.attn_dropout if self.training else 0.0, 153 is_causal=self.is_causal, 154 attention_fn=self._attention_call, 155 kv_cache=self.kv_cache, 156 ) 157 158 # this flag indicates whether to update the kv-cache during forward 159 # passes. when disabled, we can have the cache setup but still 160 # perform normal forward passes 161 self.cache_enabled = False 162 163 def setup_cache( 164 self, batch_size: int, dtype: torch.dtype, max_seq_len: int 165 ) -> None: 166 """Setup key value caches for attention calculation. If called 167 after kv_cache is already setup, this will be skipped. 168 169 Args: 170 batch_size (int): batch size for the caches. 171 dtype (torch.dtype): dtype for the caches. 172 max_seq_len (int): maximum sequence length model will be run with. 173 """ 174 # Don't overwrite user defined kv_cache from init 175 if self.kv_cache is not None: 176 logger.warning( 177 "Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping." 178 ) 179 else: 180 self.kv_cache = InferenceKVCache( 181 batch_size=batch_size, 182 max_seq_len=max_seq_len, 183 num_kv_heads=self.num_kv_heads, 184 head_dim=self.head_dim, 185 dtype=dtype, 186 transpose_cache=False, 187 ) 188 self._sdpa.kv_cache = self.kv_cache 189 self.cache_enabled = True 190 191 def reset_cache(self): 192 """Reset the key value caches.""" 193 if self.kv_cache is None: 194 raise RuntimeError( 195 "Key value caches are not setup. Call ``setup_caches()`` first." 196 ) 197 self.kv_cache.reset() 198 199 def forward( 200 self, 201 x: torch.Tensor, 202 y: Optional[torch.Tensor] = None, 203 *, 204 mask: Optional[_MaskType] = None, 205 input_pos: Optional[torch.Tensor] = None, 206 ) -> torch.Tensor: 207 """ 208 Args: 209 x (torch.Tensor): input tensor with shape [b x s_x x d] for the query 210 y (Optional[torch.Tensor]): second input tensor with shape [b x s_y x d], is the input 211 for k and v. For self attention, x=y. Optional only with kv_cache enabled. 212 mask (Optional[_MaskType]): Used to mask the scores after the query-key multiplication 213 and before the softmax. Either: 214 215 A boolean tensor with shape ``[b x s x s]``, ``[b x s x self.encoder_max_cache_seq_len]``, 216 or ``[b x s x self.encoder_max_cache_seq_len]`` if using KV-cacheing with encoder/decoder layers. 217 A value of True in row ``i`` and column ``j`` means token ``i`` attends to token ``j``. A value of False means 218 token ``i`` does not attend to token ``j``. If no mask is specified, a causal mask 219 is used by default. 220 221 A :class:`~torch.nn.attention.flex_attention.BlockMask` for document masking in a packed sequence 222 created via `create_block_mask <https://pytorch.org/blog/flexattention/#mask-mods>`_. We use 223 :func:`~torch.nn.attention.flex_attention.flex_attention` when computing attention with block masks. 224 Default is None. 225 input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids 226 of each token. During training, this is used to indicate the positions 227 of each token relative to its sample when packed, shape [b x s]. 228 During inference, this indicates the position of the current token. 229 If none, assume the index of the token is its position id. Default is None. 230 231 Raises: 232 ValueError: If no ``y`` input and ``kv_cache`` is not enabled. 233 234 Returns: 235 torch.Tensor: output tensor with attention applied 236 237 Notation used for tensor shapes: 238 - b: batch size 239 - s_x: sequence length for x 240 - s_y: sequence length for y 241 - n_h: num heads 242 - n_kv: num kv heads 243 - d: embed dim 244 - h_d: head dim 245 """ 246 # x has shape [b, s_x, d] 247 # y has shape [b, s_y, d] 248 b, s_x, _ = x.shape 249 250 # q has shape [b, s_x, num_heads * head_dim] 251 q = self.q_proj(x) 252 253 # number of queries per key/value 254 q_per_kv = self.num_heads // self.num_kv_heads 255 q = q.view(b, s_x, self.num_kv_heads * q_per_kv, self.head_dim) 256 257 # Apply positional embeddings 258 if self.pos_embeddings is not None: 259 q = self.pos_embeddings(q, input_pos=input_pos) 260 261 # Normalize q 262 if self.q_norm is not None: 263 q = self.q_norm(q) 264 265 def calculate_kv(y): 266 # Update k and v shape, positional embeddings, and normalization 267 s_y = y.shape[1] 268 # k has shape [b, s_y, num_kv_heads * head_dim] 269 # v has shape [b, s_y, num_kv_heads * head_dim] 270 k = self.k_proj(y) 271 v = self.v_proj(y) 272 273 # Apply positional embeddings 274 # k: [b, s_y, n_kv, h_d] 275 k = k.view(b, s_y, -1, self.head_dim) 276 v = v.view(b, s_y, -1, self.head_dim) 277 if self.pos_embeddings is not None: 278 k = self.pos_embeddings(k, input_pos=input_pos) 279 280 # Normalize k 281 if self.k_norm is not None: 282 k = self.k_norm(k) 283 return k, v 284 285 def true_fn(y): 286 kv_cache = self.kv_cache.clone() 287 return kv_cache.k_cache, kv_cache.v_cache, kv_cache.cache_pos 288 289 def false_fn(y): 290 k, v = calculate_kv(y) 291 kv_cache = self.kv_cache.clone() 292 kv_cache.update(k, v) 293 return kv_cache.k_cache, kv_cache.v_cache, kv_cache.cache_pos 294 295 # If kv cache is None, we expect y to be provided 296 if self.kv_cache is None: 297 assert ( 298 y is not None 299 ), "Must provide y input or use kv_cache to enable streaming decoding" 300 k, v = calculate_kv(y) 301 else: 302 # Expecting the k, v returning here to be the same size of self.kv_cache 303 # In eager, we expect this predicate to specialize. In export, this will 304 # become a SymBool so it's not specialized. 305 k, v, cache_pos = torch.cond( 306 torch.isnan(y).all().item(), true_fn, false_fn, (y,) 307 ) 308 # Update key-value cache 309 self.kv_cache.k_cache.copy_(k) 310 self.kv_cache.v_cache.copy_(v) 311 self.kv_cache.cache_pos.copy_(cache_pos) 312 313 output = self._sdpa(q, k, v, b, s_x, mask=mask) 314 return self.output_proj(output) 315 316 317class SDPA(nn.Module): 318 """ 319 TorchTune's SDPA which can be optimized and can be swapped 320 out for a more efficient implementations. 321 """ 322 323 def __init__( 324 self, 325 num_kv_heads: int, 326 num_heads: int, 327 head_dim: int, 328 attn_dropout: float, 329 is_causal: bool, 330 attention_fn, 331 kv_cache, 332 ) -> None: 333 super().__init__() 334 self.num_kv_heads = num_kv_heads 335 self.num_heads = num_heads 336 self.head_dim = head_dim 337 self.q_per_kv = self.num_heads // self.num_kv_heads 338 self.attn_dropout = attn_dropout 339 self.is_causal = is_causal 340 self._attention_fn = attention_fn 341 self.kv_cache = kv_cache 342 343 def forward( 344 self, 345 q: torch.Tensor, # [b, s, n_h, h_d] 346 k: torch.Tensor, # [b, s, n_kv, h_d] 347 v: torch.Tensor, # [b, s, n_kv, h_d] 348 bsz: int, 349 seq_len: int, 350 mask: Optional[_MaskType] = None, 351 ) -> torch.Tensor: 352 # View + expand + reshape bring num_kv_heads to num_heads for k and v 353 # to match q. 354 355 # [bsz, n_h, s, h_d] 356 q = q.transpose(1, 2) 357 k = k.transpose(1, 2) 358 v = v.transpose(1, 2) 359 360 # Expand the key and value tensors to have the same shape 361 # as the query tensor by copying values across the relevant dim 362 if self.num_heads != self.num_kv_heads: 363 expand_shape = (-1, -1, self.q_per_kv, -1, -1) 364 k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2) 365 v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2) 366 367 output = self._attention_fn( 368 q, 369 k, 370 v, 371 mask=mask, 372 dropout_p=self.attn_dropout, 373 is_causal=self.kv_cache is None and mask is None and self.is_causal, 374 ) 375 # Reshape the output to be the same shape as the input 376 return output.transpose(1, 2).contiguous().view(bsz, seq_len, -1) 377 378 379def _replace_mha_with_inference_mha(module: torch.nn.Module) -> None: 380 for name, child in module.named_children(): 381 if isinstance(child, TorchTuneAttention.MultiHeadAttention): 382 setattr( 383 module, 384 name, 385 MultiHeadAttention( 386 embed_dim=child.embed_dim, 387 num_heads=child.num_heads, 388 num_kv_heads=child.num_kv_heads, 389 head_dim=child.head_dim, 390 q_proj=child.q_proj, 391 k_proj=child.k_proj, 392 v_proj=child.v_proj, 393 output_proj=child.output_proj, 394 pos_embeddings=child.pos_embeddings, 395 q_norm=child.q_norm, 396 k_norm=child.k_norm, 397 kv_cache=child.kv_cache, 398 max_seq_len=child.max_seq_len, 399 is_causal=child.is_causal, 400 attn_dropout=child.attn_dropout, 401 ), 402 ) 403 else: 404 replace_mha_with_inference_mha(child) 405 406 407def replace_mha_with_inference_mha(module: torch.nn.Module) -> torch.nn.Module: 408 """ 409 Replace TorchTune's MHA with an inference friendly version of MHA that 410 separates out the inference-related parts for further optimization. 411 """ 412 _replace_mha_with_inference_mha(module) 413 return module 414