xref: /aosp_15_r20/external/executorch/extension/llm/modules/attention.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import logging
8from typing import Optional
9
10import torch
11import torchtune.modules.attention as TorchTuneAttention
12from executorch.extension.llm.modules.kv_cache import KVCache as InferenceKVCache
13from torch import nn
14from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention
15from torchtune.modules.kv_cache import KVCache
16
17logger = logging.getLogger(__name__)
18
19
20class MultiHeadAttention(nn.Module):
21    """
22    NOTE: copied from Torchtune's mha.py. Should be mostly 1:1 except
23    that SDPA is factored out so that it can be swapped for more
24    efficient ExecuTorch-defined SDPA ops.
25
26    Multi-headed attention layer with support for grouped query
27    attention (GQA) introduced in https://arxiv.org/abs/2305.13245v1.
28
29    GQA is a version of multiheaded attention (MHA) which uses fewer
30    key/value heads than query heads by grouping n query heads for each
31    key and value head. Multi-Query Attention is an extreme
32    version where we have a single key and value head shared by all
33    query heads.
34
35    Following is an example of MHA, GQA and MQA with num_heads = 4
36
37    (credit for the documentation:
38    `litgpt.Config <https://github.com/Lightning-AI/litgpt/blob/eda1aaaf391fd689664f95487ab03dc137e213fd/litgpt/config.py>`_).
39
40
41    ::
42
43        ┌───┐┌───┐┌───┐┌───┐     ┌───┐    ┌───┐             ┌───┐
44        │ v ││ v ││ v ││ v │     │ v │    │ v │             │ v │
45        └───┘└───┘└───┘└───┘     └───┘    └───┘             └───┘
46        │    │    │    │         │        │                 │
47        ┌───┐┌───┐┌───┐┌───┐     ┌───┐    ┌───┐             ┌───┐
48        │ k ││ k ││ k ││ k │     │ k │    │ k │             │ k │
49        └───┘└───┘└───┘└───┘     └───┘    └───┘             └───┘
50        │    │    │    │      ┌──┴──┐  ┌──┴──┐      ┌────┬──┴─┬────┐
51        ┌───┐┌───┐┌───┐┌───┐  ┌───┐┌───┐┌───┐┌───┐  ┌───┐┌───┐┌───┐┌───┐
52        │ q ││ q ││ q ││ q │  │ q ││ q ││ q ││ q │  │ q ││ q ││ q ││ q │
53        └───┘└───┘└───┘└───┘  └───┘└───┘└───┘└───┘  └───┘└───┘└───┘└───┘
54        ◀──────────────────▶  ◀──────────────────▶  ◀──────────────────▶
55                MHA                    GQA                   MQA
56        n_kv_heads =4          n_kv_heads=2           n_kv_heads=1
57
58    Args:
59        embed_dim (int): embedding dimension for the model
60        num_heads (int): number of query heads. For MHA this is also the
61            number of heads for key and value
62        num_kv_heads (int): number of key and value heads. User should ensure
63            ``num_heads % num_kv_heads == 0``. For standard MHA set ``num_kv_heads == num_heads``,
64            for GQA ``num_kv_heads < num_heads``, and for MQA set ``num_kv_heads == 1``.
65        head_dim (int): dimension of each head, calculated by ``embed_dim // num_heads``.
66        q_proj (nn.Module): projection layer for query.
67        k_proj (nn.Module): projection layer for key.
68        v_proj (nn.Module): projection layer for value.
69        output_proj (nn.Module): projection layer for output.
70        pos_embeddings (Optional[nn.Module]): positional embeddings layer, e.g. RotaryPositionalEmbeddings.
71        q_norm (Optional[nn.Module]): normalization layer for query, e.g. RMSNorm. For decoding, this is applied
72            before updating from kv_cache. This means it will only support token wide normalization and not
73            batch or sequence wide normalization.
74        k_norm (Optional[nn.Module]): normalization layer for key, must be set if q_norm is.
75        kv_cache (Optional[KVCache]): KVCache object used to cache key and value
76        max_seq_len (int): maximum sequence length supported by the model.
77            This is needed to compute the RoPE Cache. Default: 4096.
78        is_causal (bool): sets the default mask to causal when no mask is provided
79        attn_dropout (float): dropout value passed onto the scaled_dot_product_attention function.
80            Default value is 0.0.
81
82    Raises:
83        ValueError: If ``num_heads % num_kv_heads != 0``
84        ValueError: If ``embed_dim % num_heads != 0``
85        ValueError: If ``attn_dropout < 0`` or ``attn_dropout > 1``
86        ValueError: if q_norm is defined without k_norm or vice versa
87    """
88
89    def __init__(
90        self,
91        *,
92        embed_dim: int,
93        num_heads: int,
94        num_kv_heads: int,
95        head_dim: int,
96        q_proj: nn.Module,
97        k_proj: nn.Module,
98        v_proj: nn.Module,
99        output_proj: nn.Module,
100        pos_embeddings: Optional[nn.Module] = None,
101        q_norm: Optional[nn.Module] = None,
102        k_norm: Optional[nn.Module] = None,
103        kv_cache: Optional[KVCache] = None,
104        max_seq_len: int = 4096,
105        is_causal: bool = True,
106        attn_dropout: float = 0.0,
107    ) -> None:
108        super().__init__()
109        if num_heads % num_kv_heads != 0:
110            raise ValueError(
111                f"num_heads ({num_heads}) must be divisible by "
112                f"num_kv_heads ({num_kv_heads})"
113            )
114
115        if embed_dim % num_heads != 0:
116            raise ValueError(
117                f"embed_dim ({embed_dim}) must be divisible by "
118                f"num_heads ({num_heads})"
119            )
120
121        if attn_dropout < 0 or attn_dropout > 1:
122            raise ValueError(f"attn_dropout ({embed_dim}) must be between 0.0 and 1.0")
123
124        if bool(q_norm) ^ bool(k_norm):
125            raise ValueError("q and k norm must be set together")
126
127        # Set attributes
128        self.num_heads = num_heads
129        self.num_kv_heads = num_kv_heads
130        self.embed_dim = embed_dim
131        self.attn_dropout = attn_dropout
132        self.head_dim = head_dim
133        self.max_seq_len = max_seq_len
134        self.is_causal = is_causal
135
136        # Set layers
137        self.kv_cache = kv_cache
138        self.q_proj = q_proj
139        self.k_proj = k_proj
140        self.v_proj = v_proj
141        self.output_proj = output_proj
142        self.q_norm = q_norm
143        self.k_norm = k_norm
144        self.pos_embeddings = pos_embeddings
145
146        # Use flex attention if supported and we are sample packing
147        self._attention_call = _sdpa_or_flex_attention()
148        self._sdpa = SDPA(
149            num_kv_heads=self.num_kv_heads,
150            num_heads=self.num_heads,
151            head_dim=self.head_dim,
152            attn_dropout=self.attn_dropout if self.training else 0.0,
153            is_causal=self.is_causal,
154            attention_fn=self._attention_call,
155            kv_cache=self.kv_cache,
156        )
157
158        # this flag indicates whether to update the kv-cache during forward
159        # passes. when disabled, we can have the cache setup but still
160        # perform normal forward passes
161        self.cache_enabled = False
162
163    def setup_cache(
164        self, batch_size: int, dtype: torch.dtype, max_seq_len: int
165    ) -> None:
166        """Setup key value caches for attention calculation. If called
167        after kv_cache is already setup, this will be skipped.
168
169        Args:
170            batch_size (int): batch size for the caches.
171            dtype (torch.dtype): dtype for the caches.
172            max_seq_len (int): maximum sequence length model will be run with.
173        """
174        # Don't overwrite user defined kv_cache from init
175        if self.kv_cache is not None:
176            logger.warning(
177                "Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping."
178            )
179        else:
180            self.kv_cache = InferenceKVCache(
181                batch_size=batch_size,
182                max_seq_len=max_seq_len,
183                num_kv_heads=self.num_kv_heads,
184                head_dim=self.head_dim,
185                dtype=dtype,
186                transpose_cache=False,
187            )
188            self._sdpa.kv_cache = self.kv_cache
189            self.cache_enabled = True
190
191    def reset_cache(self):
192        """Reset the key value caches."""
193        if self.kv_cache is None:
194            raise RuntimeError(
195                "Key value caches are not setup. Call ``setup_caches()`` first."
196            )
197        self.kv_cache.reset()
198
199    def forward(
200        self,
201        x: torch.Tensor,
202        y: Optional[torch.Tensor] = None,
203        *,
204        mask: Optional[_MaskType] = None,
205        input_pos: Optional[torch.Tensor] = None,
206    ) -> torch.Tensor:
207        """
208        Args:
209            x (torch.Tensor): input tensor with shape [b x s_x x d] for the query
210            y (Optional[torch.Tensor]): second input tensor with shape [b x s_y x d], is the input
211                for k and v. For self attention, x=y. Optional only with kv_cache enabled.
212            mask (Optional[_MaskType]): Used to mask the scores after the query-key multiplication
213                and before the softmax. Either:
214
215                A boolean tensor with shape ``[b x s x s]``, ``[b x s x self.encoder_max_cache_seq_len]``,
216                or ``[b x s x self.encoder_max_cache_seq_len]`` if using KV-cacheing with encoder/decoder layers.
217                A value of True in row ``i`` and column ``j`` means token ``i`` attends to token ``j``. A value of False means
218                token ``i`` does not attend to token ``j``. If no mask is specified, a causal mask
219                is used by default.
220
221                A :class:`~torch.nn.attention.flex_attention.BlockMask` for document masking in a packed sequence
222                created via `create_block_mask <https://pytorch.org/blog/flexattention/#mask-mods>`_. We  use
223                :func:`~torch.nn.attention.flex_attention.flex_attention` when computing attention with block masks.
224                Default is None.
225            input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids
226                of each token. During training, this is used to indicate the positions
227                of each token relative to its sample when packed, shape [b x s].
228                During inference, this indicates the position of the current token.
229                If none, assume the index of the token is its position id. Default is None.
230
231        Raises:
232            ValueError: If no ``y`` input and ``kv_cache`` is not enabled.
233
234        Returns:
235            torch.Tensor: output tensor with attention applied
236
237        Notation used for tensor shapes:
238            - b: batch size
239            - s_x: sequence length for x
240            - s_y: sequence length for y
241            - n_h: num heads
242            - n_kv: num kv heads
243            - d: embed dim
244            - h_d: head dim
245        """
246        # x has shape [b, s_x, d]
247        # y has shape [b, s_y, d]
248        b, s_x, _ = x.shape
249
250        # q has shape [b, s_x, num_heads * head_dim]
251        q = self.q_proj(x)
252
253        # number of queries per key/value
254        q_per_kv = self.num_heads // self.num_kv_heads
255        q = q.view(b, s_x, self.num_kv_heads * q_per_kv, self.head_dim)
256
257        # Apply positional embeddings
258        if self.pos_embeddings is not None:
259            q = self.pos_embeddings(q, input_pos=input_pos)
260
261        # Normalize q
262        if self.q_norm is not None:
263            q = self.q_norm(q)
264
265        def calculate_kv(y):
266            # Update k and v shape, positional embeddings, and normalization
267            s_y = y.shape[1]
268            # k has shape [b, s_y, num_kv_heads * head_dim]
269            # v has shape [b, s_y, num_kv_heads * head_dim]
270            k = self.k_proj(y)
271            v = self.v_proj(y)
272
273            # Apply positional embeddings
274            # k: [b, s_y, n_kv, h_d]
275            k = k.view(b, s_y, -1, self.head_dim)
276            v = v.view(b, s_y, -1, self.head_dim)
277            if self.pos_embeddings is not None:
278                k = self.pos_embeddings(k, input_pos=input_pos)
279
280            # Normalize k
281            if self.k_norm is not None:
282                k = self.k_norm(k)
283            return k, v
284
285        def true_fn(y):
286            kv_cache = self.kv_cache.clone()
287            return kv_cache.k_cache, kv_cache.v_cache, kv_cache.cache_pos
288
289        def false_fn(y):
290            k, v = calculate_kv(y)
291            kv_cache = self.kv_cache.clone()
292            kv_cache.update(k, v)
293            return kv_cache.k_cache, kv_cache.v_cache, kv_cache.cache_pos
294
295        # If kv cache is None, we expect y to be provided
296        if self.kv_cache is None:
297            assert (
298                y is not None
299            ), "Must provide y input or use kv_cache to enable streaming decoding"
300            k, v = calculate_kv(y)
301        else:
302            # Expecting the k, v returning here to be the same size of self.kv_cache
303            # In eager, we expect this predicate to specialize. In export, this will
304            # become a SymBool so it's not specialized.
305            k, v, cache_pos = torch.cond(
306                torch.isnan(y).all().item(), true_fn, false_fn, (y,)
307            )
308            # Update key-value cache
309            self.kv_cache.k_cache.copy_(k)
310            self.kv_cache.v_cache.copy_(v)
311            self.kv_cache.cache_pos.copy_(cache_pos)
312
313        output = self._sdpa(q, k, v, b, s_x, mask=mask)
314        return self.output_proj(output)
315
316
317class SDPA(nn.Module):
318    """
319    TorchTune's SDPA which can be optimized and can be swapped
320    out for a more efficient implementations.
321    """
322
323    def __init__(
324        self,
325        num_kv_heads: int,
326        num_heads: int,
327        head_dim: int,
328        attn_dropout: float,
329        is_causal: bool,
330        attention_fn,
331        kv_cache,
332    ) -> None:
333        super().__init__()
334        self.num_kv_heads = num_kv_heads
335        self.num_heads = num_heads
336        self.head_dim = head_dim
337        self.q_per_kv = self.num_heads // self.num_kv_heads
338        self.attn_dropout = attn_dropout
339        self.is_causal = is_causal
340        self._attention_fn = attention_fn
341        self.kv_cache = kv_cache
342
343    def forward(
344        self,
345        q: torch.Tensor,  # [b, s, n_h, h_d]
346        k: torch.Tensor,  # [b, s, n_kv, h_d]
347        v: torch.Tensor,  # [b, s, n_kv, h_d]
348        bsz: int,
349        seq_len: int,
350        mask: Optional[_MaskType] = None,
351    ) -> torch.Tensor:
352        # View + expand + reshape bring num_kv_heads to num_heads for k and v
353        # to match q.
354
355        # [bsz, n_h, s, h_d]
356        q = q.transpose(1, 2)
357        k = k.transpose(1, 2)
358        v = v.transpose(1, 2)
359
360        # Expand the key and value tensors to have the same shape
361        # as the query tensor by copying values across the relevant dim
362        if self.num_heads != self.num_kv_heads:
363            expand_shape = (-1, -1, self.q_per_kv, -1, -1)
364            k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2)
365            v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2)
366
367        output = self._attention_fn(
368            q,
369            k,
370            v,
371            mask=mask,
372            dropout_p=self.attn_dropout,
373            is_causal=self.kv_cache is None and mask is None and self.is_causal,
374        )
375        # Reshape the output to be the same shape as the input
376        return output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
377
378
379def _replace_mha_with_inference_mha(module: torch.nn.Module) -> None:
380    for name, child in module.named_children():
381        if isinstance(child, TorchTuneAttention.MultiHeadAttention):
382            setattr(
383                module,
384                name,
385                MultiHeadAttention(
386                    embed_dim=child.embed_dim,
387                    num_heads=child.num_heads,
388                    num_kv_heads=child.num_kv_heads,
389                    head_dim=child.head_dim,
390                    q_proj=child.q_proj,
391                    k_proj=child.k_proj,
392                    v_proj=child.v_proj,
393                    output_proj=child.output_proj,
394                    pos_embeddings=child.pos_embeddings,
395                    q_norm=child.q_norm,
396                    k_norm=child.k_norm,
397                    kv_cache=child.kv_cache,
398                    max_seq_len=child.max_seq_len,
399                    is_causal=child.is_causal,
400                    attn_dropout=child.attn_dropout,
401                ),
402            )
403        else:
404            replace_mha_with_inference_mha(child)
405
406
407def replace_mha_with_inference_mha(module: torch.nn.Module) -> torch.nn.Module:
408    """
409    Replace TorchTune's MHA with an inference friendly version of MHA that
410    separates out the inference-related parts for further optimization.
411    """
412    _replace_mha_with_inference_mha(module)
413    return module
414