xref: /aosp_15_r20/external/pytorch/torch/_functorch/compile_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3
4from typing import Callable
5
6import torch
7import torch.fx as fx
8from torch.multiprocessing.reductions import StorageWeakRef
9from torch.utils import _pytree as pytree
10from torch.utils._pytree import tree_flatten
11
12
13aten = torch.ops.aten
14
15
16def get_aten_target(node: fx.Node) -> Callable:
17    if hasattr(node.target, "overloadpacket"):
18        return node.target.overloadpacket
19    return node.target
20
21
22rand_ops = [
23    aten.dropout,
24    aten._fused_dropout,
25    aten._standard_gamma,
26    aten.bernoulli,
27    aten.multinomial,
28    aten.native_dropout,
29    aten.normal,
30    aten.poisson,
31    aten.binomial,
32    aten.rrelu,
33    aten.rand_like,
34    aten.rand,
35    aten.randint,
36    aten.randn,
37    aten.randperm,
38]
39
40
41# return a new copy of torch.fx.graph.Graph with CSE applied to the input graph
42def fx_graph_cse(fx_g: torch.fx.graph.Graph):
43    new_graph = fx.Graph()
44    env = {}  # map from node in the old graph to node in the new graph
45    hash_env = {}  # map from hash to a node in the new graph
46    token_map = {}  # map from hash to token
47
48    from torch._inductor.pattern_matcher import (
49        compute_mutation_region_ids,
50        same_mutation_regions,
51    )
52
53    compute_mutation_region_ids(fx_g)  # type: ignore[arg-type]
54
55    # Make a set of separate storages returned from the output, which will be preserved
56    # when pruning.  This prevents us from deduplicating returned tensors which have
57    # experienced identical operations, but are separate data structures in eager mode.
58    output_node: fx.Node = list(fx_g.nodes)[-1]
59    assert output_node.op == "output"
60
61    def checkable_node(node: fx.Node) -> bool:
62        """We can evaluate only nodes that represent tensors with defined storage."""
63        if "val" not in node.meta or not isinstance(node.meta["val"], torch.Tensor):
64            return False
65
66        try:
67            node.meta["val"].untyped_storage()
68        except NotImplementedError:
69            return False
70
71        return True
72
73    output_storages = {
74        StorageWeakRef(n.meta["val"].untyped_storage())
75        for n in output_node.all_input_nodes
76        if checkable_node(n)
77    }
78    nodes_that_alias_outputs = {
79        n
80        for n in fx_g.nodes
81        if checkable_node(n)
82        and StorageWeakRef(n.meta["val"].untyped_storage()) in output_storages
83    }
84
85    for n in fx_g.nodes:
86        # The placeholder, output, and get_attr nodes are copied to the new graph without change
87        # do not CSE away random operations
88        if (
89            n.op == "placeholder"
90            or n.op == "output"
91            or n.op == "get_attr"
92            or get_aten_target(n) in rand_ops
93            # aten.empty is non-deterministic, so don't CSE it.
94            # Also, aten.empty is almost always fusible into its consumer,
95            # so it's not worth CSEing.
96            or get_aten_target(n) is aten.empty
97            or n in nodes_that_alias_outputs
98        ):
99            new_node = new_graph.node_copy(n, lambda x: env[x])
100            env[n] = new_node
101        else:  # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method'
102            # substitute args and kwargs members to their mapping in env if exists
103            # specs can be used to reconstruct nested list/dictionaries
104            def substitute(arg_list):
105                arg_list, spec = tree_flatten(arg_list)
106                for i in range(len(arg_list)):
107                    v = arg_list[i]
108                    if isinstance(v, torch.fx.node.Node) and v in env:
109                        arg_list[i] = env[v]
110                    if isinstance(v, (torch.SymBool, torch.SymInt, torch.SymFloat)):
111                        arg_list[i] = v.node
112                return tuple(arg_list), spec
113
114            args, args_spec = substitute(n.args)
115            kwargs, kwargs_spec = substitute(n.kwargs)
116
117            # each token corresponds to a unique node
118            # nodes with the same token can be substituted
119            token = {
120                "target": n.target,
121                "args": args,
122                "args_spec": args_spec,
123                "kwargs": kwargs,
124                "kwargs_spec": kwargs_spec,
125            }
126
127            # hash substituted args to a number, do not hash specs because specs are not hashable
128            # We need to add type into hash to avoid situations like:
129            # hash((primals_2, 1.0)) == hash((primals_2, 1))
130            hash_arg = hash(
131                (tuple((a, type(a)) for a in args), tuple((a, type(a)) for a in kwargs))
132            )
133            hash_val = (n.target, hash_arg)
134
135            # check if a node has a substitute and can be eliminated
136            hash_val_in_hash_env = hash_val in hash_env
137            overwrite_due_to_mutation = False
138            if hash_val_in_hash_env and token_map[hash_val] == token:
139                duplicate_n_prev = hash_env[hash_val]
140                if same_mutation_regions(n, duplicate_n_prev):
141                    env[n] = duplicate_n_prev
142                    continue
143                else:
144                    # any futures duplicates should replace with n, not duplicate_n_prev
145                    overwrite_due_to_mutation = True
146
147            new_node = new_graph.node_copy(n, lambda x: env[x])
148            env[n] = new_node
149            if overwrite_due_to_mutation or not hash_val_in_hash_env:
150                hash_env[hash_val] = new_node
151                token_map[hash_val] = token
152
153    return new_graph
154
155
156def strip_overloads(gm):
157    """
158    Modifies the target of graph nodes in :attr:`gm` to strip overloads.
159
160    Args:
161        gm(fx.GraphModule): The input Fx graph module to be modified
162    """
163    for node in gm.graph.nodes:
164        if isinstance(node.target, torch._ops.OpOverload):
165            node.target = node.target.overloadpacket
166    gm.recompile()
167
168
169def get_placeholders(graph):
170    return graph.find_nodes(op="placeholder")
171
172
173def get_outputs(graph):
174    for node in graph.find_nodes(op="output"):
175        return pytree.tree_leaves(node.args[0])
176    raise AssertionError("No output node found")
177