1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Workerimport logging 8*523fa7a6SAndroid Build Coastguard Workerfrom typing import Optional 9*523fa7a6SAndroid Build Coastguard Worker 10*523fa7a6SAndroid Build Coastguard Workerimport torch 11*523fa7a6SAndroid Build Coastguard Workerimport torchtune.modules.attention as TorchTuneAttention 12*523fa7a6SAndroid Build Coastguard Workerfrom executorch.extension.llm.modules.kv_cache import KVCache as InferenceKVCache 13*523fa7a6SAndroid Build Coastguard Workerfrom torch import nn 14*523fa7a6SAndroid Build Coastguard Workerfrom torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention 15*523fa7a6SAndroid Build Coastguard Workerfrom torchtune.modules.kv_cache import KVCache 16*523fa7a6SAndroid Build Coastguard Worker 17*523fa7a6SAndroid Build Coastguard Workerlogger = logging.getLogger(__name__) 18*523fa7a6SAndroid Build Coastguard Worker 19*523fa7a6SAndroid Build Coastguard Worker 20*523fa7a6SAndroid Build Coastguard Workerclass MultiHeadAttention(nn.Module): 21*523fa7a6SAndroid Build Coastguard Worker """ 22*523fa7a6SAndroid Build Coastguard Worker NOTE: copied from Torchtune's mha.py. Should be mostly 1:1 except 23*523fa7a6SAndroid Build Coastguard Worker that SDPA is factored out so that it can be swapped for more 24*523fa7a6SAndroid Build Coastguard Worker efficient ExecuTorch-defined SDPA ops. 25*523fa7a6SAndroid Build Coastguard Worker 26*523fa7a6SAndroid Build Coastguard Worker Multi-headed attention layer with support for grouped query 27*523fa7a6SAndroid Build Coastguard Worker attention (GQA) introduced in https://arxiv.org/abs/2305.13245v1. 28*523fa7a6SAndroid Build Coastguard Worker 29*523fa7a6SAndroid Build Coastguard Worker GQA is a version of multiheaded attention (MHA) which uses fewer 30*523fa7a6SAndroid Build Coastguard Worker key/value heads than query heads by grouping n query heads for each 31*523fa7a6SAndroid Build Coastguard Worker key and value head. Multi-Query Attention is an extreme 32*523fa7a6SAndroid Build Coastguard Worker version where we have a single key and value head shared by all 33*523fa7a6SAndroid Build Coastguard Worker query heads. 34*523fa7a6SAndroid Build Coastguard Worker 35*523fa7a6SAndroid Build Coastguard Worker Following is an example of MHA, GQA and MQA with num_heads = 4 36*523fa7a6SAndroid Build Coastguard Worker 37*523fa7a6SAndroid Build Coastguard Worker (credit for the documentation: 38*523fa7a6SAndroid Build Coastguard Worker `litgpt.Config <https://github.com/Lightning-AI/litgpt/blob/eda1aaaf391fd689664f95487ab03dc137e213fd/litgpt/config.py>`_). 39*523fa7a6SAndroid Build Coastguard Worker 40*523fa7a6SAndroid Build Coastguard Worker 41*523fa7a6SAndroid Build Coastguard Worker :: 42*523fa7a6SAndroid Build Coastguard Worker 43*523fa7a6SAndroid Build Coastguard Worker ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ 44*523fa7a6SAndroid Build Coastguard Worker │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │ 45*523fa7a6SAndroid Build Coastguard Worker └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ 46*523fa7a6SAndroid Build Coastguard Worker │ │ │ │ │ │ │ 47*523fa7a6SAndroid Build Coastguard Worker ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ 48*523fa7a6SAndroid Build Coastguard Worker │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │ 49*523fa7a6SAndroid Build Coastguard Worker └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ 50*523fa7a6SAndroid Build Coastguard Worker │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐ 51*523fa7a6SAndroid Build Coastguard Worker ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ 52*523fa7a6SAndroid Build Coastguard Worker │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ 53*523fa7a6SAndroid Build Coastguard Worker └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ 54*523fa7a6SAndroid Build Coastguard Worker ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶ 55*523fa7a6SAndroid Build Coastguard Worker MHA GQA MQA 56*523fa7a6SAndroid Build Coastguard Worker n_kv_heads =4 n_kv_heads=2 n_kv_heads=1 57*523fa7a6SAndroid Build Coastguard Worker 58*523fa7a6SAndroid Build Coastguard Worker Args: 59*523fa7a6SAndroid Build Coastguard Worker embed_dim (int): embedding dimension for the model 60*523fa7a6SAndroid Build Coastguard Worker num_heads (int): number of query heads. For MHA this is also the 61*523fa7a6SAndroid Build Coastguard Worker number of heads for key and value 62*523fa7a6SAndroid Build Coastguard Worker num_kv_heads (int): number of key and value heads. User should ensure 63*523fa7a6SAndroid Build Coastguard Worker ``num_heads % num_kv_heads == 0``. For standard MHA set ``num_kv_heads == num_heads``, 64*523fa7a6SAndroid Build Coastguard Worker for GQA ``num_kv_heads < num_heads``, and for MQA set ``num_kv_heads == 1``. 65*523fa7a6SAndroid Build Coastguard Worker head_dim (int): dimension of each head, calculated by ``embed_dim // num_heads``. 66*523fa7a6SAndroid Build Coastguard Worker q_proj (nn.Module): projection layer for query. 67*523fa7a6SAndroid Build Coastguard Worker k_proj (nn.Module): projection layer for key. 68*523fa7a6SAndroid Build Coastguard Worker v_proj (nn.Module): projection layer for value. 69*523fa7a6SAndroid Build Coastguard Worker output_proj (nn.Module): projection layer for output. 70*523fa7a6SAndroid Build Coastguard Worker pos_embeddings (Optional[nn.Module]): positional embeddings layer, e.g. RotaryPositionalEmbeddings. 71*523fa7a6SAndroid Build Coastguard Worker q_norm (Optional[nn.Module]): normalization layer for query, e.g. RMSNorm. For decoding, this is applied 72*523fa7a6SAndroid Build Coastguard Worker before updating from kv_cache. This means it will only support token wide normalization and not 73*523fa7a6SAndroid Build Coastguard Worker batch or sequence wide normalization. 74*523fa7a6SAndroid Build Coastguard Worker k_norm (Optional[nn.Module]): normalization layer for key, must be set if q_norm is. 75*523fa7a6SAndroid Build Coastguard Worker kv_cache (Optional[KVCache]): KVCache object used to cache key and value 76*523fa7a6SAndroid Build Coastguard Worker max_seq_len (int): maximum sequence length supported by the model. 77*523fa7a6SAndroid Build Coastguard Worker This is needed to compute the RoPE Cache. Default: 4096. 78*523fa7a6SAndroid Build Coastguard Worker is_causal (bool): sets the default mask to causal when no mask is provided 79*523fa7a6SAndroid Build Coastguard Worker attn_dropout (float): dropout value passed onto the scaled_dot_product_attention function. 80*523fa7a6SAndroid Build Coastguard Worker Default value is 0.0. 81*523fa7a6SAndroid Build Coastguard Worker 82*523fa7a6SAndroid Build Coastguard Worker Raises: 83*523fa7a6SAndroid Build Coastguard Worker ValueError: If ``num_heads % num_kv_heads != 0`` 84*523fa7a6SAndroid Build Coastguard Worker ValueError: If ``embed_dim % num_heads != 0`` 85*523fa7a6SAndroid Build Coastguard Worker ValueError: If ``attn_dropout < 0`` or ``attn_dropout > 1`` 86*523fa7a6SAndroid Build Coastguard Worker ValueError: if q_norm is defined without k_norm or vice versa 87*523fa7a6SAndroid Build Coastguard Worker """ 88*523fa7a6SAndroid Build Coastguard Worker 89*523fa7a6SAndroid Build Coastguard Worker def __init__( 90*523fa7a6SAndroid Build Coastguard Worker self, 91*523fa7a6SAndroid Build Coastguard Worker *, 92*523fa7a6SAndroid Build Coastguard Worker embed_dim: int, 93*523fa7a6SAndroid Build Coastguard Worker num_heads: int, 94*523fa7a6SAndroid Build Coastguard Worker num_kv_heads: int, 95*523fa7a6SAndroid Build Coastguard Worker head_dim: int, 96*523fa7a6SAndroid Build Coastguard Worker q_proj: nn.Module, 97*523fa7a6SAndroid Build Coastguard Worker k_proj: nn.Module, 98*523fa7a6SAndroid Build Coastguard Worker v_proj: nn.Module, 99*523fa7a6SAndroid Build Coastguard Worker output_proj: nn.Module, 100*523fa7a6SAndroid Build Coastguard Worker pos_embeddings: Optional[nn.Module] = None, 101*523fa7a6SAndroid Build Coastguard Worker q_norm: Optional[nn.Module] = None, 102*523fa7a6SAndroid Build Coastguard Worker k_norm: Optional[nn.Module] = None, 103*523fa7a6SAndroid Build Coastguard Worker kv_cache: Optional[KVCache] = None, 104*523fa7a6SAndroid Build Coastguard Worker max_seq_len: int = 4096, 105*523fa7a6SAndroid Build Coastguard Worker is_causal: bool = True, 106*523fa7a6SAndroid Build Coastguard Worker attn_dropout: float = 0.0, 107*523fa7a6SAndroid Build Coastguard Worker ) -> None: 108*523fa7a6SAndroid Build Coastguard Worker super().__init__() 109*523fa7a6SAndroid Build Coastguard Worker if num_heads % num_kv_heads != 0: 110*523fa7a6SAndroid Build Coastguard Worker raise ValueError( 111*523fa7a6SAndroid Build Coastguard Worker f"num_heads ({num_heads}) must be divisible by " 112*523fa7a6SAndroid Build Coastguard Worker f"num_kv_heads ({num_kv_heads})" 113*523fa7a6SAndroid Build Coastguard Worker ) 114*523fa7a6SAndroid Build Coastguard Worker 115*523fa7a6SAndroid Build Coastguard Worker if embed_dim % num_heads != 0: 116*523fa7a6SAndroid Build Coastguard Worker raise ValueError( 117*523fa7a6SAndroid Build Coastguard Worker f"embed_dim ({embed_dim}) must be divisible by " 118*523fa7a6SAndroid Build Coastguard Worker f"num_heads ({num_heads})" 119*523fa7a6SAndroid Build Coastguard Worker ) 120*523fa7a6SAndroid Build Coastguard Worker 121*523fa7a6SAndroid Build Coastguard Worker if attn_dropout < 0 or attn_dropout > 1: 122*523fa7a6SAndroid Build Coastguard Worker raise ValueError(f"attn_dropout ({embed_dim}) must be between 0.0 and 1.0") 123*523fa7a6SAndroid Build Coastguard Worker 124*523fa7a6SAndroid Build Coastguard Worker if bool(q_norm) ^ bool(k_norm): 125*523fa7a6SAndroid Build Coastguard Worker raise ValueError("q and k norm must be set together") 126*523fa7a6SAndroid Build Coastguard Worker 127*523fa7a6SAndroid Build Coastguard Worker # Set attributes 128*523fa7a6SAndroid Build Coastguard Worker self.num_heads = num_heads 129*523fa7a6SAndroid Build Coastguard Worker self.num_kv_heads = num_kv_heads 130*523fa7a6SAndroid Build Coastguard Worker self.embed_dim = embed_dim 131*523fa7a6SAndroid Build Coastguard Worker self.attn_dropout = attn_dropout 132*523fa7a6SAndroid Build Coastguard Worker self.head_dim = head_dim 133*523fa7a6SAndroid Build Coastguard Worker self.max_seq_len = max_seq_len 134*523fa7a6SAndroid Build Coastguard Worker self.is_causal = is_causal 135*523fa7a6SAndroid Build Coastguard Worker 136*523fa7a6SAndroid Build Coastguard Worker # Set layers 137*523fa7a6SAndroid Build Coastguard Worker self.kv_cache = kv_cache 138*523fa7a6SAndroid Build Coastguard Worker self.q_proj = q_proj 139*523fa7a6SAndroid Build Coastguard Worker self.k_proj = k_proj 140*523fa7a6SAndroid Build Coastguard Worker self.v_proj = v_proj 141*523fa7a6SAndroid Build Coastguard Worker self.output_proj = output_proj 142*523fa7a6SAndroid Build Coastguard Worker self.q_norm = q_norm 143*523fa7a6SAndroid Build Coastguard Worker self.k_norm = k_norm 144*523fa7a6SAndroid Build Coastguard Worker self.pos_embeddings = pos_embeddings 145*523fa7a6SAndroid Build Coastguard Worker 146*523fa7a6SAndroid Build Coastguard Worker # Use flex attention if supported and we are sample packing 147*523fa7a6SAndroid Build Coastguard Worker self._attention_call = _sdpa_or_flex_attention() 148*523fa7a6SAndroid Build Coastguard Worker self._sdpa = SDPA( 149*523fa7a6SAndroid Build Coastguard Worker num_kv_heads=self.num_kv_heads, 150*523fa7a6SAndroid Build Coastguard Worker num_heads=self.num_heads, 151*523fa7a6SAndroid Build Coastguard Worker head_dim=self.head_dim, 152*523fa7a6SAndroid Build Coastguard Worker attn_dropout=self.attn_dropout if self.training else 0.0, 153*523fa7a6SAndroid Build Coastguard Worker is_causal=self.is_causal, 154*523fa7a6SAndroid Build Coastguard Worker attention_fn=self._attention_call, 155*523fa7a6SAndroid Build Coastguard Worker kv_cache=self.kv_cache, 156*523fa7a6SAndroid Build Coastguard Worker ) 157*523fa7a6SAndroid Build Coastguard Worker 158*523fa7a6SAndroid Build Coastguard Worker # this flag indicates whether to update the kv-cache during forward 159*523fa7a6SAndroid Build Coastguard Worker # passes. when disabled, we can have the cache setup but still 160*523fa7a6SAndroid Build Coastguard Worker # perform normal forward passes 161*523fa7a6SAndroid Build Coastguard Worker self.cache_enabled = False 162*523fa7a6SAndroid Build Coastguard Worker 163*523fa7a6SAndroid Build Coastguard Worker def setup_cache( 164*523fa7a6SAndroid Build Coastguard Worker self, batch_size: int, dtype: torch.dtype, max_seq_len: int 165*523fa7a6SAndroid Build Coastguard Worker ) -> None: 166*523fa7a6SAndroid Build Coastguard Worker """Setup key value caches for attention calculation. If called 167*523fa7a6SAndroid Build Coastguard Worker after kv_cache is already setup, this will be skipped. 168*523fa7a6SAndroid Build Coastguard Worker 169*523fa7a6SAndroid Build Coastguard Worker Args: 170*523fa7a6SAndroid Build Coastguard Worker batch_size (int): batch size for the caches. 171*523fa7a6SAndroid Build Coastguard Worker dtype (torch.dtype): dtype for the caches. 172*523fa7a6SAndroid Build Coastguard Worker max_seq_len (int): maximum sequence length model will be run with. 173*523fa7a6SAndroid Build Coastguard Worker """ 174*523fa7a6SAndroid Build Coastguard Worker # Don't overwrite user defined kv_cache from init 175*523fa7a6SAndroid Build Coastguard Worker if self.kv_cache is not None: 176*523fa7a6SAndroid Build Coastguard Worker logger.warning( 177*523fa7a6SAndroid Build Coastguard Worker "Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping." 178*523fa7a6SAndroid Build Coastguard Worker ) 179*523fa7a6SAndroid Build Coastguard Worker else: 180*523fa7a6SAndroid Build Coastguard Worker self.kv_cache = InferenceKVCache( 181*523fa7a6SAndroid Build Coastguard Worker batch_size=batch_size, 182*523fa7a6SAndroid Build Coastguard Worker max_seq_len=max_seq_len, 183*523fa7a6SAndroid Build Coastguard Worker num_kv_heads=self.num_kv_heads, 184*523fa7a6SAndroid Build Coastguard Worker head_dim=self.head_dim, 185*523fa7a6SAndroid Build Coastguard Worker dtype=dtype, 186*523fa7a6SAndroid Build Coastguard Worker transpose_cache=False, 187*523fa7a6SAndroid Build Coastguard Worker ) 188*523fa7a6SAndroid Build Coastguard Worker self._sdpa.kv_cache = self.kv_cache 189*523fa7a6SAndroid Build Coastguard Worker self.cache_enabled = True 190*523fa7a6SAndroid Build Coastguard Worker 191*523fa7a6SAndroid Build Coastguard Worker def reset_cache(self): 192*523fa7a6SAndroid Build Coastguard Worker """Reset the key value caches.""" 193*523fa7a6SAndroid Build Coastguard Worker if self.kv_cache is None: 194*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError( 195*523fa7a6SAndroid Build Coastguard Worker "Key value caches are not setup. Call ``setup_caches()`` first." 196*523fa7a6SAndroid Build Coastguard Worker ) 197*523fa7a6SAndroid Build Coastguard Worker self.kv_cache.reset() 198*523fa7a6SAndroid Build Coastguard Worker 199*523fa7a6SAndroid Build Coastguard Worker def forward( 200*523fa7a6SAndroid Build Coastguard Worker self, 201*523fa7a6SAndroid Build Coastguard Worker x: torch.Tensor, 202*523fa7a6SAndroid Build Coastguard Worker y: Optional[torch.Tensor] = None, 203*523fa7a6SAndroid Build Coastguard Worker *, 204*523fa7a6SAndroid Build Coastguard Worker mask: Optional[_MaskType] = None, 205*523fa7a6SAndroid Build Coastguard Worker input_pos: Optional[torch.Tensor] = None, 206*523fa7a6SAndroid Build Coastguard Worker ) -> torch.Tensor: 207*523fa7a6SAndroid Build Coastguard Worker """ 208*523fa7a6SAndroid Build Coastguard Worker Args: 209*523fa7a6SAndroid Build Coastguard Worker x (torch.Tensor): input tensor with shape [b x s_x x d] for the query 210*523fa7a6SAndroid Build Coastguard Worker y (Optional[torch.Tensor]): second input tensor with shape [b x s_y x d], is the input 211*523fa7a6SAndroid Build Coastguard Worker for k and v. For self attention, x=y. Optional only with kv_cache enabled. 212*523fa7a6SAndroid Build Coastguard Worker mask (Optional[_MaskType]): Used to mask the scores after the query-key multiplication 213*523fa7a6SAndroid Build Coastguard Worker and before the softmax. Either: 214*523fa7a6SAndroid Build Coastguard Worker 215*523fa7a6SAndroid Build Coastguard Worker A boolean tensor with shape ``[b x s x s]``, ``[b x s x self.encoder_max_cache_seq_len]``, 216*523fa7a6SAndroid Build Coastguard Worker or ``[b x s x self.encoder_max_cache_seq_len]`` if using KV-cacheing with encoder/decoder layers. 217*523fa7a6SAndroid Build Coastguard Worker A value of True in row ``i`` and column ``j`` means token ``i`` attends to token ``j``. A value of False means 218*523fa7a6SAndroid Build Coastguard Worker token ``i`` does not attend to token ``j``. If no mask is specified, a causal mask 219*523fa7a6SAndroid Build Coastguard Worker is used by default. 220*523fa7a6SAndroid Build Coastguard Worker 221*523fa7a6SAndroid Build Coastguard Worker A :class:`~torch.nn.attention.flex_attention.BlockMask` for document masking in a packed sequence 222*523fa7a6SAndroid Build Coastguard Worker created via `create_block_mask <https://pytorch.org/blog/flexattention/#mask-mods>`_. We use 223*523fa7a6SAndroid Build Coastguard Worker :func:`~torch.nn.attention.flex_attention.flex_attention` when computing attention with block masks. 224*523fa7a6SAndroid Build Coastguard Worker Default is None. 225*523fa7a6SAndroid Build Coastguard Worker input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids 226*523fa7a6SAndroid Build Coastguard Worker of each token. During training, this is used to indicate the positions 227*523fa7a6SAndroid Build Coastguard Worker of each token relative to its sample when packed, shape [b x s]. 228*523fa7a6SAndroid Build Coastguard Worker During inference, this indicates the position of the current token. 229*523fa7a6SAndroid Build Coastguard Worker If none, assume the index of the token is its position id. Default is None. 230*523fa7a6SAndroid Build Coastguard Worker 231*523fa7a6SAndroid Build Coastguard Worker Raises: 232*523fa7a6SAndroid Build Coastguard Worker ValueError: If no ``y`` input and ``kv_cache`` is not enabled. 233*523fa7a6SAndroid Build Coastguard Worker 234*523fa7a6SAndroid Build Coastguard Worker Returns: 235*523fa7a6SAndroid Build Coastguard Worker torch.Tensor: output tensor with attention applied 236*523fa7a6SAndroid Build Coastguard Worker 237*523fa7a6SAndroid Build Coastguard Worker Notation used for tensor shapes: 238*523fa7a6SAndroid Build Coastguard Worker - b: batch size 239*523fa7a6SAndroid Build Coastguard Worker - s_x: sequence length for x 240*523fa7a6SAndroid Build Coastguard Worker - s_y: sequence length for y 241*523fa7a6SAndroid Build Coastguard Worker - n_h: num heads 242*523fa7a6SAndroid Build Coastguard Worker - n_kv: num kv heads 243*523fa7a6SAndroid Build Coastguard Worker - d: embed dim 244*523fa7a6SAndroid Build Coastguard Worker - h_d: head dim 245*523fa7a6SAndroid Build Coastguard Worker """ 246*523fa7a6SAndroid Build Coastguard Worker # x has shape [b, s_x, d] 247*523fa7a6SAndroid Build Coastguard Worker # y has shape [b, s_y, d] 248*523fa7a6SAndroid Build Coastguard Worker b, s_x, _ = x.shape 249*523fa7a6SAndroid Build Coastguard Worker 250*523fa7a6SAndroid Build Coastguard Worker # q has shape [b, s_x, num_heads * head_dim] 251*523fa7a6SAndroid Build Coastguard Worker q = self.q_proj(x) 252*523fa7a6SAndroid Build Coastguard Worker 253*523fa7a6SAndroid Build Coastguard Worker # number of queries per key/value 254*523fa7a6SAndroid Build Coastguard Worker q_per_kv = self.num_heads // self.num_kv_heads 255*523fa7a6SAndroid Build Coastguard Worker q = q.view(b, s_x, self.num_kv_heads * q_per_kv, self.head_dim) 256*523fa7a6SAndroid Build Coastguard Worker 257*523fa7a6SAndroid Build Coastguard Worker # Apply positional embeddings 258*523fa7a6SAndroid Build Coastguard Worker if self.pos_embeddings is not None: 259*523fa7a6SAndroid Build Coastguard Worker q = self.pos_embeddings(q, input_pos=input_pos) 260*523fa7a6SAndroid Build Coastguard Worker 261*523fa7a6SAndroid Build Coastguard Worker # Normalize q 262*523fa7a6SAndroid Build Coastguard Worker if self.q_norm is not None: 263*523fa7a6SAndroid Build Coastguard Worker q = self.q_norm(q) 264*523fa7a6SAndroid Build Coastguard Worker 265*523fa7a6SAndroid Build Coastguard Worker def calculate_kv(y): 266*523fa7a6SAndroid Build Coastguard Worker # Update k and v shape, positional embeddings, and normalization 267*523fa7a6SAndroid Build Coastguard Worker s_y = y.shape[1] 268*523fa7a6SAndroid Build Coastguard Worker # k has shape [b, s_y, num_kv_heads * head_dim] 269*523fa7a6SAndroid Build Coastguard Worker # v has shape [b, s_y, num_kv_heads * head_dim] 270*523fa7a6SAndroid Build Coastguard Worker k = self.k_proj(y) 271*523fa7a6SAndroid Build Coastguard Worker v = self.v_proj(y) 272*523fa7a6SAndroid Build Coastguard Worker 273*523fa7a6SAndroid Build Coastguard Worker # Apply positional embeddings 274*523fa7a6SAndroid Build Coastguard Worker # k: [b, s_y, n_kv, h_d] 275*523fa7a6SAndroid Build Coastguard Worker k = k.view(b, s_y, -1, self.head_dim) 276*523fa7a6SAndroid Build Coastguard Worker v = v.view(b, s_y, -1, self.head_dim) 277*523fa7a6SAndroid Build Coastguard Worker if self.pos_embeddings is not None: 278*523fa7a6SAndroid Build Coastguard Worker k = self.pos_embeddings(k, input_pos=input_pos) 279*523fa7a6SAndroid Build Coastguard Worker 280*523fa7a6SAndroid Build Coastguard Worker # Normalize k 281*523fa7a6SAndroid Build Coastguard Worker if self.k_norm is not None: 282*523fa7a6SAndroid Build Coastguard Worker k = self.k_norm(k) 283*523fa7a6SAndroid Build Coastguard Worker return k, v 284*523fa7a6SAndroid Build Coastguard Worker 285*523fa7a6SAndroid Build Coastguard Worker def true_fn(y): 286*523fa7a6SAndroid Build Coastguard Worker kv_cache = self.kv_cache.clone() 287*523fa7a6SAndroid Build Coastguard Worker return kv_cache.k_cache, kv_cache.v_cache, kv_cache.cache_pos 288*523fa7a6SAndroid Build Coastguard Worker 289*523fa7a6SAndroid Build Coastguard Worker def false_fn(y): 290*523fa7a6SAndroid Build Coastguard Worker k, v = calculate_kv(y) 291*523fa7a6SAndroid Build Coastguard Worker kv_cache = self.kv_cache.clone() 292*523fa7a6SAndroid Build Coastguard Worker kv_cache.update(k, v) 293*523fa7a6SAndroid Build Coastguard Worker return kv_cache.k_cache, kv_cache.v_cache, kv_cache.cache_pos 294*523fa7a6SAndroid Build Coastguard Worker 295*523fa7a6SAndroid Build Coastguard Worker # If kv cache is None, we expect y to be provided 296*523fa7a6SAndroid Build Coastguard Worker if self.kv_cache is None: 297*523fa7a6SAndroid Build Coastguard Worker assert ( 298*523fa7a6SAndroid Build Coastguard Worker y is not None 299*523fa7a6SAndroid Build Coastguard Worker ), "Must provide y input or use kv_cache to enable streaming decoding" 300*523fa7a6SAndroid Build Coastguard Worker k, v = calculate_kv(y) 301*523fa7a6SAndroid Build Coastguard Worker else: 302*523fa7a6SAndroid Build Coastguard Worker # Expecting the k, v returning here to be the same size of self.kv_cache 303*523fa7a6SAndroid Build Coastguard Worker # In eager, we expect this predicate to specialize. In export, this will 304*523fa7a6SAndroid Build Coastguard Worker # become a SymBool so it's not specialized. 305*523fa7a6SAndroid Build Coastguard Worker k, v, cache_pos = torch.cond( 306*523fa7a6SAndroid Build Coastguard Worker torch.isnan(y).all().item(), true_fn, false_fn, (y,) 307*523fa7a6SAndroid Build Coastguard Worker ) 308*523fa7a6SAndroid Build Coastguard Worker # Update key-value cache 309*523fa7a6SAndroid Build Coastguard Worker self.kv_cache.k_cache.copy_(k) 310*523fa7a6SAndroid Build Coastguard Worker self.kv_cache.v_cache.copy_(v) 311*523fa7a6SAndroid Build Coastguard Worker self.kv_cache.cache_pos.copy_(cache_pos) 312*523fa7a6SAndroid Build Coastguard Worker 313*523fa7a6SAndroid Build Coastguard Worker output = self._sdpa(q, k, v, b, s_x, mask=mask) 314*523fa7a6SAndroid Build Coastguard Worker return self.output_proj(output) 315*523fa7a6SAndroid Build Coastguard Worker 316*523fa7a6SAndroid Build Coastguard Worker 317*523fa7a6SAndroid Build Coastguard Workerclass SDPA(nn.Module): 318*523fa7a6SAndroid Build Coastguard Worker """ 319*523fa7a6SAndroid Build Coastguard Worker TorchTune's SDPA which can be optimized and can be swapped 320*523fa7a6SAndroid Build Coastguard Worker out for a more efficient implementations. 321*523fa7a6SAndroid Build Coastguard Worker """ 322*523fa7a6SAndroid Build Coastguard Worker 323*523fa7a6SAndroid Build Coastguard Worker def __init__( 324*523fa7a6SAndroid Build Coastguard Worker self, 325*523fa7a6SAndroid Build Coastguard Worker num_kv_heads: int, 326*523fa7a6SAndroid Build Coastguard Worker num_heads: int, 327*523fa7a6SAndroid Build Coastguard Worker head_dim: int, 328*523fa7a6SAndroid Build Coastguard Worker attn_dropout: float, 329*523fa7a6SAndroid Build Coastguard Worker is_causal: bool, 330*523fa7a6SAndroid Build Coastguard Worker attention_fn, 331*523fa7a6SAndroid Build Coastguard Worker kv_cache, 332*523fa7a6SAndroid Build Coastguard Worker ) -> None: 333*523fa7a6SAndroid Build Coastguard Worker super().__init__() 334*523fa7a6SAndroid Build Coastguard Worker self.num_kv_heads = num_kv_heads 335*523fa7a6SAndroid Build Coastguard Worker self.num_heads = num_heads 336*523fa7a6SAndroid Build Coastguard Worker self.head_dim = head_dim 337*523fa7a6SAndroid Build Coastguard Worker self.q_per_kv = self.num_heads // self.num_kv_heads 338*523fa7a6SAndroid Build Coastguard Worker self.attn_dropout = attn_dropout 339*523fa7a6SAndroid Build Coastguard Worker self.is_causal = is_causal 340*523fa7a6SAndroid Build Coastguard Worker self._attention_fn = attention_fn 341*523fa7a6SAndroid Build Coastguard Worker self.kv_cache = kv_cache 342*523fa7a6SAndroid Build Coastguard Worker 343*523fa7a6SAndroid Build Coastguard Worker def forward( 344*523fa7a6SAndroid Build Coastguard Worker self, 345*523fa7a6SAndroid Build Coastguard Worker q: torch.Tensor, # [b, s, n_h, h_d] 346*523fa7a6SAndroid Build Coastguard Worker k: torch.Tensor, # [b, s, n_kv, h_d] 347*523fa7a6SAndroid Build Coastguard Worker v: torch.Tensor, # [b, s, n_kv, h_d] 348*523fa7a6SAndroid Build Coastguard Worker bsz: int, 349*523fa7a6SAndroid Build Coastguard Worker seq_len: int, 350*523fa7a6SAndroid Build Coastguard Worker mask: Optional[_MaskType] = None, 351*523fa7a6SAndroid Build Coastguard Worker ) -> torch.Tensor: 352*523fa7a6SAndroid Build Coastguard Worker # View + expand + reshape bring num_kv_heads to num_heads for k and v 353*523fa7a6SAndroid Build Coastguard Worker # to match q. 354*523fa7a6SAndroid Build Coastguard Worker 355*523fa7a6SAndroid Build Coastguard Worker # [bsz, n_h, s, h_d] 356*523fa7a6SAndroid Build Coastguard Worker q = q.transpose(1, 2) 357*523fa7a6SAndroid Build Coastguard Worker k = k.transpose(1, 2) 358*523fa7a6SAndroid Build Coastguard Worker v = v.transpose(1, 2) 359*523fa7a6SAndroid Build Coastguard Worker 360*523fa7a6SAndroid Build Coastguard Worker # Expand the key and value tensors to have the same shape 361*523fa7a6SAndroid Build Coastguard Worker # as the query tensor by copying values across the relevant dim 362*523fa7a6SAndroid Build Coastguard Worker if self.num_heads != self.num_kv_heads: 363*523fa7a6SAndroid Build Coastguard Worker expand_shape = (-1, -1, self.q_per_kv, -1, -1) 364*523fa7a6SAndroid Build Coastguard Worker k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2) 365*523fa7a6SAndroid Build Coastguard Worker v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2) 366*523fa7a6SAndroid Build Coastguard Worker 367*523fa7a6SAndroid Build Coastguard Worker output = self._attention_fn( 368*523fa7a6SAndroid Build Coastguard Worker q, 369*523fa7a6SAndroid Build Coastguard Worker k, 370*523fa7a6SAndroid Build Coastguard Worker v, 371*523fa7a6SAndroid Build Coastguard Worker mask=mask, 372*523fa7a6SAndroid Build Coastguard Worker dropout_p=self.attn_dropout, 373*523fa7a6SAndroid Build Coastguard Worker is_causal=self.kv_cache is None and mask is None and self.is_causal, 374*523fa7a6SAndroid Build Coastguard Worker ) 375*523fa7a6SAndroid Build Coastguard Worker # Reshape the output to be the same shape as the input 376*523fa7a6SAndroid Build Coastguard Worker return output.transpose(1, 2).contiguous().view(bsz, seq_len, -1) 377*523fa7a6SAndroid Build Coastguard Worker 378*523fa7a6SAndroid Build Coastguard Worker 379*523fa7a6SAndroid Build Coastguard Workerdef _replace_mha_with_inference_mha(module: torch.nn.Module) -> None: 380*523fa7a6SAndroid Build Coastguard Worker for name, child in module.named_children(): 381*523fa7a6SAndroid Build Coastguard Worker if isinstance(child, TorchTuneAttention.MultiHeadAttention): 382*523fa7a6SAndroid Build Coastguard Worker setattr( 383*523fa7a6SAndroid Build Coastguard Worker module, 384*523fa7a6SAndroid Build Coastguard Worker name, 385*523fa7a6SAndroid Build Coastguard Worker MultiHeadAttention( 386*523fa7a6SAndroid Build Coastguard Worker embed_dim=child.embed_dim, 387*523fa7a6SAndroid Build Coastguard Worker num_heads=child.num_heads, 388*523fa7a6SAndroid Build Coastguard Worker num_kv_heads=child.num_kv_heads, 389*523fa7a6SAndroid Build Coastguard Worker head_dim=child.head_dim, 390*523fa7a6SAndroid Build Coastguard Worker q_proj=child.q_proj, 391*523fa7a6SAndroid Build Coastguard Worker k_proj=child.k_proj, 392*523fa7a6SAndroid Build Coastguard Worker v_proj=child.v_proj, 393*523fa7a6SAndroid Build Coastguard Worker output_proj=child.output_proj, 394*523fa7a6SAndroid Build Coastguard Worker pos_embeddings=child.pos_embeddings, 395*523fa7a6SAndroid Build Coastguard Worker q_norm=child.q_norm, 396*523fa7a6SAndroid Build Coastguard Worker k_norm=child.k_norm, 397*523fa7a6SAndroid Build Coastguard Worker kv_cache=child.kv_cache, 398*523fa7a6SAndroid Build Coastguard Worker max_seq_len=child.max_seq_len, 399*523fa7a6SAndroid Build Coastguard Worker is_causal=child.is_causal, 400*523fa7a6SAndroid Build Coastguard Worker attn_dropout=child.attn_dropout, 401*523fa7a6SAndroid Build Coastguard Worker ), 402*523fa7a6SAndroid Build Coastguard Worker ) 403*523fa7a6SAndroid Build Coastguard Worker else: 404*523fa7a6SAndroid Build Coastguard Worker replace_mha_with_inference_mha(child) 405*523fa7a6SAndroid Build Coastguard Worker 406*523fa7a6SAndroid Build Coastguard Worker 407*523fa7a6SAndroid Build Coastguard Workerdef replace_mha_with_inference_mha(module: torch.nn.Module) -> torch.nn.Module: 408*523fa7a6SAndroid Build Coastguard Worker """ 409*523fa7a6SAndroid Build Coastguard Worker Replace TorchTune's MHA with an inference friendly version of MHA that 410*523fa7a6SAndroid Build Coastguard Worker separates out the inference-related parts for further optimization. 411*523fa7a6SAndroid Build Coastguard Worker """ 412*523fa7a6SAndroid Build Coastguard Worker _replace_mha_with_inference_mha(module) 413*523fa7a6SAndroid Build Coastguard Worker return module 414