xref: /aosp_15_r20/external/executorch/backends/xnnpack/_passes/convert_to_sdpa.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import logging
8from typing import Optional
9
10import torch
11from executorch.backends.transforms import get_shape
12
13from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
14from executorch.backends.xnnpack.partition.graphs import sdpa
15from executorch.exir.dialects._ops import ops as exir_ops
16
17from torch.fx.passes.infra.pass_base import PassResult
18from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher
19
20logger = logging.getLogger(__name__)
21logger.setLevel(logging.WARNING)
22
23
24class ConvertToSDPAPass(XNNPACKPass):
25    def get_scale(self, match: InternalMatch) -> Optional[float]:
26        """
27        Returns the scale of the SDPA op.
28
29        Scale: Optional[float] doesn't change the graph pattern.
30        The default value can be calulated however we need to extract
31        it for lowering when it is the user supplied value anyway.
32        """
33        for node in match.nodes_map.values():
34            if (
35                node.op == "call_function"
36                and node.target == exir_ops.edge.aten.mul.Scalar
37            ):
38                scale = node.args[1]
39
40                dtype = torch.float
41                mul_val = node.meta.get("val", None)
42                if mul_val is not None:
43                    dtype = mul_val.dtype
44
45                if isinstance(scale, float):
46                    # Convert scale value to fp16 (reducing precision)
47                    scale = torch.tensor(scale, dtype=dtype).item()
48
49                    # since scale we extracted this before the QK^T.
50                    return scale**2
51                break
52        return None
53
54    def assert_2d_mask(self, match: InternalMatch) -> None:
55        """
56        No better way to do this right now. Ideally we don't want to partition this.
57        """
58        mask = match.placeholder_nodes[-1]
59        mask_shape = get_shape(mask)
60        if len(mask_shape) != 2:
61            raise Exception(f"Mask rank is not 2 got {mask_shape}")
62
63    def create_sdpa(
64        self,
65        graph_module: torch.fx.GraphModule,
66        match: InternalMatch,
67    ):
68        logger.debug(f"Matched Subgraph: {match}")
69
70        scale = self.get_scale(match)
71        assert scale is not None, "Could not find scale"
72        logger.debug(f"scale: {scale}")
73
74        self.assert_2d_mask(match)
75
76        output = match.returning_nodes[0]
77
78        with graph_module.graph.inserting_before(output):
79            sdpa_node = graph_module.graph.create_node(
80                "call_function",
81                exir_ops.edge.aten.scaled_dot_product_attention.default,  # HACK not edge_op/CATen
82                tuple(match.placeholder_nodes),
83                kwargs={"scale": scale},
84            )
85
86        sdpa_node.meta["val"] = sdpa_node.target(  # pyre-fixme[29]
87            *[n.meta["val"] for n in match.placeholder_nodes],
88            scale=scale,
89        )
90
91        logger.debug(
92            f"Replacing {output}{get_shape(output)} node with {sdpa_node}{get_shape(sdpa_node)}"
93        )
94        output.replace_all_uses_with(sdpa_node)
95        graph_module.graph.eliminate_dead_code()
96
97    # override
98    def call(self, graph_module: torch.fx.GraphModule):
99        logger.debug("ConvertToSDPA Begin: ")
100        logger.debug(graph_module.print_readable(print_output=False))
101
102        for pattern in sdpa.get_graphs():
103            sm = SubgraphMatcher(pattern.graph, ignore_literals=True)
104            matches = list(sm.match(graph_module.graph))
105            for partition_to_replace in matches:
106                self.create_sdpa(graph_module, partition_to_replace)
107
108        graph_module.recompile()
109        graph_module = super().call(graph_module).graph_module
110
111        logger.debug("ConvertToSDPA End: ")
112        logger.debug(graph_module.print_readable(print_output=False))
113
114        return PassResult(graph_module, True)
115