xref: /aosp_15_r20/external/pytorch/torch/fx/passes/tools_common.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import List, Tuple, Union, Dict, Any, Set, Mapping, Optional
3import collections
4from dataclasses import dataclass
5import operator
6
7import torch
8import torch.fx
9from torch.fx.node import _get_qualified_name
10from torch.fx._compatibility import compatibility
11
12__all__ = ['get_acc_ops_name', 'get_node_target', 'is_node_output_tensor', 'FxNetAccFusionsFinder', 'legalize_graph']
13
14Tensors = Union[Tuple[torch.Tensor], List[torch.Tensor]]
15TensorOrTensors = Union[torch.Tensor, Tensors]
16NodeList = List[torch.fx.Node]
17NodeSet = Set[torch.fx.Node]
18Names = List[str]
19CALLABLE_NODE_OPS = {"call_module", "call_function", "call_method"}
20
21
22@compatibility(is_backward_compatible=False)
23def get_acc_ops_name(k):
24    if isinstance(k, str):
25        return k
26    elif k.__module__ and "acc_ops" in k.__module__:
27        return f"acc_ops.{k.__name__}"
28    else:
29        module = k.__module__.replace('torch._ops', 'torch.ops')  # WAR for bug in how torch.ops assigns module
30        return f"{module if module else ''}.{k.__name__}"
31
32
33@compatibility(is_backward_compatible=False)
34def get_node_target(submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node) -> str:
35    """
36    Given a `node` returns its target typename.
37
38    For "call_method" node, return node.target which is the name of that method being called.
39    This could potential lead to conflict but should be okay because normally it's on a tensor.
40
41    For "call_function" node, return typename of node.target.
42
43    For "call_module" node, return typename of the module that node.target point to.
44
45    If seeing "_VariableFunctionsClass" in the target name string, it will be replaced by
46    "torch". e.g. _VariableFunctionsClass.relu would become torch.relu.
47    """
48
49    assert node.op in CALLABLE_NODE_OPS, (
50        "Expect op types of " + ", ".join(CALLABLE_NODE_OPS) + f", but found {node.op}"
51    )
52
53    if node.op == "call_module":
54        assert isinstance(node.target, str)
55        submod = submodules[node.target]
56        submod_type = getattr(submod, "_base_class_origin", type(submod))
57        return get_acc_ops_name(submod_type)
58    elif node.op == "call_function":
59        target: Any = node.target
60        return (
61            f"acc_ops.{target.__name__}"
62            if target.__module__ is not None and "acc_ops" in target.__module__
63            else _get_qualified_name(target)
64        )
65    else:
66        assert isinstance(node.target, str)
67        return node.target
68
69@compatibility(is_backward_compatible=False)
70def is_node_output_tensor(node: torch.fx.Node) -> bool:
71    """Checks if the node output produces a Tensor or not.
72
73    NOTE: This requires to run `ShapeProp` on the containing fx graph before
74    calling this function. This is because it works by checking the `type`
75    metadata on the node. This metadata is produced by the `ShapeProp`.
76    """
77    type_ = node.meta.get("type", None)
78    return type_ is not None and issubclass(type_, torch.Tensor)
79
80@compatibility(is_backward_compatible=False)
81class FxNetAccFusionsFinder:
82    """
83    Finds groups of connected ACC nodes that pass non-tensor data between each other.
84    Such groups are called fusion groups.
85    """
86
87    def __init__(self, module: torch.fx.GraphModule, acc_nodes: NodeSet):
88        self.module = module
89        self.nodes = list(module.graph.nodes)
90        self.acc_nodes = acc_nodes
91
92    @dataclass
93    class FusionGroup:
94        # The smallest idx of nodes in the fusion group after topological sorting all the nodes in the model.
95        top_node_idx: int
96
97        # Nodes in this fusion group.
98        nodes: NodeSet
99
100        # Inputs to this fusion group.
101        inputs: NodeSet
102
103        # Nodes that in the fusion group that haven't been processed yet.
104        nodes_need_process: NodeSet
105
106        def add_node(self, node):
107            """
108            Add a node to fusion group.
109            """
110            if node in self.nodes:
111                return
112
113            self.nodes_need_process.add(node)
114            self.nodes.add(node)
115            self.inputs.discard(node)
116            self.inputs.update(
117                {
118                    n
119                    for n in node.all_input_nodes
120                    if n.op in CALLABLE_NODE_OPS and n not in self.nodes
121                }
122            )
123
124    def recursive_add_node(
125        self,
126        fusion_group: "FxNetAccFusionsFinder.FusionGroup",
127        inputs: Union[NodeSet, NodeList],
128        visited: Optional[NodeSet] = None,
129    ):
130        """
131        Start from inputs and going reverse topological order. If any upstream node
132        is in the fusion group, add all the nodes in this path to fusion group.
133        """
134        for arg in inputs:
135            # skip the node if already seen
136            if visited is not None:
137                if arg in visited:
138                    continue
139                visited.add(arg)
140
141            # Skip placeholder and get_attr because they won't be in the fusion group.
142            if arg.op not in CALLABLE_NODE_OPS:
143                continue
144
145            # If the node has smaller idx, it's already an upstream node of the fusion
146            # group. We don't need to check it anymore.
147            if self.nodes.index(arg) < fusion_group.top_node_idx:
148                continue
149
150            # If the node is in the fusion group, return True.
151            if arg in fusion_group.nodes:
152                return True
153
154            # Check the upstream nodes of the node, if any of them is in the fusion group
155            # we'll add this node to fusion group and return True.
156            if self.recursive_add_node(fusion_group, arg.all_input_nodes, visited):
157                fusion_group.add_node(arg)
158                return True
159
160        return False
161
162    def __call__(self) -> Dict[torch.fx.Node, NodeSet]:
163        result: Dict[torch.fx.Node, NodeSet] = {}
164        acc_nodes = list(self.acc_nodes)
165
166        for node in acc_nodes:
167            if node in result:
168                continue
169            if node.op not in CALLABLE_NODE_OPS:
170                continue
171            if "tensor_meta" in node.meta:
172                continue
173            if node not in self.acc_nodes:
174                continue
175
176            fusion_group: FxNetAccFusionsFinder.FusionGroup = self.FusionGroup(
177                top_node_idx=self.nodes.index(node),
178                nodes={node},
179                inputs=set(node.all_input_nodes),
180                nodes_need_process={node},
181            )
182            while fusion_group.nodes_need_process:
183                node = fusion_group.nodes_need_process.pop()
184                self.recursive_add_node(
185                    fusion_group,
186                    fusion_group.inputs,
187                    visited=set(),
188                )
189
190                # Optionally add downstream nodes
191                if "tensor_meta" not in node.meta:
192                    for user in node.users:
193                        if user.op not in CALLABLE_NODE_OPS:
194                            continue
195                        if user in fusion_group.nodes:
196                            continue
197
198                        fusion_group.add_node(user)
199                        self.recursive_add_node(
200                            fusion_group,
201                            fusion_group.inputs,
202                            visited=set(),
203                        )
204
205                # Add some upstream nodes
206                for arg in node.all_input_nodes:
207                    if arg.op not in CALLABLE_NODE_OPS:
208                        continue
209                    if "tensor_meta" in arg.meta:
210                        continue
211                    if arg in fusion_group.nodes:
212                        continue
213
214                    fusion_group.add_node(arg)
215                    fusion_group.top_node_idx = min(
216                        fusion_group.top_node_idx, self.nodes.index(arg)
217                    )
218                    self.recursive_add_node(
219                        fusion_group,
220                        fusion_group.inputs,
221                        visited=set(),
222                    )
223
224            if not (set(fusion_group.nodes) <= self.acc_nodes):
225                self.acc_nodes -= fusion_group.nodes
226            else:
227                for n in fusion_group.nodes:
228                    result[n] = fusion_group.nodes
229
230        return result
231
232
233@compatibility(is_backward_compatible=False)
234def legalize_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
235    """
236    Replace the graph of the given GraphModule with one that contains the same nodes as the
237    original, but in topologically sorted order.
238
239    This is used by the merge_matmul transformation below, which disturbs the topologically sorted
240    order of its input GraphModule, so that this order is restored before further transformation.
241
242    Arguments:
243        gm: The graph module to topologically sort. It is modified in-place.
244
245    Returns:
246        The graph module in-place sorted
247    """
248
249    # These operators are used for making runtime assertions before any
250    # data-dependent operators occur. We want to prioritize sorting these to
251    # ensure that these assertions appear before any data-dependent operations
252    # in the graph.
253    PRIORITIZED_OPS = [
254        operator.add,
255        operator.mul,
256        operator.sub,
257        operator.floordiv,
258        operator.truediv,
259        operator.mod,
260        operator.le,
261        operator.lt,
262        operator.ge,
263        operator.gt,
264        operator.eq,
265        operator.ne,
266        torch.ops.aten.sym_constrain_range.default,
267        torch.ops.aten.sym_constrain_range_for_size.default,
268        torch.ops.aten._assert_async.msg,
269        torch.ops.aten.scalar_tensor.default,
270        torch.ops.aten._assert_scalar.default,
271    ]
272
273    indeg = dict.fromkeys(gm.graph.nodes, 0)
274    new_graph = torch.fx.Graph()
275    # Track how many unfulfilled dependencies each node has
276    for node in gm.graph.nodes:
277        for user in node.users:
278            indeg[user] += 1
279    queue: collections.deque = collections.deque()
280    # Add all nodes with no dependencies to the queue
281    for node in gm.graph.nodes:
282        if indeg[node] == 0:
283            queue.append(node)
284    env: Dict[torch.fx.Node, torch.fx.Node] = {}
285    # Pop nodes from the queue, and add nodes that have had all their
286    # dependencies fulfilled
287    while len(queue) > 0:
288        cur = queue.popleft()
289        env[cur] = new_graph.node_copy(cur, lambda x: env[x])
290        for user in cur.users:
291            indeg[user] -= 1
292            if indeg[user] == 0:
293                if user.op == "call_function" and user.target in PRIORITIZED_OPS:
294                    queue.appendleft(user)
295                else:
296                    queue.append(user)
297    # If the new graph's size is not as large as the old one, then there must be
298    # a cycle (i.e. some node's dependencies were not satisfied.)
299    if len(new_graph.nodes) < len(gm.graph.nodes):
300        raise RuntimeError(f"Input graph has cycles, unable to add {[node for node in indeg if indeg[node] != 0]}")
301    new_graph._codegen = gm.graph._codegen
302    gm.graph = new_graph
303    return gm
304