xref: /aosp_15_r20/external/executorch/extension/llm/modules/attention.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Workerimport logging
8*523fa7a6SAndroid Build Coastguard Workerfrom typing import Optional
9*523fa7a6SAndroid Build Coastguard Worker
10*523fa7a6SAndroid Build Coastguard Workerimport torch
11*523fa7a6SAndroid Build Coastguard Workerimport torchtune.modules.attention as TorchTuneAttention
12*523fa7a6SAndroid Build Coastguard Workerfrom executorch.extension.llm.modules.kv_cache import KVCache as InferenceKVCache
13*523fa7a6SAndroid Build Coastguard Workerfrom torch import nn
14*523fa7a6SAndroid Build Coastguard Workerfrom torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention
15*523fa7a6SAndroid Build Coastguard Workerfrom torchtune.modules.kv_cache import KVCache
16*523fa7a6SAndroid Build Coastguard Worker
17*523fa7a6SAndroid Build Coastguard Workerlogger = logging.getLogger(__name__)
18*523fa7a6SAndroid Build Coastguard Worker
19*523fa7a6SAndroid Build Coastguard Worker
20*523fa7a6SAndroid Build Coastguard Workerclass MultiHeadAttention(nn.Module):
21*523fa7a6SAndroid Build Coastguard Worker    """
22*523fa7a6SAndroid Build Coastguard Worker    NOTE: copied from Torchtune's mha.py. Should be mostly 1:1 except
23*523fa7a6SAndroid Build Coastguard Worker    that SDPA is factored out so that it can be swapped for more
24*523fa7a6SAndroid Build Coastguard Worker    efficient ExecuTorch-defined SDPA ops.
25*523fa7a6SAndroid Build Coastguard Worker
26*523fa7a6SAndroid Build Coastguard Worker    Multi-headed attention layer with support for grouped query
27*523fa7a6SAndroid Build Coastguard Worker    attention (GQA) introduced in https://arxiv.org/abs/2305.13245v1.
28*523fa7a6SAndroid Build Coastguard Worker
29*523fa7a6SAndroid Build Coastguard Worker    GQA is a version of multiheaded attention (MHA) which uses fewer
30*523fa7a6SAndroid Build Coastguard Worker    key/value heads than query heads by grouping n query heads for each
31*523fa7a6SAndroid Build Coastguard Worker    key and value head. Multi-Query Attention is an extreme
32*523fa7a6SAndroid Build Coastguard Worker    version where we have a single key and value head shared by all
33*523fa7a6SAndroid Build Coastguard Worker    query heads.
34*523fa7a6SAndroid Build Coastguard Worker
35*523fa7a6SAndroid Build Coastguard Worker    Following is an example of MHA, GQA and MQA with num_heads = 4
36*523fa7a6SAndroid Build Coastguard Worker
37*523fa7a6SAndroid Build Coastguard Worker    (credit for the documentation:
38*523fa7a6SAndroid Build Coastguard Worker    `litgpt.Config <https://github.com/Lightning-AI/litgpt/blob/eda1aaaf391fd689664f95487ab03dc137e213fd/litgpt/config.py>`_).
39*523fa7a6SAndroid Build Coastguard Worker
40*523fa7a6SAndroid Build Coastguard Worker
41*523fa7a6SAndroid Build Coastguard Worker    ::
42*523fa7a6SAndroid Build Coastguard Worker
43*523fa7a6SAndroid Build Coastguard Worker        ┌───┐┌───┐┌───┐┌───┐     ┌───┐    ┌───┐             ┌───┐
44*523fa7a6SAndroid Build Coastguard Worker        │ v ││ v ││ v ││ v │     │ v │    │ v │             │ v │
45*523fa7a6SAndroid Build Coastguard Worker        └───┘└───┘└───┘└───┘     └───┘    └───┘             └───┘
46*523fa7a6SAndroid Build Coastguard Worker        │    │    │    │         │        │                 │
47*523fa7a6SAndroid Build Coastguard Worker        ┌───┐┌───┐┌───┐┌───┐     ┌───┐    ┌───┐             ┌───┐
48*523fa7a6SAndroid Build Coastguard Worker        │ k ││ k ││ k ││ k │     │ k │    │ k │             │ k │
49*523fa7a6SAndroid Build Coastguard Worker        └───┘└───┘└───┘└───┘     └───┘    └───┘             └───┘
50*523fa7a6SAndroid Build Coastguard Worker        │    │    │    │      ┌──┴──┐  ┌──┴──┐      ┌────┬──┴─┬────┐
51*523fa7a6SAndroid Build Coastguard Worker        ┌───┐┌───┐┌───┐┌───┐  ┌───┐┌───┐┌───┐┌───┐  ┌───┐┌───┐┌───┐┌───┐
52*523fa7a6SAndroid Build Coastguard Worker        │ q ││ q ││ q ││ q │  │ q ││ q ││ q ││ q │  │ q ││ q ││ q ││ q │
53*523fa7a6SAndroid Build Coastguard Worker        └───┘└───┘└───┘└───┘  └───┘└───┘└───┘└───┘  └───┘└───┘└───┘└───┘
54*523fa7a6SAndroid Build Coastguard Worker        ◀──────────────────▶  ◀──────────────────▶  ◀──────────────────▶
55*523fa7a6SAndroid Build Coastguard Worker                MHA                    GQA                   MQA
56*523fa7a6SAndroid Build Coastguard Worker        n_kv_heads =4          n_kv_heads=2           n_kv_heads=1
57*523fa7a6SAndroid Build Coastguard Worker
58*523fa7a6SAndroid Build Coastguard Worker    Args:
59*523fa7a6SAndroid Build Coastguard Worker        embed_dim (int): embedding dimension for the model
60*523fa7a6SAndroid Build Coastguard Worker        num_heads (int): number of query heads. For MHA this is also the
61*523fa7a6SAndroid Build Coastguard Worker            number of heads for key and value
62*523fa7a6SAndroid Build Coastguard Worker        num_kv_heads (int): number of key and value heads. User should ensure
63*523fa7a6SAndroid Build Coastguard Worker            ``num_heads % num_kv_heads == 0``. For standard MHA set ``num_kv_heads == num_heads``,
64*523fa7a6SAndroid Build Coastguard Worker            for GQA ``num_kv_heads < num_heads``, and for MQA set ``num_kv_heads == 1``.
65*523fa7a6SAndroid Build Coastguard Worker        head_dim (int): dimension of each head, calculated by ``embed_dim // num_heads``.
66*523fa7a6SAndroid Build Coastguard Worker        q_proj (nn.Module): projection layer for query.
67*523fa7a6SAndroid Build Coastguard Worker        k_proj (nn.Module): projection layer for key.
68*523fa7a6SAndroid Build Coastguard Worker        v_proj (nn.Module): projection layer for value.
69*523fa7a6SAndroid Build Coastguard Worker        output_proj (nn.Module): projection layer for output.
70*523fa7a6SAndroid Build Coastguard Worker        pos_embeddings (Optional[nn.Module]): positional embeddings layer, e.g. RotaryPositionalEmbeddings.
71*523fa7a6SAndroid Build Coastguard Worker        q_norm (Optional[nn.Module]): normalization layer for query, e.g. RMSNorm. For decoding, this is applied
72*523fa7a6SAndroid Build Coastguard Worker            before updating from kv_cache. This means it will only support token wide normalization and not
73*523fa7a6SAndroid Build Coastguard Worker            batch or sequence wide normalization.
74*523fa7a6SAndroid Build Coastguard Worker        k_norm (Optional[nn.Module]): normalization layer for key, must be set if q_norm is.
75*523fa7a6SAndroid Build Coastguard Worker        kv_cache (Optional[KVCache]): KVCache object used to cache key and value
76*523fa7a6SAndroid Build Coastguard Worker        max_seq_len (int): maximum sequence length supported by the model.
77*523fa7a6SAndroid Build Coastguard Worker            This is needed to compute the RoPE Cache. Default: 4096.
78*523fa7a6SAndroid Build Coastguard Worker        is_causal (bool): sets the default mask to causal when no mask is provided
79*523fa7a6SAndroid Build Coastguard Worker        attn_dropout (float): dropout value passed onto the scaled_dot_product_attention function.
80*523fa7a6SAndroid Build Coastguard Worker            Default value is 0.0.
81*523fa7a6SAndroid Build Coastguard Worker
82*523fa7a6SAndroid Build Coastguard Worker    Raises:
83*523fa7a6SAndroid Build Coastguard Worker        ValueError: If ``num_heads % num_kv_heads != 0``
84*523fa7a6SAndroid Build Coastguard Worker        ValueError: If ``embed_dim % num_heads != 0``
85*523fa7a6SAndroid Build Coastguard Worker        ValueError: If ``attn_dropout < 0`` or ``attn_dropout > 1``
86*523fa7a6SAndroid Build Coastguard Worker        ValueError: if q_norm is defined without k_norm or vice versa
87*523fa7a6SAndroid Build Coastguard Worker    """
88*523fa7a6SAndroid Build Coastguard Worker
89*523fa7a6SAndroid Build Coastguard Worker    def __init__(
90*523fa7a6SAndroid Build Coastguard Worker        self,
91*523fa7a6SAndroid Build Coastguard Worker        *,
92*523fa7a6SAndroid Build Coastguard Worker        embed_dim: int,
93*523fa7a6SAndroid Build Coastguard Worker        num_heads: int,
94*523fa7a6SAndroid Build Coastguard Worker        num_kv_heads: int,
95*523fa7a6SAndroid Build Coastguard Worker        head_dim: int,
96*523fa7a6SAndroid Build Coastguard Worker        q_proj: nn.Module,
97*523fa7a6SAndroid Build Coastguard Worker        k_proj: nn.Module,
98*523fa7a6SAndroid Build Coastguard Worker        v_proj: nn.Module,
99*523fa7a6SAndroid Build Coastguard Worker        output_proj: nn.Module,
100*523fa7a6SAndroid Build Coastguard Worker        pos_embeddings: Optional[nn.Module] = None,
101*523fa7a6SAndroid Build Coastguard Worker        q_norm: Optional[nn.Module] = None,
102*523fa7a6SAndroid Build Coastguard Worker        k_norm: Optional[nn.Module] = None,
103*523fa7a6SAndroid Build Coastguard Worker        kv_cache: Optional[KVCache] = None,
104*523fa7a6SAndroid Build Coastguard Worker        max_seq_len: int = 4096,
105*523fa7a6SAndroid Build Coastguard Worker        is_causal: bool = True,
106*523fa7a6SAndroid Build Coastguard Worker        attn_dropout: float = 0.0,
107*523fa7a6SAndroid Build Coastguard Worker    ) -> None:
108*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
109*523fa7a6SAndroid Build Coastguard Worker        if num_heads % num_kv_heads != 0:
110*523fa7a6SAndroid Build Coastguard Worker            raise ValueError(
111*523fa7a6SAndroid Build Coastguard Worker                f"num_heads ({num_heads}) must be divisible by "
112*523fa7a6SAndroid Build Coastguard Worker                f"num_kv_heads ({num_kv_heads})"
113*523fa7a6SAndroid Build Coastguard Worker            )
114*523fa7a6SAndroid Build Coastguard Worker
115*523fa7a6SAndroid Build Coastguard Worker        if embed_dim % num_heads != 0:
116*523fa7a6SAndroid Build Coastguard Worker            raise ValueError(
117*523fa7a6SAndroid Build Coastguard Worker                f"embed_dim ({embed_dim}) must be divisible by "
118*523fa7a6SAndroid Build Coastguard Worker                f"num_heads ({num_heads})"
119*523fa7a6SAndroid Build Coastguard Worker            )
120*523fa7a6SAndroid Build Coastguard Worker
121*523fa7a6SAndroid Build Coastguard Worker        if attn_dropout < 0 or attn_dropout > 1:
122*523fa7a6SAndroid Build Coastguard Worker            raise ValueError(f"attn_dropout ({embed_dim}) must be between 0.0 and 1.0")
123*523fa7a6SAndroid Build Coastguard Worker
124*523fa7a6SAndroid Build Coastguard Worker        if bool(q_norm) ^ bool(k_norm):
125*523fa7a6SAndroid Build Coastguard Worker            raise ValueError("q and k norm must be set together")
126*523fa7a6SAndroid Build Coastguard Worker
127*523fa7a6SAndroid Build Coastguard Worker        # Set attributes
128*523fa7a6SAndroid Build Coastguard Worker        self.num_heads = num_heads
129*523fa7a6SAndroid Build Coastguard Worker        self.num_kv_heads = num_kv_heads
130*523fa7a6SAndroid Build Coastguard Worker        self.embed_dim = embed_dim
131*523fa7a6SAndroid Build Coastguard Worker        self.attn_dropout = attn_dropout
132*523fa7a6SAndroid Build Coastguard Worker        self.head_dim = head_dim
133*523fa7a6SAndroid Build Coastguard Worker        self.max_seq_len = max_seq_len
134*523fa7a6SAndroid Build Coastguard Worker        self.is_causal = is_causal
135*523fa7a6SAndroid Build Coastguard Worker
136*523fa7a6SAndroid Build Coastguard Worker        # Set layers
137*523fa7a6SAndroid Build Coastguard Worker        self.kv_cache = kv_cache
138*523fa7a6SAndroid Build Coastguard Worker        self.q_proj = q_proj
139*523fa7a6SAndroid Build Coastguard Worker        self.k_proj = k_proj
140*523fa7a6SAndroid Build Coastguard Worker        self.v_proj = v_proj
141*523fa7a6SAndroid Build Coastguard Worker        self.output_proj = output_proj
142*523fa7a6SAndroid Build Coastguard Worker        self.q_norm = q_norm
143*523fa7a6SAndroid Build Coastguard Worker        self.k_norm = k_norm
144*523fa7a6SAndroid Build Coastguard Worker        self.pos_embeddings = pos_embeddings
145*523fa7a6SAndroid Build Coastguard Worker
146*523fa7a6SAndroid Build Coastguard Worker        # Use flex attention if supported and we are sample packing
147*523fa7a6SAndroid Build Coastguard Worker        self._attention_call = _sdpa_or_flex_attention()
148*523fa7a6SAndroid Build Coastguard Worker        self._sdpa = SDPA(
149*523fa7a6SAndroid Build Coastguard Worker            num_kv_heads=self.num_kv_heads,
150*523fa7a6SAndroid Build Coastguard Worker            num_heads=self.num_heads,
151*523fa7a6SAndroid Build Coastguard Worker            head_dim=self.head_dim,
152*523fa7a6SAndroid Build Coastguard Worker            attn_dropout=self.attn_dropout if self.training else 0.0,
153*523fa7a6SAndroid Build Coastguard Worker            is_causal=self.is_causal,
154*523fa7a6SAndroid Build Coastguard Worker            attention_fn=self._attention_call,
155*523fa7a6SAndroid Build Coastguard Worker            kv_cache=self.kv_cache,
156*523fa7a6SAndroid Build Coastguard Worker        )
157*523fa7a6SAndroid Build Coastguard Worker
158*523fa7a6SAndroid Build Coastguard Worker        # this flag indicates whether to update the kv-cache during forward
159*523fa7a6SAndroid Build Coastguard Worker        # passes. when disabled, we can have the cache setup but still
160*523fa7a6SAndroid Build Coastguard Worker        # perform normal forward passes
161*523fa7a6SAndroid Build Coastguard Worker        self.cache_enabled = False
162*523fa7a6SAndroid Build Coastguard Worker
163*523fa7a6SAndroid Build Coastguard Worker    def setup_cache(
164*523fa7a6SAndroid Build Coastguard Worker        self, batch_size: int, dtype: torch.dtype, max_seq_len: int
165*523fa7a6SAndroid Build Coastguard Worker    ) -> None:
166*523fa7a6SAndroid Build Coastguard Worker        """Setup key value caches for attention calculation. If called
167*523fa7a6SAndroid Build Coastguard Worker        after kv_cache is already setup, this will be skipped.
168*523fa7a6SAndroid Build Coastguard Worker
169*523fa7a6SAndroid Build Coastguard Worker        Args:
170*523fa7a6SAndroid Build Coastguard Worker            batch_size (int): batch size for the caches.
171*523fa7a6SAndroid Build Coastguard Worker            dtype (torch.dtype): dtype for the caches.
172*523fa7a6SAndroid Build Coastguard Worker            max_seq_len (int): maximum sequence length model will be run with.
173*523fa7a6SAndroid Build Coastguard Worker        """
174*523fa7a6SAndroid Build Coastguard Worker        # Don't overwrite user defined kv_cache from init
175*523fa7a6SAndroid Build Coastguard Worker        if self.kv_cache is not None:
176*523fa7a6SAndroid Build Coastguard Worker            logger.warning(
177*523fa7a6SAndroid Build Coastguard Worker                "Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping."
178*523fa7a6SAndroid Build Coastguard Worker            )
179*523fa7a6SAndroid Build Coastguard Worker        else:
180*523fa7a6SAndroid Build Coastguard Worker            self.kv_cache = InferenceKVCache(
181*523fa7a6SAndroid Build Coastguard Worker                batch_size=batch_size,
182*523fa7a6SAndroid Build Coastguard Worker                max_seq_len=max_seq_len,
183*523fa7a6SAndroid Build Coastguard Worker                num_kv_heads=self.num_kv_heads,
184*523fa7a6SAndroid Build Coastguard Worker                head_dim=self.head_dim,
185*523fa7a6SAndroid Build Coastguard Worker                dtype=dtype,
186*523fa7a6SAndroid Build Coastguard Worker                transpose_cache=False,
187*523fa7a6SAndroid Build Coastguard Worker            )
188*523fa7a6SAndroid Build Coastguard Worker            self._sdpa.kv_cache = self.kv_cache
189*523fa7a6SAndroid Build Coastguard Worker            self.cache_enabled = True
190*523fa7a6SAndroid Build Coastguard Worker
191*523fa7a6SAndroid Build Coastguard Worker    def reset_cache(self):
192*523fa7a6SAndroid Build Coastguard Worker        """Reset the key value caches."""
193*523fa7a6SAndroid Build Coastguard Worker        if self.kv_cache is None:
194*523fa7a6SAndroid Build Coastguard Worker            raise RuntimeError(
195*523fa7a6SAndroid Build Coastguard Worker                "Key value caches are not setup. Call ``setup_caches()`` first."
196*523fa7a6SAndroid Build Coastguard Worker            )
197*523fa7a6SAndroid Build Coastguard Worker        self.kv_cache.reset()
198*523fa7a6SAndroid Build Coastguard Worker
199*523fa7a6SAndroid Build Coastguard Worker    def forward(
200*523fa7a6SAndroid Build Coastguard Worker        self,
201*523fa7a6SAndroid Build Coastguard Worker        x: torch.Tensor,
202*523fa7a6SAndroid Build Coastguard Worker        y: Optional[torch.Tensor] = None,
203*523fa7a6SAndroid Build Coastguard Worker        *,
204*523fa7a6SAndroid Build Coastguard Worker        mask: Optional[_MaskType] = None,
205*523fa7a6SAndroid Build Coastguard Worker        input_pos: Optional[torch.Tensor] = None,
206*523fa7a6SAndroid Build Coastguard Worker    ) -> torch.Tensor:
207*523fa7a6SAndroid Build Coastguard Worker        """
208*523fa7a6SAndroid Build Coastguard Worker        Args:
209*523fa7a6SAndroid Build Coastguard Worker            x (torch.Tensor): input tensor with shape [b x s_x x d] for the query
210*523fa7a6SAndroid Build Coastguard Worker            y (Optional[torch.Tensor]): second input tensor with shape [b x s_y x d], is the input
211*523fa7a6SAndroid Build Coastguard Worker                for k and v. For self attention, x=y. Optional only with kv_cache enabled.
212*523fa7a6SAndroid Build Coastguard Worker            mask (Optional[_MaskType]): Used to mask the scores after the query-key multiplication
213*523fa7a6SAndroid Build Coastguard Worker                and before the softmax. Either:
214*523fa7a6SAndroid Build Coastguard Worker
215*523fa7a6SAndroid Build Coastguard Worker                A boolean tensor with shape ``[b x s x s]``, ``[b x s x self.encoder_max_cache_seq_len]``,
216*523fa7a6SAndroid Build Coastguard Worker                or ``[b x s x self.encoder_max_cache_seq_len]`` if using KV-cacheing with encoder/decoder layers.
217*523fa7a6SAndroid Build Coastguard Worker                A value of True in row ``i`` and column ``j`` means token ``i`` attends to token ``j``. A value of False means
218*523fa7a6SAndroid Build Coastguard Worker                token ``i`` does not attend to token ``j``. If no mask is specified, a causal mask
219*523fa7a6SAndroid Build Coastguard Worker                is used by default.
220*523fa7a6SAndroid Build Coastguard Worker
221*523fa7a6SAndroid Build Coastguard Worker                A :class:`~torch.nn.attention.flex_attention.BlockMask` for document masking in a packed sequence
222*523fa7a6SAndroid Build Coastguard Worker                created via `create_block_mask <https://pytorch.org/blog/flexattention/#mask-mods>`_. We  use
223*523fa7a6SAndroid Build Coastguard Worker                :func:`~torch.nn.attention.flex_attention.flex_attention` when computing attention with block masks.
224*523fa7a6SAndroid Build Coastguard Worker                Default is None.
225*523fa7a6SAndroid Build Coastguard Worker            input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids
226*523fa7a6SAndroid Build Coastguard Worker                of each token. During training, this is used to indicate the positions
227*523fa7a6SAndroid Build Coastguard Worker                of each token relative to its sample when packed, shape [b x s].
228*523fa7a6SAndroid Build Coastguard Worker                During inference, this indicates the position of the current token.
229*523fa7a6SAndroid Build Coastguard Worker                If none, assume the index of the token is its position id. Default is None.
230*523fa7a6SAndroid Build Coastguard Worker
231*523fa7a6SAndroid Build Coastguard Worker        Raises:
232*523fa7a6SAndroid Build Coastguard Worker            ValueError: If no ``y`` input and ``kv_cache`` is not enabled.
233*523fa7a6SAndroid Build Coastguard Worker
234*523fa7a6SAndroid Build Coastguard Worker        Returns:
235*523fa7a6SAndroid Build Coastguard Worker            torch.Tensor: output tensor with attention applied
236*523fa7a6SAndroid Build Coastguard Worker
237*523fa7a6SAndroid Build Coastguard Worker        Notation used for tensor shapes:
238*523fa7a6SAndroid Build Coastguard Worker            - b: batch size
239*523fa7a6SAndroid Build Coastguard Worker            - s_x: sequence length for x
240*523fa7a6SAndroid Build Coastguard Worker            - s_y: sequence length for y
241*523fa7a6SAndroid Build Coastguard Worker            - n_h: num heads
242*523fa7a6SAndroid Build Coastguard Worker            - n_kv: num kv heads
243*523fa7a6SAndroid Build Coastguard Worker            - d: embed dim
244*523fa7a6SAndroid Build Coastguard Worker            - h_d: head dim
245*523fa7a6SAndroid Build Coastguard Worker        """
246*523fa7a6SAndroid Build Coastguard Worker        # x has shape [b, s_x, d]
247*523fa7a6SAndroid Build Coastguard Worker        # y has shape [b, s_y, d]
248*523fa7a6SAndroid Build Coastguard Worker        b, s_x, _ = x.shape
249*523fa7a6SAndroid Build Coastguard Worker
250*523fa7a6SAndroid Build Coastguard Worker        # q has shape [b, s_x, num_heads * head_dim]
251*523fa7a6SAndroid Build Coastguard Worker        q = self.q_proj(x)
252*523fa7a6SAndroid Build Coastguard Worker
253*523fa7a6SAndroid Build Coastguard Worker        # number of queries per key/value
254*523fa7a6SAndroid Build Coastguard Worker        q_per_kv = self.num_heads // self.num_kv_heads
255*523fa7a6SAndroid Build Coastguard Worker        q = q.view(b, s_x, self.num_kv_heads * q_per_kv, self.head_dim)
256*523fa7a6SAndroid Build Coastguard Worker
257*523fa7a6SAndroid Build Coastguard Worker        # Apply positional embeddings
258*523fa7a6SAndroid Build Coastguard Worker        if self.pos_embeddings is not None:
259*523fa7a6SAndroid Build Coastguard Worker            q = self.pos_embeddings(q, input_pos=input_pos)
260*523fa7a6SAndroid Build Coastguard Worker
261*523fa7a6SAndroid Build Coastguard Worker        # Normalize q
262*523fa7a6SAndroid Build Coastguard Worker        if self.q_norm is not None:
263*523fa7a6SAndroid Build Coastguard Worker            q = self.q_norm(q)
264*523fa7a6SAndroid Build Coastguard Worker
265*523fa7a6SAndroid Build Coastguard Worker        def calculate_kv(y):
266*523fa7a6SAndroid Build Coastguard Worker            # Update k and v shape, positional embeddings, and normalization
267*523fa7a6SAndroid Build Coastguard Worker            s_y = y.shape[1]
268*523fa7a6SAndroid Build Coastguard Worker            # k has shape [b, s_y, num_kv_heads * head_dim]
269*523fa7a6SAndroid Build Coastguard Worker            # v has shape [b, s_y, num_kv_heads * head_dim]
270*523fa7a6SAndroid Build Coastguard Worker            k = self.k_proj(y)
271*523fa7a6SAndroid Build Coastguard Worker            v = self.v_proj(y)
272*523fa7a6SAndroid Build Coastguard Worker
273*523fa7a6SAndroid Build Coastguard Worker            # Apply positional embeddings
274*523fa7a6SAndroid Build Coastguard Worker            # k: [b, s_y, n_kv, h_d]
275*523fa7a6SAndroid Build Coastguard Worker            k = k.view(b, s_y, -1, self.head_dim)
276*523fa7a6SAndroid Build Coastguard Worker            v = v.view(b, s_y, -1, self.head_dim)
277*523fa7a6SAndroid Build Coastguard Worker            if self.pos_embeddings is not None:
278*523fa7a6SAndroid Build Coastguard Worker                k = self.pos_embeddings(k, input_pos=input_pos)
279*523fa7a6SAndroid Build Coastguard Worker
280*523fa7a6SAndroid Build Coastguard Worker            # Normalize k
281*523fa7a6SAndroid Build Coastguard Worker            if self.k_norm is not None:
282*523fa7a6SAndroid Build Coastguard Worker                k = self.k_norm(k)
283*523fa7a6SAndroid Build Coastguard Worker            return k, v
284*523fa7a6SAndroid Build Coastguard Worker
285*523fa7a6SAndroid Build Coastguard Worker        def true_fn(y):
286*523fa7a6SAndroid Build Coastguard Worker            kv_cache = self.kv_cache.clone()
287*523fa7a6SAndroid Build Coastguard Worker            return kv_cache.k_cache, kv_cache.v_cache, kv_cache.cache_pos
288*523fa7a6SAndroid Build Coastguard Worker
289*523fa7a6SAndroid Build Coastguard Worker        def false_fn(y):
290*523fa7a6SAndroid Build Coastguard Worker            k, v = calculate_kv(y)
291*523fa7a6SAndroid Build Coastguard Worker            kv_cache = self.kv_cache.clone()
292*523fa7a6SAndroid Build Coastguard Worker            kv_cache.update(k, v)
293*523fa7a6SAndroid Build Coastguard Worker            return kv_cache.k_cache, kv_cache.v_cache, kv_cache.cache_pos
294*523fa7a6SAndroid Build Coastguard Worker
295*523fa7a6SAndroid Build Coastguard Worker        # If kv cache is None, we expect y to be provided
296*523fa7a6SAndroid Build Coastguard Worker        if self.kv_cache is None:
297*523fa7a6SAndroid Build Coastguard Worker            assert (
298*523fa7a6SAndroid Build Coastguard Worker                y is not None
299*523fa7a6SAndroid Build Coastguard Worker            ), "Must provide y input or use kv_cache to enable streaming decoding"
300*523fa7a6SAndroid Build Coastguard Worker            k, v = calculate_kv(y)
301*523fa7a6SAndroid Build Coastguard Worker        else:
302*523fa7a6SAndroid Build Coastguard Worker            # Expecting the k, v returning here to be the same size of self.kv_cache
303*523fa7a6SAndroid Build Coastguard Worker            # In eager, we expect this predicate to specialize. In export, this will
304*523fa7a6SAndroid Build Coastguard Worker            # become a SymBool so it's not specialized.
305*523fa7a6SAndroid Build Coastguard Worker            k, v, cache_pos = torch.cond(
306*523fa7a6SAndroid Build Coastguard Worker                torch.isnan(y).all().item(), true_fn, false_fn, (y,)
307*523fa7a6SAndroid Build Coastguard Worker            )
308*523fa7a6SAndroid Build Coastguard Worker            # Update key-value cache
309*523fa7a6SAndroid Build Coastguard Worker            self.kv_cache.k_cache.copy_(k)
310*523fa7a6SAndroid Build Coastguard Worker            self.kv_cache.v_cache.copy_(v)
311*523fa7a6SAndroid Build Coastguard Worker            self.kv_cache.cache_pos.copy_(cache_pos)
312*523fa7a6SAndroid Build Coastguard Worker
313*523fa7a6SAndroid Build Coastguard Worker        output = self._sdpa(q, k, v, b, s_x, mask=mask)
314*523fa7a6SAndroid Build Coastguard Worker        return self.output_proj(output)
315*523fa7a6SAndroid Build Coastguard Worker
316*523fa7a6SAndroid Build Coastguard Worker
317*523fa7a6SAndroid Build Coastguard Workerclass SDPA(nn.Module):
318*523fa7a6SAndroid Build Coastguard Worker    """
319*523fa7a6SAndroid Build Coastguard Worker    TorchTune's SDPA which can be optimized and can be swapped
320*523fa7a6SAndroid Build Coastguard Worker    out for a more efficient implementations.
321*523fa7a6SAndroid Build Coastguard Worker    """
322*523fa7a6SAndroid Build Coastguard Worker
323*523fa7a6SAndroid Build Coastguard Worker    def __init__(
324*523fa7a6SAndroid Build Coastguard Worker        self,
325*523fa7a6SAndroid Build Coastguard Worker        num_kv_heads: int,
326*523fa7a6SAndroid Build Coastguard Worker        num_heads: int,
327*523fa7a6SAndroid Build Coastguard Worker        head_dim: int,
328*523fa7a6SAndroid Build Coastguard Worker        attn_dropout: float,
329*523fa7a6SAndroid Build Coastguard Worker        is_causal: bool,
330*523fa7a6SAndroid Build Coastguard Worker        attention_fn,
331*523fa7a6SAndroid Build Coastguard Worker        kv_cache,
332*523fa7a6SAndroid Build Coastguard Worker    ) -> None:
333*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
334*523fa7a6SAndroid Build Coastguard Worker        self.num_kv_heads = num_kv_heads
335*523fa7a6SAndroid Build Coastguard Worker        self.num_heads = num_heads
336*523fa7a6SAndroid Build Coastguard Worker        self.head_dim = head_dim
337*523fa7a6SAndroid Build Coastguard Worker        self.q_per_kv = self.num_heads // self.num_kv_heads
338*523fa7a6SAndroid Build Coastguard Worker        self.attn_dropout = attn_dropout
339*523fa7a6SAndroid Build Coastguard Worker        self.is_causal = is_causal
340*523fa7a6SAndroid Build Coastguard Worker        self._attention_fn = attention_fn
341*523fa7a6SAndroid Build Coastguard Worker        self.kv_cache = kv_cache
342*523fa7a6SAndroid Build Coastguard Worker
343*523fa7a6SAndroid Build Coastguard Worker    def forward(
344*523fa7a6SAndroid Build Coastguard Worker        self,
345*523fa7a6SAndroid Build Coastguard Worker        q: torch.Tensor,  # [b, s, n_h, h_d]
346*523fa7a6SAndroid Build Coastguard Worker        k: torch.Tensor,  # [b, s, n_kv, h_d]
347*523fa7a6SAndroid Build Coastguard Worker        v: torch.Tensor,  # [b, s, n_kv, h_d]
348*523fa7a6SAndroid Build Coastguard Worker        bsz: int,
349*523fa7a6SAndroid Build Coastguard Worker        seq_len: int,
350*523fa7a6SAndroid Build Coastguard Worker        mask: Optional[_MaskType] = None,
351*523fa7a6SAndroid Build Coastguard Worker    ) -> torch.Tensor:
352*523fa7a6SAndroid Build Coastguard Worker        # View + expand + reshape bring num_kv_heads to num_heads for k and v
353*523fa7a6SAndroid Build Coastguard Worker        # to match q.
354*523fa7a6SAndroid Build Coastguard Worker
355*523fa7a6SAndroid Build Coastguard Worker        # [bsz, n_h, s, h_d]
356*523fa7a6SAndroid Build Coastguard Worker        q = q.transpose(1, 2)
357*523fa7a6SAndroid Build Coastguard Worker        k = k.transpose(1, 2)
358*523fa7a6SAndroid Build Coastguard Worker        v = v.transpose(1, 2)
359*523fa7a6SAndroid Build Coastguard Worker
360*523fa7a6SAndroid Build Coastguard Worker        # Expand the key and value tensors to have the same shape
361*523fa7a6SAndroid Build Coastguard Worker        # as the query tensor by copying values across the relevant dim
362*523fa7a6SAndroid Build Coastguard Worker        if self.num_heads != self.num_kv_heads:
363*523fa7a6SAndroid Build Coastguard Worker            expand_shape = (-1, -1, self.q_per_kv, -1, -1)
364*523fa7a6SAndroid Build Coastguard Worker            k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2)
365*523fa7a6SAndroid Build Coastguard Worker            v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2)
366*523fa7a6SAndroid Build Coastguard Worker
367*523fa7a6SAndroid Build Coastguard Worker        output = self._attention_fn(
368*523fa7a6SAndroid Build Coastguard Worker            q,
369*523fa7a6SAndroid Build Coastguard Worker            k,
370*523fa7a6SAndroid Build Coastguard Worker            v,
371*523fa7a6SAndroid Build Coastguard Worker            mask=mask,
372*523fa7a6SAndroid Build Coastguard Worker            dropout_p=self.attn_dropout,
373*523fa7a6SAndroid Build Coastguard Worker            is_causal=self.kv_cache is None and mask is None and self.is_causal,
374*523fa7a6SAndroid Build Coastguard Worker        )
375*523fa7a6SAndroid Build Coastguard Worker        # Reshape the output to be the same shape as the input
376*523fa7a6SAndroid Build Coastguard Worker        return output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
377*523fa7a6SAndroid Build Coastguard Worker
378*523fa7a6SAndroid Build Coastguard Worker
379*523fa7a6SAndroid Build Coastguard Workerdef _replace_mha_with_inference_mha(module: torch.nn.Module) -> None:
380*523fa7a6SAndroid Build Coastguard Worker    for name, child in module.named_children():
381*523fa7a6SAndroid Build Coastguard Worker        if isinstance(child, TorchTuneAttention.MultiHeadAttention):
382*523fa7a6SAndroid Build Coastguard Worker            setattr(
383*523fa7a6SAndroid Build Coastguard Worker                module,
384*523fa7a6SAndroid Build Coastguard Worker                name,
385*523fa7a6SAndroid Build Coastguard Worker                MultiHeadAttention(
386*523fa7a6SAndroid Build Coastguard Worker                    embed_dim=child.embed_dim,
387*523fa7a6SAndroid Build Coastguard Worker                    num_heads=child.num_heads,
388*523fa7a6SAndroid Build Coastguard Worker                    num_kv_heads=child.num_kv_heads,
389*523fa7a6SAndroid Build Coastguard Worker                    head_dim=child.head_dim,
390*523fa7a6SAndroid Build Coastguard Worker                    q_proj=child.q_proj,
391*523fa7a6SAndroid Build Coastguard Worker                    k_proj=child.k_proj,
392*523fa7a6SAndroid Build Coastguard Worker                    v_proj=child.v_proj,
393*523fa7a6SAndroid Build Coastguard Worker                    output_proj=child.output_proj,
394*523fa7a6SAndroid Build Coastguard Worker                    pos_embeddings=child.pos_embeddings,
395*523fa7a6SAndroid Build Coastguard Worker                    q_norm=child.q_norm,
396*523fa7a6SAndroid Build Coastguard Worker                    k_norm=child.k_norm,
397*523fa7a6SAndroid Build Coastguard Worker                    kv_cache=child.kv_cache,
398*523fa7a6SAndroid Build Coastguard Worker                    max_seq_len=child.max_seq_len,
399*523fa7a6SAndroid Build Coastguard Worker                    is_causal=child.is_causal,
400*523fa7a6SAndroid Build Coastguard Worker                    attn_dropout=child.attn_dropout,
401*523fa7a6SAndroid Build Coastguard Worker                ),
402*523fa7a6SAndroid Build Coastguard Worker            )
403*523fa7a6SAndroid Build Coastguard Worker        else:
404*523fa7a6SAndroid Build Coastguard Worker            replace_mha_with_inference_mha(child)
405*523fa7a6SAndroid Build Coastguard Worker
406*523fa7a6SAndroid Build Coastguard Worker
407*523fa7a6SAndroid Build Coastguard Workerdef replace_mha_with_inference_mha(module: torch.nn.Module) -> torch.nn.Module:
408*523fa7a6SAndroid Build Coastguard Worker    """
409*523fa7a6SAndroid Build Coastguard Worker    Replace TorchTune's MHA with an inference friendly version of MHA that
410*523fa7a6SAndroid Build Coastguard Worker    separates out the inference-related parts for further optimization.
411*523fa7a6SAndroid Build Coastguard Worker    """
412*523fa7a6SAndroid Build Coastguard Worker    _replace_mha_with_inference_mha(module)
413*523fa7a6SAndroid Build Coastguard Worker    return module
414