xref: /aosp_15_r20/external/pytorch/torch/fx/passes/dialect/common/cse_pass.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Dict, Tuple, Any
3
4import torch
5from torch.fx.passes.infra.pass_base import PassBase, PassResult
6from torch.utils._pytree import tree_flatten
7
8from torch.fx import GraphModule, Graph
9from torch.fx import Node
10
11aten = torch.ops.aten
12
13
14# stateful ops are banned from CSE
15rand_ops = {aten.dropout, aten._fused_dropout, aten._standard_gamma, aten.bernoulli, aten.multinomial, aten.native_dropout, aten.normal, aten.poisson, aten.binomial, aten.rrelu, aten.rand_like, aten.rand, aten.randint, aten.randn, aten.randperm}  # noqa: E501,B950
16
17inplace_ops = {aten.add_, aten.sub_, aten.mul_, aten.div_, aten.pow_, aten.lerp_, aten.relu_, aten.sigmoid_, aten.tanh_}  # noqa: E501
18
19
20@torch.fx._compatibility.compatibility(is_backward_compatible=False)
21def get_CSE_banned_ops():
22    return rand_ops.union(inplace_ops)
23
24
25@torch.fx._compatibility.compatibility(is_backward_compatible=False)
26class CSEPass(PassBase):
27
28    def __init__(self, banned_ops=None):
29        """
30        This version of CSE Pass aims to be dialect agnostic, and it's implemented purely based on the connectivity between fx.Node.
31
32        For functional dialects, user would only need to specify the random ops in ban list.
33
34        Warning: CSE Pass cannot be safely applied on a FX graph in non-functional dialects.
35        If your dialect contains stateful operators, please customized the banned_ops.
36
37        """
38        if banned_ops is None:
39            banned_ops = set()
40        self.banned_ops = banned_ops
41        super().__init__()
42
43    def call(self, graph_module: GraphModule) -> PassResult:
44        """
45        Return a new copy of torch.fx.GraphModule with CSE applied to the input graph
46
47        Example usage:
48
49        from torch.fx.experimental.proxy_tensor import make_fx
50        def f(a):
51            b = a * a
52            c = a * a
53            return b+c
54
55        p = CSEPass()
56        traced_graph = make_fx(f)(torch.tensor(1))
57        print(traced_graph)
58        result = p(traced_graph)
59        print(result.graph_module)
60        """
61        def get_aten_target(node):
62            if hasattr(node.target, 'overloadpacket'):
63                return node.target.overloadpacket
64            return node.target
65
66        modified = False
67        new_graph = Graph()
68        env: Dict[Node, Node] = {}  # map from node in the old graph to node in the new graph
69        hash_env: Dict[Tuple[torch._ops.OpOverload, int], Node] = {}  # map from hash to a node in the new graph
70        token_map: Dict[Tuple[torch._ops.OpOverload, int], Dict[str, Any]] = {}  # map from hash to token
71        for n in graph_module.graph.nodes:
72            # The placeholder, output, and get_attr nodes are copied to the new graph without change
73            # do not CSE away random operations
74            if n.op == 'placeholder' or n.op == 'output' or n.op == 'get_attr' or get_aten_target(n) in self.banned_ops:
75                new_node = new_graph.node_copy(n, lambda x: env[x])
76                env[n] = new_node
77            else:  # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method'
78                # substitute args and kwargs members to their mapping in env if exists
79                # specs can be used to reconstruct nested list/dictionaries
80                def substitute(arg_list):
81                    arg_list, spec = tree_flatten(arg_list)
82                    for i in range(len(arg_list)):
83                        v = arg_list[i]
84                        if isinstance(v, Node) and v in env:
85                            arg_list[i] = env[v]
86                    return tuple(arg_list), spec
87                args, args_spec = substitute(n.args)
88                kwargs, kwargs_spec = substitute(n.kwargs)
89
90                # each token corresponds to a unique node
91                # nodes with the same token can be substituted
92                token = {"target": n.target, "args": args, "args_spec": args_spec,
93                         "kwargs": kwargs, "kwargs_spec": kwargs_spec}
94
95                # hash substituted args to a number, do not hash specs because specs are not hashable
96                hash_arg = hash((args, kwargs))
97                hash_val = (n.target, hash_arg)
98
99                # check if a node has a substitute and can be eliminated
100                hash_val_in_hash_env = hash_val in hash_env
101                if hash_val_in_hash_env and token_map[hash_val] == token:
102                    modified = True  # substitution happens and the graph is modified
103                    env[n] = hash_env[hash_val]
104                    continue
105
106                new_node = new_graph.node_copy(n, lambda x: env[x])
107                env[n] = new_node
108                if not hash_val_in_hash_env:
109                    hash_env[hash_val] = new_node
110                    token_map[hash_val] = token
111
112        csed_gm = GraphModule(graph_module, new_graph)
113        return PassResult(csed_gm, modified)
114