1# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 2 3# pyre-strict 4 5from dataclasses import dataclass 6from typing import Callable, List, Optional, Set, Union 7 8import torch 9from executorch.backends.cadence.aot.utils import get_edge_overload_packet 10 11from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket 12 13from executorch.exir.pass_base import ExportPass 14from torch._ops import OpOverloadPacket 15 16 17# Is an overlap in tensor lifetime and storage allowed at the current opt level? 18# We allow overlap at opt level >= 2. 19def allow_lifetime_and_storage_overlap(opt_level: int) -> bool: 20 return opt_level >= 2 21 22 23# A dataclass that stores the attributes of an ExportPass. 24@dataclass 25class CadencePassAttribute: 26 opt_level: Optional[int] = None 27 debug_pass: bool = False 28 29 30# A dictionary that maps an ExportPass to its attributes. 31ALL_CADENCE_PASSES: dict[ExportPass, CadencePassAttribute] = {} 32 33 34def get_cadence_pass_attribute(p: ExportPass) -> CadencePassAttribute: 35 return ALL_CADENCE_PASSES[p] 36 37 38# A decorator that registers a pass. 39def register_cadence_pass( 40 pass_attribute: CadencePassAttribute, 41) -> Callable[[ExportPass], ExportPass]: 42 def wrapper(cls: ExportPass) -> ExportPass: 43 ALL_CADENCE_PASSES[cls] = pass_attribute 44 return cls 45 46 return wrapper 47 48 49def get_all_available_cadence_passes() -> Set[ExportPass]: 50 return set(ALL_CADENCE_PASSES.keys()) 51 52 53# Create a new filter to filter out relevant passes from all passes. 54def create_cadence_pass_filter( 55 opt_level: int, debug: bool = False 56) -> Callable[[ExportPass], bool]: 57 def _filter(p: ExportPass) -> bool: 58 pass_attribute = get_cadence_pass_attribute(p) 59 return ( 60 pass_attribute.opt_level is not None 61 and pass_attribute.opt_level <= opt_level 62 and (not pass_attribute.debug_pass or debug) 63 ) 64 65 return _filter 66 67 68# Return the overload packet for the edge or torch op. 69def get_overload_packet( 70 op: Union[Callable[..., str], str], 71) -> Union[OpOverloadPacket, EdgeOpOverloadPacket, None]: 72 return ( 73 get_edge_overload_packet(op) 74 if isinstance(op, EdgeOpOverload) 75 else getattr(op, "overloadpacket", None) 76 ) 77 78 79# Get the list of node names in a graph module (only for "call_function" ops and 80# EdgeOpOverload targets). This should be used only after to_edge is called. 81def get_node_names_list_from_gm( 82 graph_module: torch.fx.GraphModule, 83) -> list[torch.fx.Node]: 84 graph_nodes = [] 85 for node in graph_module.graph.nodes: 86 if node.op != "call_function": 87 continue 88 if not isinstance(node.target, EdgeOpOverload): 89 continue 90 graph_nodes.append(node.name) 91 return graph_nodes 92 93 94def count_node(graph_module: torch.fx.GraphModule, target: torch.fx.node.Target) -> int: 95 """Count the number of nodes with target `target` in the graph.""" 96 total = 0 97 for node in graph_module.graph.nodes: 98 if node.op == "call_function" and node.target == target: 99 total += 1 100 return total 101 102 103# Testing utils 104# Return the compute/function nodes in the graph 105def get_compute_nodes_in_gm(graph_module: torch.fx.GraphModule) -> List[torch.fx.Node]: 106 nodes = [] 107 for x in graph_module.graph.nodes: 108 if x.op == "call_function": 109 if isinstance(x.target, torch._ops.OpOverload): 110 nodes.append(x.target.overloadpacket) 111 elif isinstance(x.target, EdgeOpOverload): 112 nodes.append(get_edge_overload_packet(x.target)) 113 return nodes 114 115 116# Return true if there is no edge from a node with target pred_target to a 117# node with target succ_target in the graph. 118def nodes_not_connected_in_gm( 119 graph_module: torch.fx.GraphModule, 120 pred_target: torch.fx.Node, 121 succ_target: torch.fx.Node, 122) -> bool: 123 for node in graph_module.graph.nodes: 124 if node.target != pred_target: 125 continue 126 for user in node.users: 127 if user.target == succ_target: 128 return False 129 return True 130 131 132# Returns true if there is no instance of a node with target succ_target 133# positioned immediately after a node with target pred_target in the graph 134def nodes_not_adjacent_in_gm( 135 graph_module: torch.fx.GraphModule, 136 pred_target: torch.fx.Node, 137 succ_target: torch.fx.Node, 138) -> bool: 139 for node in graph_module.graph.nodes: 140 if node.target != pred_target: 141 continue 142 if node.next.target == succ_target: 143 return False 144 return True 145