xref: /aosp_15_r20/external/pytorch/torch/fx/passes/split_module.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import inspect
3from typing import Any, Callable, Dict, List, Optional, Set
4from collections import OrderedDict
5import logging
6
7import torch
8from torch.fx._compatibility import compatibility
9from torch.fx.graph_module import GraphModule
10from torch.fx.node import Node
11from torch.fx._utils import lazy_format_graph_code
12
13
14__all__ = ["Partition", "split_module"]
15log = _LOGGER = logging.getLogger(__name__)
16
17@compatibility(is_backward_compatible=True)
18class Partition:
19    def __init__(self, name: str):
20        self.name: str = name
21        self.submod_name = f"submod_{name}"
22        self.node_names: List[str] = []
23        self.inputs: Dict[str, None] = {}
24        self.outputs: Dict[str, None] = {}
25        self.dependencies: Dict[str, None] = {}
26        self.dependents: Dict[str, None] = {}
27        self.graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
28        self.environment: Dict[Node, Node] = {}
29        self.targets: Dict[str, Any] = {}
30
31    def __repr__(self) -> str:
32        return (
33            f"name: {self.name},\n"
34            f" nodes: {self.node_names},\n"
35            f" inputs: {self.inputs},\n"
36            f" outputs: {self.outputs},\n"
37            f" partitions depended on: {self.dependencies},\n"
38            f" partition dependents: {self.dependents}"
39        )
40
41
42# Creates subgraphs out of main graph
43@compatibility(is_backward_compatible=True)
44def split_module(
45    m: GraphModule,
46    root_m: torch.nn.Module,
47    split_callback: Callable[[Node], int],
48    qualname_map: Optional[Dict[str, str]] = None,
49    keep_original_order: Optional[bool] = False,
50    keep_original_node_name: Optional[bool] = False,
51):
52    """
53    Creates subgraphs out of main graph
54
55    Args:
56        m (GraphModule): Graph module to split
57        root_m (torch.nn.Module): root nn module. Not currently used. Included
58            because the root nn module is usually transformed via
59            torch.fx._symbolic_trace.symbolic_trace (see example below)
60        split_callback (Callable[[Node], int]): Callable function
61            that maps a given Node instance to a numeric partition identifier.
62            split_module will use this function as the policy for which operations
63            appear in which partitions in the output Module.
64        qualname_map: Optional[Dict[str, str]]: optional output parameter that returns a
65            mapping from new target names in the module after split to old target
66            names in the original module.
67        keep_original_order: Optional[bool]: keep the original order of the GraphModule
68            or use the Topological order of the new constructed GraphModule
69
70
71    Returns:
72        GraphModule: the module after split.
73
74    Example:
75
76        This is a sample setup:
77
78            import torch
79            from torch.fx.symbolic_trace import symbolic_trace
80            from torch.fx.graph_module import GraphModule
81            from torch.fx.node import Node
82            from torch.fx.passes.split_module import split_module
83
84            class MyModule(torch.nn.Module):
85                def __init__(self) -> None:
86                    super().__init__()
87                    self.param = torch.nn.Parameter(torch.rand(3, 4))
88                    self.linear = torch.nn.Linear(4, 5)
89
90                def forward(self, x, y):
91                    z = self.linear(x + self.param).clamp(min=0.0, max=1.0)
92                    w = self.linear(y).clamp(min=0.0, max=1.0)
93                    return z + w
94
95            # symbolically trace model
96            my_module = MyModule()
97            my_module_traced = symbolic_trace(my_module)
98
99            # random mod partitioning
100            partition_counter = 0
101            NPARTITIONS = 3
102
103            def mod_partition(node: Node):
104                global partition_counter
105                partition = partition_counter % NPARTITIONS
106                partition_counter = (partition_counter + 1) % NPARTITIONS
107                return partition
108
109            # split module in module with submodules
110            module_with_submodules = split_module(
111                my_module_traced, my_module, mod_partition
112            )
113
114        Output looks like this. Original graph is broken into partitions
115
116            > print(module_with_submodules)
117            GraphModule(
118                (submod_0): GraphModule(
119                    (linear): Linear(in_features=4, out_features=5, bias=True)
120                )
121                (submod_1): GraphModule(
122                    (linear): Linear(in_features=4, out_features=5, bias=True)
123                )
124                (submod_2): GraphModule()
125            )
126
127            def forward(self, x, y):
128                param = self.param
129                submod_0 = self.submod_0(x, param, y);  x = param = y = None
130                getitem = submod_0[0]
131                getitem_1 = submod_0[1];  submod_0 = None
132                submod_1 = self.submod_1(getitem, getitem_1);  getitem = getitem_1 = None
133                getitem_2 = submod_1[0]
134                getitem_3 = submod_1[1];  submod_1 = None
135                submod_2 = self.submod_2(getitem_2, getitem_3);  getitem_2 = getitem_3 = None
136                return submod_2
137
138        Output of split module is the same as output of input traced module.
139        This is an example within a test setting:
140
141            > orig_out = my_module_traced(x, y)
142            > submodules_out = module_with_submodules(x, y)
143            > self.assertEqual(orig_out, submodules_out)
144            True
145    """
146
147    log.debug(
148        "%s",
149        lazy_format_graph_code(
150            "pre split_module", m, colored=True
151        ),
152    )
153
154    def construct_graph(
155        node: Node,
156        base_mod_env: Dict[str, Node],
157        base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule],
158    ):
159        if node.op == "placeholder":
160            default_value = (
161                node.args[0] if len(node.args) > 0 else inspect.Signature.empty
162            )
163            if keep_original_node_name:
164                args = () if default_value is inspect.Signature.empty else (default_value,)
165                base_mod_env[node.name] = base_mod_graph.create_node('placeholder', node.name, args=args, type_expr=node.type)  # type: ignore[arg-type]
166            else:
167                base_mod_env[node.name] = base_mod_graph.placeholder(
168                    node.target, type_expr=node.type, default_value=default_value  # type: ignore[arg-type]
169                )
170            base_mod_env[node.name].meta = node.meta.copy()
171        elif node.op == "get_attr":
172            base_mod_env[node.name] = base_mod_graph.get_attr(node.target)  # type: ignore[arg-type]
173            base_mod_env[node.name].meta = node.meta.copy()
174            attr_val = m
175            for atom in node.target.split("."):  # type: ignore[union-attr]
176                if not hasattr(attr_val, atom):
177                    raise AttributeError(f"Node target {node.target} not found!")
178                attr_val = getattr(attr_val, atom)
179            base_mod_attrs[node.target] = attr_val  # type: ignore[index]
180        return base_mod_env, base_mod_attrs
181
182    import sympy
183
184    partitions: Dict[str, Partition] = {}
185    orig_nodes: Dict[str, Node] = {}
186    symbol_to_node: Dict[sympy.Symbol, Node] = {}
187
188    def record_cross_partition_use(
189        def_node: Node, use_node: Optional[Node]
190    ):  # noqa: B950
191        from torch.fx.experimental.symbolic_shapes import free_symbols
192
193        defined = getattr(def_node, "_fx_partition", None)
194        used = getattr(use_node, "_fx_partition", None)
195
196        log.debug(
197            "record_cross_partition_use %s (%s) %s (%s)",
198            def_node.name, defined, use_node.name if use_node is not None else "-", used
199        )
200
201        if defined != used:
202            if defined is not None:
203                def_partition = partitions[defined]
204                def_partition.outputs.setdefault(def_node.name)
205                if used is not None:
206                    def_partition.dependents.setdefault(used)
207
208            if used is not None:
209                use_partition = partitions[used]
210                use_partition.inputs.setdefault(def_node.name)
211                # We have made def_node an input to the use_partition.  If
212                # this input has symbolic symbols in its size, those also must
213                # be made as inputs to the partition
214                if (def_val := def_node.meta.get("example_value")) is not None:
215                    for s in sorted(free_symbols(def_val), key=str):
216                        s_node = symbol_to_node[s]
217                        use_partition.inputs.setdefault(s_node.name)
218                        if symbol_to_node[s].op != "placeholder":
219                            # If the node that defines the symbol is not a
220                            # placeholder, we must make it an output of the
221                            # partition.  Note that this may be in a different
222                            # partition than defined!  Although, this doesn't
223                            # really make a difference for correctness, since
224                            # defined is guaranteed to have the symbol in
225                            # scope and can return it; you just get less
226                            # optimal codegen in this case.
227                            s_defined = getattr(s_node, "_fx_partition", None)
228                            if s_defined is not None:
229                                s_def_partition = partitions[s_defined]
230                                s_def_partition.outputs.setdefault(s_node.name)
231                                s_def_partition.dependents.setdefault(used)
232                if defined is not None:
233                    use_partition.dependencies.setdefault(defined)
234
235    def instantiate_node_partition_mapping(node):
236        partition_name = str(split_callback(node))
237        log.debug("instantiate_node_partition_mapping %s (%s)", node.name, partition_name)
238
239        # add node to partitions
240        partition = partitions.get(partition_name)
241        if partition is None:
242            partitions[partition_name] = partition = Partition(partition_name)
243
244        partition.node_names.append(node.name)
245        node._fx_partition = partition_name
246
247    # Global State Nodes are nodes which by their global state effects,
248    # "taint" all downstream nodes while they are active.
249    GLOBAL_STATE_NODES = [
250        torch.amp._enter_autocast,
251        torch.amp._exit_autocast,
252        torch._C._set_grad_enabled
253    ]
254
255    # For grad regions:
256    # ------------------------
257    # 1. first region: we do nothing
258    # 2. subsequent regions: we insert the set_grad at the beginning
259    grad_regions: OrderedDict[Node, Set[int]] = OrderedDict()
260
261    # For autocast regions:
262    # ------------------------
263    # 1. first region: we will only insert the _exit at the end
264    # 2. intermediate regions: we will insert both the
265    #    _enter at the beginning and _exit at the end
266    # 3. last region: we will only insert _enter at the beginning
267    # We will do so in the order in which the autocasts were instantiated.
268    autocast_regions: OrderedDict[Node, Set[int]] = OrderedDict()
269    autocast_exits: Dict[Node, Optional[Node]] = {}
270
271    active_grad = None
272    active_autocasts = set()
273
274    for node in m.graph.nodes:
275        # This will prefer placeholder bindings, because those come first.
276        # This is a little dangerous though: it is possible that an unbacked
277        # symbol is used without any binding site for it, in which case we
278        # will get a KeyError not able to find it.  I'd like to fix this by
279        # having passes.runtime_assert establish some invariants that I can
280        # rely on later, but this needs some extra work.  Quick fix first.
281        # See https://github.com/pytorch/pytorch/issues/130534
282        if (
283            (val := node.meta.get("example_value")) is not None and
284            isinstance(val, torch.SymInt) and
285            isinstance(s0 := val.node.expr, sympy.Symbol) and
286            s0 not in symbol_to_node
287        ):
288            symbol_to_node[val.node.expr] = node
289
290        if node.op in ["placeholder", "get_attr", "output"]:
291            continue
292
293        instantiate_node_partition_mapping(node)
294
295        if node.op == "call_function" and node.target in GLOBAL_STATE_NODES:
296            if node.target == torch._C._set_grad_enabled:
297                assert len(node.args) == 1
298                assert isinstance(node.args[0], bool)
299                active_grad = node
300                grad_regions[active_grad] = set({split_callback(node)})
301            elif node.target == torch.amp._enter_autocast:
302                # Should all be python constants
303                assert all(not isinstance(arg, Node) for arg in node.args)
304                active_autocasts.add(node)
305                autocast_regions[node] = set({split_callback(node)})
306                autocast_exits[node] = None
307            elif node.target == torch.amp._exit_autocast:
308                assert len(node.args) == 1
309                autocast_regions[node.args[0]].add(split_callback(node))
310                active_autocasts.remove(node.args[0])
311                autocast_exits[node.args[0]] = node
312
313        if active_grad is not None:
314            grad_regions[active_grad].add(split_callback(node))
315
316        for a in active_autocasts:
317            autocast_regions[a].add(split_callback(node))
318
319    assert all(v is not None for v in autocast_exits.values()), "autocast must exit"
320
321    autocast_regions = {k: sorted(v) for k, v in autocast_regions.items()}
322    grad_regions = {k: sorted(v) for k, v in grad_regions.items()}
323
324    if _LOGGER.isEnabledFor(logging.DEBUG):
325        _LOGGER.debug("autocast_regions: %s", autocast_regions)
326        _LOGGER.debug("grad_regions: %s", grad_regions)
327
328    assert_monotonically_increasing = bool(autocast_regions) or bool(grad_regions)
329
330    # split nodes into partitions
331    highest_partition = -1
332    for node in m.graph.nodes:
333        orig_nodes[node.name] = node
334
335        # TODO currently placeholders/parameters aren't put into random partitions,
336        # rather they're added to the graphs where they are used down below
337        if node.op in ["placeholder", "get_attr"]:
338            continue
339        if node.op == "output":
340            torch.fx.graph.map_arg(
341                node.args[0], lambda n: record_cross_partition_use(n, None)
342            )
343            continue
344
345        if assert_monotonically_increasing:
346            pid = split_callback(node)
347            assert highest_partition <= pid, \
348                ("autocast or set_grad_enabled require monotonically increasing partitions:"
349                 f"highest: {highest_partition}, this node's: {pid}")
350            highest_partition = pid
351
352        # do not capture cross-partition dependencies for global state nodes as they will be
353        # self-contained - their setup and unwind will be isolated to each partition submodule.
354        if node.target not in GLOBAL_STATE_NODES:
355            torch.fx.graph.map_arg(
356                node.args, lambda def_node: record_cross_partition_use(def_node, node)
357            )
358            torch.fx.graph.map_arg(
359                node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)
360            )  # noqa: B950
361
362    original_partition_order = list(partitions.keys())
363    # find partitions with no dependencies
364    root_partitions: List[str] = []
365    for partition_name, partition in partitions.items():
366        if not len(partition.dependencies):
367            root_partitions.append(partition_name)
368
369    # check partitions for circular dependencies and create topological partition ordering
370    sorted_partitions: List[str] = []
371    while root_partitions:
372        root_partition = root_partitions.pop()
373        sorted_partitions.append(root_partition)
374        for dependent in partitions[root_partition].dependents:
375            partitions[dependent].dependencies.pop(root_partition)
376            if not partitions[dependent].dependencies:
377                root_partitions.append(dependent)
378    if len(sorted_partitions) != len(partitions):
379        raise RuntimeError("cycle exists between partitions!")
380
381    # Enter prelude
382    for regions_mapping in [autocast_regions, grad_regions]:
383        for node, regions in regions_mapping.items():
384            assert len(regions) > 0
385            partitions[str(regions[0])].environment[node] = node
386            for r in regions[1:]:
387                partition = partitions[str(r)]
388                new_node = partition.graph.create_node(
389                    op=node.op,
390                    target=node.target,
391                    args=tuple(arg for arg in node.args),
392                    kwargs={},
393                    type_expr=node.type,
394                )
395                new_node.meta = node.meta.copy()  # is it really a good idea to copy this?
396                partition.environment[node] = new_node
397
398    # add placeholders to partition inputs
399    for partition_name in sorted_partitions:
400        partition = partitions[partition_name]
401        for inp in partition.inputs:
402            placeholder = partition.graph.placeholder(
403                inp,
404                type_expr=orig_nodes[inp].type,
405            )
406            placeholder.meta = orig_nodes[inp].meta.copy()
407            partition.environment[orig_nodes[inp]] = placeholder
408
409    # Transform nodes and collect targets for partition's submodule
410    for node in m.graph.nodes:
411        if hasattr(node, "_fx_partition"):
412            partition = partitions[node._fx_partition]
413
414            # swap out old graph nodes in kw/args with references to new nodes in this submodule
415            environment = partition.environment
416            gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n])
417            gathered_kwargs = torch.fx.graph.map_arg(
418                node.kwargs, lambda n: environment[n]
419            )
420
421            if node.op not in ["call_module", "get_attr"]:
422                target = node.target
423            else:
424                target_atoms = node.target.split(".")
425                target_attr = m
426                for atom in target_atoms:
427                    if not hasattr(target_attr, atom):
428                        raise AttributeError(f"Operator target {node.target} not found!")
429                    target_attr = getattr(target_attr, atom)
430                # target = target_atoms[-1]
431                target = "_".join(target_atoms)
432                partition.targets[target] = target_attr
433                # Fill in the passed-in mapping from new qualname to old qualname
434                if qualname_map is not None:
435                    # When creating the split module later, the submodules will have
436                    # path prefix matching the corresponding partition's submod_name
437                    qualname = f"{partition.submod_name}.{target}"
438                    qualname_map[qualname] = node.target
439
440            assert isinstance(gathered_args, tuple)
441            assert isinstance(gathered_kwargs, dict)
442            name = node.name if keep_original_node_name else None
443            new_node = partition.graph.create_node(
444                op=node.op,
445                target=target,
446                args=gathered_args,
447                kwargs=gathered_kwargs,
448                type_expr=node.type,
449                name=name,
450            )
451            new_node.meta = node.meta.copy()
452            partition.environment[node] = new_node
453
454    # Exit epilogue
455    for regions_mapping in [autocast_regions]:
456        for node in reversed(regions_mapping):
457            regions = regions_mapping[node]
458            assert len(regions) > 0
459            for r in regions[:-1]:
460                partition = partitions[str(r)]
461                exit_node = autocast_exits[node]
462                assert exit_node is not None, "Missing exit node"
463                new_node = partition.graph.create_node(
464                    op=exit_node.op,
465                    target=exit_node.target,
466                    args=(partition.environment[node],),
467                    kwargs={},
468                    type_expr=exit_node.type,
469                )
470                new_node.meta = exit_node.meta.copy()  # is it really a good idea to copy this?
471
472    # original module environment dict mapping node names to nodes
473    orig_mod_env: Dict[str, Node] = {}
474    # Set up values to construct base module
475    base_mod_env: Dict[str, Node] = {}
476    base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
477    base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {}
478    if not keep_original_order:
479        for node in m.graph.nodes:
480            base_mod_env, base_mod_attrs = construct_graph(
481                node, base_mod_env, base_mod_attrs
482            )
483
484    else:
485        # Go through the graph to construct the mapping dict
486        for node in m.graph.nodes:
487            orig_mod_env[node.name] = node
488
489    # Do some things iterating over the partitions in topological order again:
490    # 1) Finish off submodule Graphs by setting corresponding outputs
491    # 2) Construct GraphModules for each submodule
492    # 3) Construct the base graph by emitting calls to those submodules in
493    #    topological order or original order specified by keep_original_order
494
495    construct_order_partitions = (
496        sorted_partitions if not keep_original_order else original_partition_order
497    )
498
499    already_constructed_attr_nodes = set()
500
501    # We actually need to insert the placeholder nodes in the original order
502    # otherwise graph signature will be wrong.
503    original_order = [node for node in m.graph.nodes if node.op == "placeholder"]
504
505    for partition_name in construct_order_partitions:
506        partition = partitions[partition_name]
507
508        # Set correct output values
509        output_vals = tuple(
510            partition.environment[orig_nodes[name]] for name in partition.outputs
511        )
512
513        # skip output node generation if there are no output values
514        num_output_vals = len(output_vals)
515        if num_output_vals == 1:
516            partition.graph.output(output_vals[0])
517        elif num_output_vals > 1:
518            partition.graph.output(output_vals)
519
520        if keep_original_order:
521            # first get the attr nodes required by this partition
522            orig_mod_attr_nodes: List[Node] = [
523                orig_mod_env[key] for key in partition.inputs if key not in original_order
524            ]
525
526            for node in original_order:
527                if node in already_constructed_attr_nodes:
528                    continue  # already added this attr to the base graph
529                base_mod_env, based_mod_attrs = construct_graph(
530                    node, base_mod_env, base_mod_attrs
531                )
532                already_constructed_attr_nodes.add(node)
533
534            # Construct GraphModule for this partition
535            for node in orig_mod_attr_nodes:  # type: ignore[attr-defined]
536                if node in already_constructed_attr_nodes:
537                    continue
538                base_mod_env, base_mod_attrs = construct_graph(
539                    node, base_mod_env, base_mod_attrs
540                )
541                already_constructed_attr_nodes.add(node)
542
543        base_mod_attrs[partition.submod_name] = torch.fx.graph_module.GraphModule(
544            partition.targets, partition.graph
545        )  # noqa: B950
546
547        # Emit call in base graph to this submodule
548        output_val = base_mod_graph.call_module(
549            partition.submod_name,
550            tuple(base_mod_env[name] for name in partition.inputs),
551        )
552
553        num_outputs = len(partition.outputs)
554        if num_outputs > 1:
555            # Unpack multiple return values from submodule
556            output_val_proxy = torch.fx.proxy.Proxy(output_val)
557            for i, output_name in enumerate(partition.outputs):
558                base_mod_env[output_name] = output_val_proxy[i].node  # type: ignore[index]
559        elif num_outputs == 1:
560            base_mod_env[next(iter(partition.outputs))] = output_val
561
562    for node in m.graph.nodes:
563        if node.op == "output":
564            base_mod_graph.output(
565                torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])
566            )  # noqa: B950
567
568    ret = torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
569    log.debug(
570        "%s",
571        lazy_format_graph_code(
572            "post split_module", ret, colored=True
573        ),
574    )
575    return ret
576