xref: /aosp_15_r20/external/executorch/backends/cadence/aot/pass_utils.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2
3# pyre-strict
4
5from dataclasses import dataclass
6from typing import Callable, List, Optional, Set, Union
7
8import torch
9from executorch.backends.cadence.aot.utils import get_edge_overload_packet
10
11from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
12
13from executorch.exir.pass_base import ExportPass
14from torch._ops import OpOverloadPacket
15
16
17# Is an overlap in tensor lifetime and storage allowed at the current opt level?
18# We allow overlap at opt level >= 2.
19def allow_lifetime_and_storage_overlap(opt_level: int) -> bool:
20    return opt_level >= 2
21
22
23# A dataclass that stores the attributes of an ExportPass.
24@dataclass
25class CadencePassAttribute:
26    opt_level: Optional[int] = None
27    debug_pass: bool = False
28
29
30# A dictionary that maps an ExportPass to its attributes.
31ALL_CADENCE_PASSES: dict[ExportPass, CadencePassAttribute] = {}
32
33
34def get_cadence_pass_attribute(p: ExportPass) -> CadencePassAttribute:
35    return ALL_CADENCE_PASSES[p]
36
37
38# A decorator that registers a pass.
39def register_cadence_pass(
40    pass_attribute: CadencePassAttribute,
41) -> Callable[[ExportPass], ExportPass]:
42    def wrapper(cls: ExportPass) -> ExportPass:
43        ALL_CADENCE_PASSES[cls] = pass_attribute
44        return cls
45
46    return wrapper
47
48
49def get_all_available_cadence_passes() -> Set[ExportPass]:
50    return set(ALL_CADENCE_PASSES.keys())
51
52
53# Create a new filter to filter out relevant passes from all passes.
54def create_cadence_pass_filter(
55    opt_level: int, debug: bool = False
56) -> Callable[[ExportPass], bool]:
57    def _filter(p: ExportPass) -> bool:
58        pass_attribute = get_cadence_pass_attribute(p)
59        return (
60            pass_attribute.opt_level is not None
61            and pass_attribute.opt_level <= opt_level
62            and (not pass_attribute.debug_pass or debug)
63        )
64
65    return _filter
66
67
68# Return the overload packet for the edge or torch op.
69def get_overload_packet(
70    op: Union[Callable[..., str], str],
71) -> Union[OpOverloadPacket, EdgeOpOverloadPacket, None]:
72    return (
73        get_edge_overload_packet(op)
74        if isinstance(op, EdgeOpOverload)
75        else getattr(op, "overloadpacket", None)
76    )
77
78
79# Get the list of node names in a graph module (only for "call_function" ops and
80# EdgeOpOverload targets). This should be used only after to_edge is called.
81def get_node_names_list_from_gm(
82    graph_module: torch.fx.GraphModule,
83) -> list[torch.fx.Node]:
84    graph_nodes = []
85    for node in graph_module.graph.nodes:
86        if node.op != "call_function":
87            continue
88        if not isinstance(node.target, EdgeOpOverload):
89            continue
90        graph_nodes.append(node.name)
91    return graph_nodes
92
93
94def count_node(graph_module: torch.fx.GraphModule, target: torch.fx.node.Target) -> int:
95    """Count the number of nodes with target `target` in the graph."""
96    total = 0
97    for node in graph_module.graph.nodes:
98        if node.op == "call_function" and node.target == target:
99            total += 1
100    return total
101
102
103# Testing utils
104# Return the compute/function nodes in the graph
105def get_compute_nodes_in_gm(graph_module: torch.fx.GraphModule) -> List[torch.fx.Node]:
106    nodes = []
107    for x in graph_module.graph.nodes:
108        if x.op == "call_function":
109            if isinstance(x.target, torch._ops.OpOverload):
110                nodes.append(x.target.overloadpacket)
111            elif isinstance(x.target, EdgeOpOverload):
112                nodes.append(get_edge_overload_packet(x.target))
113    return nodes
114
115
116# Return true if there is no edge from a node with target pred_target to a
117# node with target succ_target in the graph.
118def nodes_not_connected_in_gm(
119    graph_module: torch.fx.GraphModule,
120    pred_target: torch.fx.Node,
121    succ_target: torch.fx.Node,
122) -> bool:
123    for node in graph_module.graph.nodes:
124        if node.target != pred_target:
125            continue
126        for user in node.users:
127            if user.target == succ_target:
128                return False
129    return True
130
131
132# Returns true if there is no instance of a node with target succ_target
133# positioned immediately after a node with target pred_target in the graph
134def nodes_not_adjacent_in_gm(
135    graph_module: torch.fx.GraphModule,
136    pred_target: torch.fx.Node,
137    succ_target: torch.fx.Node,
138) -> bool:
139    for node in graph_module.graph.nodes:
140        if node.target != pred_target:
141            continue
142        if node.next.target == succ_target:
143            return False
144    return True
145