xref: /aosp_15_r20/external/executorch/backends/xnnpack/partition/graphs/sdpa.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from functools import lru_cache
8from typing import List, Optional
9
10import torch
11from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config
12from executorch.exir import to_edge
13from torch import Tensor
14from torch.export import export
15
16
17@lru_cache(maxsize=None)
18def get_graphs() -> List[torch.fx.GraphModule]:
19    """
20    Returns a list of SDPA graphs.
21    """
22
23    class SDPA(torch.nn.Module):
24        def __init__(self):
25            super().__init__()
26            self.dropout_p: float = 0.0
27            self.is_causal: bool = False
28            self.scale: Optional[float] = None
29
30        def forward(
31            self,
32            q: Tensor,
33            k: Tensor,
34            v: Tensor,
35            attn_mask: Optional[Tensor] = None,
36        ):
37            return torch.nn.functional.scaled_dot_product_attention(
38                q,
39                k,
40                v,
41                attn_mask=mask,
42                dropout_p=self.dropout_p,
43                is_causal=self.is_causal,
44                scale=self.scale,
45            )
46
47    batch_size = 8
48    heads = 16
49    seq_len = 32
50    dim = 64
51
52    q = torch.randn(batch_size, heads, seq_len, dim)
53    k = torch.randn(batch_size, heads, seq_len, dim)
54    v = torch.randn(batch_size, heads, seq_len, dim)
55
56    # TODO add support for,
57    # 1. None - mask should be inserted later on
58    # 2. >2d tensor - requires general unsqueeze from newer xnnpack
59    masks = [torch.full((seq_len, seq_len), 0, dtype=torch.float)]
60
61    graphs = []
62    for mask in masks:
63        # These two seems to generate different graphs - P1136301928
64        for dtype in [torch.float, torch.float16]:
65            q = q.to(dtype)
66            k = k.to(dtype)
67            v = v.to(dtype)
68            mask = mask.to(dtype)
69
70            edge = to_edge(
71                export(
72                    SDPA(),  # pyre-ignore[16]
73                    (
74                        q,
75                        k,
76                        v,
77                        mask,
78                    ),
79                ),
80                compile_config=get_xnnpack_edge_compile_config(),
81            )
82            gm = edge.exported_program().graph_module
83            graphs.append(gm)
84
85    return graphs
86