xref: /aosp_15_r20/external/pytorch/torch/fx/experimental/const_fold.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import re
3from typing import Callable, Dict, Optional, Set, Union
4
5import torch.fx
6from torch.fx.node import map_arg
7from torch.fx.passes.split_module import split_module
8
9
10__all__ = ['FoldedGraphModule', 'get_unique_attr_name_in_module', 'split_const_subgraphs']
11
12class FoldedGraphModule(torch.fx.GraphModule):
13    """
14    FoldedGraphModule is a GraphModule which also contains another
15    `const_subgraph_module` representing a subgraph which has all const attr
16    inputs and which can be run once before running the main standard
17    `graph`. The `const_output_names` are the ordered list names of attrs which
18    represent what each respective output from the const_subgraph should be set
19    on which attrs.
20    """
21
22    def __init__(
23        self,
24        root: torch.nn.Module,
25        graph: torch.fx.Graph,
26        const_subgraph: Optional[torch.fx.Graph] = None,
27        fx_const_folded_attrs_name: Optional[str] = None,
28        device_for_folded_attrs: str = "cuda",
29    ):
30        super().__init__(root, graph)
31        self.const_subgraph_module = (
32            None
33            if const_subgraph is None
34            else torch.fx.GraphModule(root, const_subgraph)
35        )
36        self.has_folding_been_run = False
37        self.fx_const_folded_attrs_name = fx_const_folded_attrs_name
38        self.device_for_folded_attrs = device_for_folded_attrs
39
40    def __call__(self, *args, **kwargs):
41        if not self.has_folding_been_run:
42            self.run_folding()
43        return super().__call__(*args)
44
45    def run_folding(self):
46        # If there's no const subgraph module or attr output names to use, return
47        # early as there is no const folding to perform.
48        if (
49            self.const_subgraph_module is None
50            or self.fx_const_folded_attrs_name is None
51        ):
52            return
53
54        assert not self.has_folding_been_run
55        self.has_folding_been_run = True
56
57        # Actually run const folding subgraph. Note that single attr const fold
58        # subgraphs output a single Tensor while multiple outputs are returned as
59        # Tuple[Tensor,].
60        folded_attrs = self.const_subgraph_module()
61
62        def _create_param(i):
63            return torch.nn.Parameter(
64                i.detach().clone()
65                if not isinstance(i, int)
66                else torch.Tensor([i]).to(device=self.device_for_folded_attrs),
67                requires_grad=i.requires_grad if isinstance(i, torch.Tensor) else False,
68            )
69
70        params = (
71            torch.nn.ParameterList([_create_param(i) for i in folded_attrs])
72            if isinstance(folded_attrs, tuple)
73            else _create_param(folded_attrs)
74        )
75        setattr(self, self.fx_const_folded_attrs_name, params)
76
77
78def _inline_module(gm: torch.fx.GraphModule, inline_mod_name: str):
79    """
80    Given `gm` and some graph module which is called with target name `inline_mod_name`,
81    this helper will inline all of the nodes from that called graph module into `gm`.
82    """
83    # Fetch the inner graph module that we want to inline inside `gm`.
84    inline_mod = dict(gm.named_modules())[inline_mod_name]
85    assert isinstance(inline_mod, torch.fx.GraphModule)
86    call_mod_node_to_replace = None
87    for node in gm.graph.nodes:
88        if node.op == "call_module" and node.target == inline_mod_name:
89            call_mod_node_to_replace = node
90            break
91    assert call_mod_node_to_replace is not None
92
93    # Now actually do the swap. Note that we have to keep track of new nodes that are
94    # copied into `gm` -- we do this via replacement_mapping.
95    call_mod_args = call_mod_node_to_replace.args
96    replacement_mapping: Dict[torch.fx.Node, torch.fx.Node] = {}
97    ph_count = 0
98
99    def replacement_fn(node):
100        new_node = replacement_mapping[node]
101        new_node.meta = node.meta.copy()
102        return new_node
103
104    for inline_node in inline_mod.graph.nodes:
105        if inline_node.op == "placeholder":
106            replacement_mapping[inline_node] = call_mod_args[ph_count]
107            ph_count += 1
108            continue
109
110        if inline_node.op == "output":
111            outputs = inline_node.args[0]
112            output_replacements = map_arg(outputs, replacement_fn)
113            call_mod_node_to_replace.replace_all_uses_with(output_replacements)
114            continue
115
116        with gm.graph.inserting_before(call_mod_node_to_replace):
117            new_node = gm.graph.node_copy(inline_node, replacement_fn)
118        replacement_mapping[inline_node] = new_node
119
120    gm.graph.eliminate_dead_code()
121
122
123def get_unique_attr_name_in_module(mod_traced: torch.fx.GraphModule, name: str) -> str:
124    """
125    Make sure the name is unique (in a module) and can represents an attr.
126    """
127    # Delete all characters that are illegal in a Python identifier.
128    name = re.sub("[^0-9a-zA-Z_]+", "_", name)
129    if name[0].isdigit():
130        name = f"_{name}"
131    # Now make sure it is in fact unique to the module by incrementing suffix value.
132    while hasattr(mod_traced, name):
133        match = re.match(r"(.*)_(\d+)$", name)
134        if match is None:
135            name = name + "_1"
136        else:
137            base, num = match.group(1, 2)
138            name = f"{base}_{int(num) + 1}"
139
140    return name
141
142
143def split_const_subgraphs(
144    module: Union[torch.nn.Module, torch.fx.GraphModule],
145    skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
146    device_for_folded_attrs: str = "cpu",
147) -> FoldedGraphModule:
148    """
149    Looks through `module` for any nodes that have all constant attribute inputs
150    and separates them out into their own constant subgraph, and returns a
151    FoldedGraphModule which runs that constant subgraph on the first run to set
152    attributes on the module prior to running the non-constant portion of the
153    graph.
154    """
155    if not isinstance(module, torch.fx.GraphModule):
156        mod_traced = torch.fx.symbolic_trace(module)
157    else:
158        mod_traced = module
159
160    # Build up a list of const_nodes, defined as nodes that are themselves
161    # get_attrs, or have all get_attr or other constant node inputs.
162    const_nodes: Set[torch.fx.Node] = set()
163    found_const_folding = False
164    for node in mod_traced.graph.nodes:
165        # Skip over placeholders/outputs because they can't be const folded and
166        # we don't want to add tags to them.
167        if node.op in {"placeholder", "output"}:
168            continue
169
170        # If the node itself is constant, or all of its inputs are constant,
171        # then tag it as constant.
172        if node.op != "get_attr" and not set(node.all_input_nodes).issubset(
173            const_nodes
174        ):
175            continue
176
177        # If provided skip folding function says to skip, then skip.
178        if skip_folding_node_fn and skip_folding_node_fn(node):
179            continue
180
181        # Skip folding side-effectful functions
182        if node.is_impure():
183            continue
184
185        # Must be a constant foldable node at this point.
186        const_nodes.add(node)
187        if node.op != "get_attr":
188            found_const_folding = True
189
190    # If we did not find any const folding then return early without a const fold subgraph.
191    if not found_const_folding:
192        return FoldedGraphModule(mod_traced, mod_traced.graph)
193
194    # Partition the module into two: submod_0 for constant folding subgraph, and
195    # submod_1 for the rest.
196    def mod_partition(node: torch.fx.Node):
197        return 0 if node in const_nodes else 1
198
199    split = split_module(mod_traced, module, mod_partition)
200
201    const_mod_name, non_const_mod_name = "submod_0", "submod_1"
202    # Safely get submod_1 in case there are no non-const nodes
203    const_gm, non_const_gm = split.submod_0, getattr(split, non_const_mod_name, None)
204
205    # The module that a call_module node refers to gets copied to submodules during split.
206    # The path to the module also gets inlined, i.e. mod.a.b -> mod_a_b. Here we need to
207    # attach inlined modules to `split` as it's the owning module now.
208    for node in non_const_gm.graph.nodes if non_const_gm else []:
209        if node.op == "call_module":
210            setattr(split, node.target, getattr(non_const_gm, node.target))
211    for node in const_gm.graph.nodes:
212        if node.op == "call_module":
213            setattr(split, node.target, getattr(const_gm, node.target))
214
215    # split_module currently does not use get_attrs for attrs. Instead it passes
216    # them in as args from the parent module, which used get_attrs. Here we set
217    # them as get_attrs inside const_gm, allowing for running folding without
218    # somehow a priori knowing the attrs that should be passed as args. We can
219    # unconditionally do this for all placeholders because we know all
220    # placeholders to const_gm must be constants accessible via get_attr.
221    call_const_gm_args = None
222    for node in split.graph.nodes:
223        if node.op == "call_module":
224            if node.target == const_mod_name:
225                call_const_gm_args = node.args
226                break
227    assert call_const_gm_args is not None
228
229    # Here we do the actual replacement of placeholders to get_attrs. Note that here we
230    # set the const_gm.graph into a new root_const_gm with split as the root module,
231    # because we are fetching attributes directly from the root module, instead of
232    # fetching them from const_gm. Example: The const_gm must have some format like:
233    # graph():
234    #    %inp : [num_users=1] = placeholder[target=const_inp]
235    #    %add : [num_users=1] = call_function[target=operator.add](args = (%inp, %inp), kwargs = {})
236    #    return add
237    # We replace that with the following, which does not have any placeholders:
238    # graph():
239    #    %inp_1 : [num_users=1] = get_attr[target=const_inp]
240    #    %add : [num_users=1] = call_function[target=operator.add](args = (%inp_1, %inp_1), kwargs = {})
241    #    return add
242    root_const_gm = torch.fx.GraphModule(split, const_gm.graph)
243    for node in root_const_gm.graph.nodes:
244        if node.op == "output":
245            multiple_outputs = isinstance(node.args[0], tuple)
246            continue
247        if node.op != "placeholder":
248            continue
249        in_node = next(n for n in call_const_gm_args if n.name == node.target)
250        assert in_node.op == "get_attr"
251        with root_const_gm.graph.inserting_before(node):
252            new_node = root_const_gm.graph.get_attr(in_node.target)
253        new_node.meta = node.meta.copy()
254        node.replace_all_uses_with(new_node)
255        root_const_gm.graph.erase_node(node)
256    assert "multiple_outputs" in locals()
257
258    # Now find the call to const_gm inside split, and replace it with a getattr to the
259    # folded tensor(s) that result from constant folding. Note that we don't need to
260    # worry about whether this is one or more tensors because the original graph
261    # correctly uses getitem to extract individual tensors if there are multiple folded.
262    fx_const_folded_attrs_name = get_unique_attr_name_in_module(
263        mod_traced, "_FX_CONST_FOLDED_ATTRS"
264    )
265    setattr(
266        split,
267        fx_const_folded_attrs_name,
268        torch.nn.ParameterList() if multiple_outputs else torch.nn.Parameter(),  # type: ignore[possibly-undefined]
269    )
270    for node in split.graph.nodes:
271        if node.op == "call_module" and node.target == const_mod_name:
272            with node.graph.inserting_before(node):
273                folded_attrs = node.graph.get_attr(fx_const_folded_attrs_name)
274            folded_attrs.meta = node.meta.copy()
275            node.replace_all_uses_with(folded_attrs)
276            break
277
278    # Finally, inline the non-constant submod (if it exists) into the split submod.
279    # This is so that the original caller who may have passed in a graph module will
280    # get back out a graph module whose graph is traced to the same granularity.
281    if hasattr(split, non_const_mod_name):
282        _inline_module(split, non_const_mod_name)
283
284    split.graph.eliminate_dead_code()
285
286    return FoldedGraphModule(
287        split,
288        split.graph,
289        root_const_gm.graph,
290        fx_const_folded_attrs_name,
291        device_for_folded_attrs,
292    )
293