1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import logging 8from typing import Optional 9 10import torch 11from executorch.backends.transforms import get_shape 12 13from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass 14from executorch.backends.xnnpack.partition.graphs import sdpa 15from executorch.exir.dialects._ops import ops as exir_ops 16 17from torch.fx.passes.infra.pass_base import PassResult 18from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher 19 20logger = logging.getLogger(__name__) 21logger.setLevel(logging.WARNING) 22 23 24class ConvertToSDPAPass(XNNPACKPass): 25 def get_scale(self, match: InternalMatch) -> Optional[float]: 26 """ 27 Returns the scale of the SDPA op. 28 29 Scale: Optional[float] doesn't change the graph pattern. 30 The default value can be calulated however we need to extract 31 it for lowering when it is the user supplied value anyway. 32 """ 33 for node in match.nodes_map.values(): 34 if ( 35 node.op == "call_function" 36 and node.target == exir_ops.edge.aten.mul.Scalar 37 ): 38 scale = node.args[1] 39 40 dtype = torch.float 41 mul_val = node.meta.get("val", None) 42 if mul_val is not None: 43 dtype = mul_val.dtype 44 45 if isinstance(scale, float): 46 # Convert scale value to fp16 (reducing precision) 47 scale = torch.tensor(scale, dtype=dtype).item() 48 49 # since scale we extracted this before the QK^T. 50 return scale**2 51 break 52 return None 53 54 def assert_2d_mask(self, match: InternalMatch) -> None: 55 """ 56 No better way to do this right now. Ideally we don't want to partition this. 57 """ 58 mask = match.placeholder_nodes[-1] 59 mask_shape = get_shape(mask) 60 if len(mask_shape) != 2: 61 raise Exception(f"Mask rank is not 2 got {mask_shape}") 62 63 def create_sdpa( 64 self, 65 graph_module: torch.fx.GraphModule, 66 match: InternalMatch, 67 ): 68 logger.debug(f"Matched Subgraph: {match}") 69 70 scale = self.get_scale(match) 71 assert scale is not None, "Could not find scale" 72 logger.debug(f"scale: {scale}") 73 74 self.assert_2d_mask(match) 75 76 output = match.returning_nodes[0] 77 78 with graph_module.graph.inserting_before(output): 79 sdpa_node = graph_module.graph.create_node( 80 "call_function", 81 exir_ops.edge.aten.scaled_dot_product_attention.default, # HACK not edge_op/CATen 82 tuple(match.placeholder_nodes), 83 kwargs={"scale": scale}, 84 ) 85 86 sdpa_node.meta["val"] = sdpa_node.target( # pyre-fixme[29] 87 *[n.meta["val"] for n in match.placeholder_nodes], 88 scale=scale, 89 ) 90 91 logger.debug( 92 f"Replacing {output}{get_shape(output)} node with {sdpa_node}{get_shape(sdpa_node)}" 93 ) 94 output.replace_all_uses_with(sdpa_node) 95 graph_module.graph.eliminate_dead_code() 96 97 # override 98 def call(self, graph_module: torch.fx.GraphModule): 99 logger.debug("ConvertToSDPA Begin: ") 100 logger.debug(graph_module.print_readable(print_output=False)) 101 102 for pattern in sdpa.get_graphs(): 103 sm = SubgraphMatcher(pattern.graph, ignore_literals=True) 104 matches = list(sm.match(graph_module.graph)) 105 for partition_to_replace in matches: 106 self.create_sdpa(graph_module, partition_to_replace) 107 108 graph_module.recompile() 109 graph_module = super().call(graph_module).graph_module 110 111 logger.debug("ConvertToSDPA End: ") 112 logger.debug(graph_module.print_readable(print_output=False)) 113 114 return PassResult(graph_module, True) 115