xref: /aosp_15_r20/external/pytorch/torch/export/unflatten.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import abc
3import copy
4import operator
5from collections import defaultdict
6from contextlib import contextmanager
7from copy import deepcopy
8from enum import Enum
9from typing import Any, cast, Dict, List, Optional, Set, Tuple, Union
10
11import torch
12import torch.fx._pytree as fx_pytree
13import torch.utils._pytree as pytree
14from torch._library.fake_class_registry import FakeScriptObject
15from torch.export._tree_utils import reorder_kwargs
16from torch.export.exported_program import (
17    ConstantArgument,
18    ExportedProgram,
19    InputKind,
20    ModuleCallSignature,
21    SymIntArgument,
22    TensorArgument,
23)
24from torch.fx._symbolic_trace import is_fx_tracing
25from torch.fx.graph_module import _print_readable
26from torch.utils._pytree import GetAttrKey, SequenceKey
27
28from ._remove_effect_tokens_pass import _remove_effect_tokens
29
30
31__all__ = ["InterpreterModule", "UnflattenedModule", "unflatten", "FlatArgsAdapter"]
32
33
34class _AttrKind(Enum):
35    PARAMETER = "parameter"
36    BUFFER = "buffer"
37    CONSTANT = "constant"
38
39
40RUN_WITH_INTERPRETER = True
41
42
43@contextmanager
44def _disable_interpreter():
45    global RUN_WITH_INTERPRETER
46    old_flag = RUN_WITH_INTERPRETER
47    RUN_WITH_INTERPRETER = False
48    try:
49        yield
50    finally:
51        RUN_WITH_INTERPRETER = old_flag
52
53
54# Assign attribute 'from_obj' to the qualified name 'target' on 'to_module
55# This installs empty Modules where none exist yet if they are subpaths of target
56def _assign_attr(
57    from_obj: Union[torch.Tensor, torch.ScriptObject],
58    to_module: torch.nn.Module,
59    target: str,
60    attr_kind: _AttrKind,
61    persistent: bool = True,
62):
63    *prefix, field = target.split(".")
64    for item in prefix:
65        t = getattr(to_module, item, None)
66
67        if t is None:
68            t = torch.nn.Module()
69            setattr(to_module, item, t)
70        to_module = t
71
72    if attr_kind == _AttrKind.PARAMETER:
73        assert isinstance(from_obj, torch.nn.Parameter)
74        to_module.register_parameter(field, from_obj)
75    elif attr_kind == _AttrKind.BUFFER:
76        assert isinstance(from_obj, torch.Tensor)
77        to_module.register_buffer(field, from_obj, persistent=persistent)
78    elif attr_kind == _AttrKind.CONSTANT:
79        assert not isinstance(
80            from_obj, FakeScriptObject
81        ), "FakeScriptObject should only exist during tracing."
82        assert isinstance(
83            from_obj,
84            (
85                torch.Tensor,
86                torch.ScriptObject,
87            ),
88        )
89        setattr(to_module, field, from_obj)
90
91
92class InterpreterModule(torch.nn.Module):
93    """A module that uses torch.fx.Interpreter to execute instead of the usual
94    codegen that GraphModule uses. This provides better stack trace information
95    and makes it easier to debug execution.
96    """
97
98    def __init__(
99        self,
100        graph: torch.fx.Graph,
101    ):
102        super().__init__()
103        self.graph = graph
104        self.graph.owning_module = self
105        self._run_with_interpeter = RUN_WITH_INTERPRETER
106
107    def forward(self, *args, **kwargs):
108        assert self.graph_module is not None, "Didn't finalize this InterpreterModule"
109        if not is_fx_tracing() and (
110            torch.compiler.is_dynamo_compiling() or not self._run_with_interpeter
111        ):
112            # Dynamo cannot trace through torch.fx.Interpreter, so fall back to
113            # GraphModule codegen in this instance.
114            # Patch the codegened forward to run with this InterpreterModule,
115            # so attribute accesses, etc. are on this module instead.
116            return type(self.graph_module).forward(self, *args, **kwargs)
117        else:
118            if kwargs:
119                # Handle **kwargs. FX only natively supports positional
120                # arguments (through placeholders). So in order to pass in
121                # kwargs, we must correspond the names of the placeholders with
122                # the keys in the kwarg dict.
123                arg_list = list(args)
124                kwarg_names = self.arg_names[len(arg_list) :]
125                for kwarg_name in kwarg_names:
126                    if kwarg_name in kwargs:
127                        arg_list.append(kwargs[kwarg_name])
128
129                # Assert that the kwargs passed in exactly match the positional
130                # arguments specified by the GraphModule. This should be
131                # guaranteed by the unflattening process.
132                assert len(kwarg_names) == len(kwargs)
133                assert len(arg_list) == len(self.arg_names)
134                args = tuple(arg_list)
135
136            return torch.fx.Interpreter(self, graph=self.graph).run(
137                *args, enable_io_processing=False
138            )
139
140    def finalize(self):
141        # We need to "finalize" because GraphModule populates its own state_dict
142        # based on the get_attrs observed in the graph. So we need to fully
143        # construct the graph and call _sink_params before generating this
144        # GraphModule.
145
146        # need to set `graph_module` directly on the dict to avoid it getting
147        # registered as a submodule.
148        self.__dict__["graph_module"] = torch.fx.GraphModule(self, self.graph)
149        self.graph.lint()
150
151        # Cache arg names for kwarg handling (see forward())
152        self.arg_names = []
153        for node in self.graph.nodes:
154            if node.op == "placeholder":
155                self.arg_names.append(node.target)
156
157    def print_readable(
158        self,
159        print_output=True,
160        include_stride=False,
161        include_device=False,
162        colored=False,
163    ):
164        return _print_readable(
165            self,
166            "InterpreterModule",
167            print_output,
168            include_stride,
169            include_device,
170            colored,
171        )
172
173
174class FlatArgsAdapter(abc.ABC):
175    """
176    Adapts input arguments with ``input_spec`` to align ``target_spec``.
177    """
178
179    @abc.abstractmethod
180    def adapt(
181        self,
182        target_spec: pytree.TreeSpec,
183        input_spec: pytree.TreeSpec,
184        input_args: List[Any],
185    ) -> List[Any]:
186        """NOTE: This adapter may mutate given ``input_args_with_path``."""
187        ...
188
189
190class UnflattenedModule(torch.nn.Module):
191    def __init__(
192        self,
193        export_module: ExportedProgram,
194        flat_args_adapter: Optional[FlatArgsAdapter] = None,
195    ):
196        super().__init__()
197        if export_module.graph_signature.backward_signature is not None:
198            raise ValueError("Unflattening on JointExportModule NYI")
199
200        fqn_list = [entry.fqn for entry in export_module.module_call_graph]
201        assert fqn_list[0] == ""
202        export_graph = deepcopy(export_module.graph)
203        self.graph_signature = deepcopy(export_module.graph_signature)
204        self.graph = torch.fx.Graph()
205        self.module_call_graph = deepcopy(export_module.module_call_graph)
206        self.flat_args_adapter = flat_args_adapter
207        # Flag to indicate whether args have been adapted.
208        self.adapted = False
209        self._run_with_interpeter = RUN_WITH_INTERPRETER
210
211        _inplace_buffer_mutations(export_graph, self.graph_signature)
212        _outline_submodules(export_graph, self)
213
214        self.range_constraints = export_module.range_constraints
215        self.equality_constraints: List = []
216
217        # aliasing/unused param or buffer issues:
218        # in strict-mode export, dynamo export will deduplicate aliased tensors,
219        # and ignore unused tensors. For aliasing, this causes issues when some aliases
220        # are unused, and we're unable to match the placeholder node to the correct FQN.
221        # This leads to the graph signature potentially having the wrong target FQN,
222        # and downstream issues where parameters are assigned to the wrong target attribute,
223        # mismatching the relevant placeholder node in the unflattened module.
224        # To resolve this we restore (_assign_attr) all aliased/unused tensors in
225        # the state_dict as module attributes, but only keep the used tensors in the
226        # graph's forward pass (_sink_params).
227        state_dict = export_module.state_dict
228        assigned_params: Set[str] = set()  # tracking unused params
229        id_to_param: Dict[int, torch.nn.Parameter] = {}  # handling weight-sharing
230        for name in self.graph_signature.parameters:  # this loop adds used params
231            param = state_dict[name]
232            if id(param) not in id_to_param:
233                id_to_param[id(param)] = torch.nn.Parameter(
234                    param.clone(), requires_grad=param.requires_grad
235                )
236
237            _assign_attr(
238                id_to_param[id(param)],
239                self,
240                name,
241                attr_kind=_AttrKind.PARAMETER,
242            )
243            assigned_params.add(name)
244
245        non_persistent_buffers = set(self.graph_signature.non_persistent_buffers)
246        assigned_buffers: Set[str] = set()  # tracking unused buffers
247        id_to_buffer: Dict[
248            int, Tuple[torch.nn.Parameter, bool]
249        ] = {}  # handle weight-sharing
250        for name in self.graph_signature.buffers:  # this loop adds used buffers
251            if name in non_persistent_buffers:
252                persistent = False
253                buffer = export_module.constants[name]
254            else:
255                persistent = True
256                buffer = state_dict[name]
257
258            if id(buffer) not in id_to_buffer:
259                id_to_buffer[id(buffer)] = (buffer.clone(), persistent)
260
261            _assign_attr(
262                id_to_buffer[id(buffer)][0],
263                self,
264                name,
265                attr_kind=_AttrKind.BUFFER,
266                persistent=persistent,
267            )
268            assigned_buffers.add(name)
269
270        # restore aliased/unused params and buffers
271        # these appear in state dict but not graph signature
272        for name, tensor in state_dict.items():
273            if name in assigned_params or name in assigned_buffers:  # already assigned
274                continue
275
276            is_buffer = False
277            if id(tensor) in id_to_buffer or not isinstance(
278                tensor, torch.nn.Parameter
279            ):  # aliased buffer
280                is_buffer = True
281
282            if is_buffer:
283                if (
284                    id(tensor) not in id_to_buffer
285                ):  # this is completely unused (not weight-sharing)
286                    id_to_buffer[id(tensor)] = (
287                        tensor,
288                        True,
289                    )  # assign to respect original model
290                _assign_attr(
291                    id_to_buffer[id(tensor)][0],
292                    self,
293                    name,
294                    attr_kind=_AttrKind.BUFFER,
295                    persistent=True,
296                )
297            else:
298                if id(tensor) not in id_to_param:  # this is unused
299                    id_to_param[id(tensor)] = tensor
300                _assign_attr(
301                    id_to_param[id(tensor)],
302                    self,
303                    name,
304                    attr_kind=_AttrKind.PARAMETER,
305                )
306
307        # use id map so we don't double-clone aliased constants
308        id_to_const: Dict[int, Union[torch.Tensor, torch._C.ScriptObject]] = {}
309        for fqn, constant in export_module.constants.items():
310            if id(constant) not in id_to_const:
311                if isinstance(constant, torch.Tensor):
312                    constant = constant.clone()
313                id_to_const[id(constant)] = constant
314            _constant = id_to_const[id(constant)]
315            _assign_attr(
316                _constant,
317                self,
318                fqn,
319                attr_kind=_AttrKind.CONSTANT,
320            )
321
322        # This is to handle parameters/buffers that point to the same tensor
323        # object id -> list of (node_name, target_name)
324        consts_map: Dict[int, List[Tuple[str, str]]] = defaultdict(list)
325        consts_targets: Set[str] = set()
326
327        def add_to_consts_map(obj_id, node_name, target_name):
328            name_list = consts_map[obj_id]
329            name_list.append((node_name, target_name))
330
331        added_params_buffers: Set[str] = set()  # track aliased/unused params, buffers
332        for s in self.graph_signature.input_specs:
333            if s.kind == InputKind.PARAMETER or (
334                s.kind == InputKind.BUFFER and s.persistent
335            ):
336                assert hasattr(s.arg, "name")
337                assert isinstance(s.target, str)
338                add_to_consts_map(
339                    id(export_module.state_dict[s.target]), s.arg.name, s.target
340                )
341                consts_targets.add(s.target)
342                added_params_buffers.add(s.target)
343            elif (
344                (s.kind == InputKind.BUFFER and not s.persistent)
345                or s.kind == InputKind.CONSTANT_TENSOR
346                or s.kind == InputKind.CUSTOM_OBJ
347            ):
348                assert hasattr(s.arg, "name")
349                assert isinstance(s.target, str)
350                add_to_consts_map(
351                    id(export_module.constants[s.target]), s.arg.name, s.target
352                )
353                consts_targets.add(s.target)
354
355        # add constants that are aliased and don't appear in graph signature
356        for const_name, const in export_module.constants.items():
357            if const_name not in consts_targets:
358                assert (
359                    id(const) in consts_map
360                ), "Constants should be either aliased or appear in graph signature"
361                ph_name, _ = consts_map[id(const)][0]
362                add_to_consts_map(id(const), ph_name, const_name)
363                added_params_buffers.add(s.target)
364
365        # add aliased/unused params and buffers that don't appear in graph signature
366        for fqn, tensor in export_module.state_dict.items():
367            if fqn not in added_params_buffers:
368                if id(tensor) not in consts_map:
369                    # completely unused (no weight-sharing), ignore.
370                    # this weight doesn't appear in graph module,
371                    # so won't cause FQN assignment issues
372                    continue
373                ph_name, _ = consts_map[id(tensor)][0]
374                add_to_consts_map(id(tensor), ph_name, fqn)
375
376        # node name -> list of possible targets
377        inputs_to_state: Dict[str, List[str]] = {}
378        for node_target in consts_map.values():
379            targets = [t[1] for t in node_target]
380            for n, _ in node_target:
381                inputs_to_state[n] = targets
382
383        _sink_params(self, inputs_to_state, [])
384
385        # Helper function to check input nodes of `module` has been processed.
386        def check_module_inputs(module, scope):
387            if hasattr(module, "graph"):
388                for node in module.graph.nodes:
389                    # sink_params() should turn placeholders into get_attr nodes
390                    # for attributes that are within scope of the current
391                    # module. We allow attributes to remain as placeholders if
392                    # they are inputs in the original module signature, meaning
393                    # they are a parent module's attribute, and therefore out of
394                    # scope of the current module.
395                    if (
396                        node.op == "placeholder"
397                        and node.name in inputs_to_state
398                        and any(
399                            fqn.split(".")[: len(scope)] == scope
400                            for fqn in inputs_to_state[node.name]
401                        )  # matching scope to avoid wrong assert
402                    ):
403                        raise AssertionError(
404                            f"{node.name} was not sunk into the module {scope} which has the graph: {module.graph}"
405                        )
406            # Recursively check the submodules.
407            for name, submod in module.named_children():
408                scope.append(name)
409                check_module_inputs(submod, scope)
410
411        # Recurively check all input nodes have been processed.
412        check_module_inputs(self, [])
413
414        # Cache so we don't have to compute this every time.
415        # NOTE: this needs to be kept in sync with the placeholders in
416        # self.graph, but currently we have no way to guarantee that.
417        self.input_placeholders = [
418            node for node in self.graph.nodes if node.op == "placeholder"
419        ]
420        self.check_input_constraints = True
421        # TODO(zhxchen17) We can register modules ahead of time instead of reorder later.
422        fqn_order = {fqn: i for i, fqn in enumerate(fqn_list)}
423        # In the case of legacy IR, we might be missing some modules from metadata.
424        for name, _ in self.named_modules(remove_duplicate=False):
425            if name not in fqn_order:
426                fqn_order[name] = len(fqn_order)
427        _reorder_submodules(self, fqn_order)
428        assert [fqn for fqn, _ in self.named_modules(remove_duplicate=False)] == list(
429            fqn_order.keys()
430        )
431        self.graph.lint()
432
433    def _print_graph(self):
434        for fqn, mod in self.named_modules():
435            print(fqn + ":")
436            if hasattr(mod, "graph") and isinstance(mod.graph, torch.fx.Graph):
437                print(mod.graph)
438
439    def forward(self, *args, **kwargs):
440        signature = self.module_call_graph[0].signature
441
442        reordered_kwargs = reorder_kwargs(kwargs, signature.in_spec)
443
444        flat_args_with_path, in_spec = pytree.tree_flatten_with_path(
445            (args, reordered_kwargs)
446        )
447        flat_args = [x[1] for x in flat_args_with_path]
448        if is_fx_tracing():
449            return_val = torch.fx.Interpreter(self, graph=self.graph).run(
450                *flat_args, enable_io_processing=False
451            )
452            # For scalar return value, fx.Graph wraps in a tuple
453            if isinstance(return_val, tuple) and len(return_val) == 1:
454                return return_val[0]
455            return return_val
456
457        if in_spec != signature.in_spec:
458            if not self.adapted:
459                print(
460                    "Input treespec does not match with exported module's: \n"
461                    f"Input treespec: {in_spec}. ",
462                    f"Exported module treespec: {signature.in_spec}",
463                )
464            if self.flat_args_adapter is None:
465                raise TypeError(
466                    "There is no flat args adapter sepcified. "
467                    "Are you sure you are calling this with the right arguments? "
468                )
469            else:
470                if not self.adapted:
471                    print("Adapting flat arg to match exported module's treespec")
472                flat_args = self.flat_args_adapter.adapt(
473                    target_spec=signature.in_spec,
474                    input_spec=in_spec,
475                    input_args=flat_args,
476                )
477                self.adapted = True
478                if len(flat_args) != signature.in_spec.num_leaves:
479                    raise TypeError(
480                        f"Flat args adaption failed, number of args mismatch "
481                        f"Adatped: {len(flat_args)} \n"
482                        f"Exported module: {signature.in_spec.num_leaves}"
483                    )
484
485        if self.check_input_constraints:
486            # Import here to avoid an unfortunate circular dependency.
487            # TODO(suo): untangle this.
488            from torch._export.utils import _check_input_constraints_for_graph
489
490            if self.adapted is True:
491                # TODO(suo): The FlatArgsAdapter returns a list of flat args,
492                # which we don't have keypaths for. For now, just create a dummy
493                # keypath to associate with the arg.
494                new_flat_args_with_path = [  # type: ignore[var-annotated]
495                    ((SequenceKey(idx=0), GetAttrKey(name="<unknown location>")), arg)
496                    for arg in flat_args
497                ]
498            else:
499                new_flat_args_with_path = flat_args_with_path  # type: ignore[assignment]
500
501            _check_input_constraints_for_graph(
502                self.input_placeholders, new_flat_args_with_path, self.range_constraints
503            )
504        if torch.compiler.is_dynamo_compiling() and not self._run_with_interpreter:
505            tree_out = torch.fx.GraphModule(self, self.graph)(*flat_args)
506        else:
507            tree_out = torch.fx.Interpreter(self, graph=self.graph).run(
508                *flat_args, enable_io_processing=False
509            )
510        return pytree.tree_unflatten(tree_out, signature.out_spec)
511
512    def print_readable(
513        self,
514        print_output=True,
515        include_stride=False,
516        include_device=False,
517        colored=False,
518    ):
519        return _print_readable(
520            self,
521            "UnflattenedModule",
522            print_output,
523            include_stride,
524            include_device,
525            colored,
526        )
527
528
529def unflatten(
530    module: ExportedProgram, flat_args_adapter: Optional[FlatArgsAdapter] = None
531) -> UnflattenedModule:
532    """Unflatten an ExportedProgram, producing a module with the same module
533    hierarchy as the original eager module. This can be useful if you are trying
534    to use :mod:`torch.export` with another system that expects a module
535    hierachy instead of the flat graph that :mod:`torch.export` usually produces.
536
537    .. note:: The args/kwargs of unflattened modules will not necessarily match
538        the eager module, so doing a module swap (e.g. :code:`self.submod =
539        new_mod`) will not necessarily work. If you need to swap a module out, you
540        need to set the :code:`preserve_module_call_signature` parameter of
541        :func:`torch.export.export`.
542
543    Args:
544        module (ExportedProgram): The ExportedProgram to unflatten.
545        flat_args_adapter (Optional[FlatArgsAdapter]): Adapt flat args if input TreeSpec does not match with exported module's.
546
547    Returns:
548        An instance of :class:`UnflattenedModule`, which has the same module
549        hierarchy as the original eager module pre-export.
550    """
551    module = _remove_effect_tokens(module)
552    return UnflattenedModule(module, flat_args_adapter)
553
554
555def _inplace_buffer_mutations(graph: torch.fx.Graph, graph_signature) -> None:
556    """Transform buffer mutations from their functionalized form into a copy_
557    node in the graph.
558
559    Functionalization represents buffer mutation by passing the buffer as an input and output. So for example, the eager code:
560        def forward(self, x):
561            self.buffer += x
562            return x * x
563
564    Will become a graph that looks like:
565        def forward(self, buffer, x):
566            mutated_buffer = aten.add(buffer, x)
567            mul = aten.mul(x, x)
568            return (mutated_buffer, mul)
569
570    We want to inplace this into something that looks like the original eager code:
571        def forward(self, buffer, x):
572            mutated_buffer = aten.add(buffer, x)
573            buffer.copy_(mutated_buffer)
574            mul = aten.mul(x, x)
575            return (mul,)
576    """
577    output_node = next(iter(reversed(graph.nodes)))
578    assert output_node.op == "output" and len(output_node.args) == 1
579    return_args = output_node.args[0]
580
581    mutation_node_to_buffer = graph_signature.buffers_to_mutate
582    mutations = return_args[: len(mutation_node_to_buffer)]
583    buffers_to_inputs = {v: k for k, v in graph_signature.inputs_to_buffers.items()}
584    input_name_to_node = {
585        node.name: node for node in graph.nodes if node.op == "placeholder"
586    }
587
588    for mutation in mutations:
589        buffer_name = mutation_node_to_buffer[mutation.name]
590        input_name = buffers_to_inputs[buffer_name]
591        input_node = input_name_to_node[input_name]
592
593        with graph.inserting_after(mutation):
594            new_node = graph.create_node(
595                "call_function", torch.ops.aten.copy_, (input_node, mutation)
596            )
597            for k, v in mutation.meta.items():
598                new_node.meta[k] = v
599        # Replace all uses of the previously functional mutation with our copy_ output.
600        mutation.replace_all_uses_with(new_node, lambda x: x is not new_node)
601
602    # Remove the mutated buffer from the graph outputs, since we don't need to
603    # thread it through anymore. We don't need to handle the inputs, which will
604    # be handled by _sink_params.
605    user_outputs = tuple(
606        return_args[len(mutation_node_to_buffer) :],
607    )
608    output_node.args = ((user_outputs),)
609
610
611def _is_prefix(candidate, target):
612    """Check whether `candidate` is a prefix of `target`."""
613    return len(candidate) < len(target) and target[: len(candidate)] == candidate
614
615
616def _compute_accessor(parent_fqn: str, child_fqn: str) -> str:
617    if parent_fqn == "":
618        # Handle the root module correctly.
619        return child_fqn
620
621    parent_split = parent_fqn.split(".")
622    child_split = child_fqn.split(".")
623
624    # TODO: support skip connection by inlining the child module.
625    if child_split[: len(parent_split)] != parent_split:
626        raise RuntimeError(
627            f"Child module '{child_fqn}' is not a descendant of parent mldule '{parent_fqn}'."
628            "This is currently unsupported."
629            "Please try to make child module attach to parent module direclty."
630        )
631    return ".".join(child_split[len(parent_split) :])
632
633
634def _verify_graph_equivalence(x: torch.nn.Module, y: torch.nn.Module):
635    def graph_dump(graph: torch.fx.Graph) -> str:
636        ret = []
637        nodes_idx: Dict[int, int] = {}
638
639        def arg_dump(arg) -> str:
640            if isinstance(arg, torch.fx.Node):
641                return "%" + str(nodes_idx[id(arg)])
642            return str(arg)
643
644        for i, node in enumerate(graph.nodes):
645            args_dump = [str(arg) for arg in pytree.tree_map(arg_dump, node.args)]
646            args_dump += [
647                f"{key}={value}"
648                for key, value in pytree.tree_map(arg_dump, node.kwargs).items()
649            ]
650            target = node.target if node.op == "call_function" else ""
651            ret.append(f"{i}: {node.op}[{target}]({', '.join(args_dump)})")
652            nodes_idx[id(node)] = i
653        return "\n".join(ret)
654
655    assert graph_dump(x.graph) == graph_dump(y.graph)
656
657
658def _add_spec(gm: torch.nn.Module, spec) -> str:
659    i = 0
660    while hasattr(gm, f"_spec_{i}"):
661        i += 1
662    name = f"_spec_{i}"
663    setattr(gm, name, spec)
664    return name
665
666
667def _generate_flatten(gm: torch.nn.Module, node, spec) -> torch.fx.Node:
668    name = _add_spec(gm, spec)
669    spec_node = gm.graph.get_attr(name)
670    return gm.graph.call_function(fx_pytree.tree_flatten_spec, (node, spec_node))
671
672
673def _generate_unflatten(gm: torch.nn.Module, nodes, spec) -> torch.fx.Node:
674    name = _add_spec(gm, spec)
675    spec_node = gm.graph.get_attr(name)
676    return gm.graph.call_function(pytree.tree_unflatten, (nodes, spec_node))
677
678
679def _get_submodule(mod: torch.nn.Module, target: str):
680    *prefix, field = target.split(".")
681
682    for item in prefix:
683        submod = getattr(mod, item, None)
684
685        if submod is None:
686            return None
687
688        if not isinstance(submod, torch.nn.Module):
689            return None
690
691        mod = submod
692
693    return getattr(mod, field, None)
694
695
696def _add_submodule(mod: torch.nn.Module, target: str, module_to_add: torch.nn.Module):
697    *prefix, field = target.split(".")
698
699    for item in prefix:
700        submod = getattr(mod, item, None)
701
702        if submod is None:
703            submod = torch.nn.Module()
704            setattr(mod, item, submod)
705
706        if not isinstance(submod, torch.nn.Module):
707            return False
708
709        mod = submod
710
711    mod.add_module(field, module_to_add)
712
713
714class _ModuleFrame:
715    def __init__(
716        self,
717        flat_graph: torch.fx.Graph,
718        nodes: Tuple[torch.fx.Node, ...],
719        seen_nodes,
720        seen_modules,
721        parent,
722        module_stack: List[str],
723        module_id,
724        module_call_graph: Dict[str, ModuleCallSignature],
725        module: Optional[torch.nn.Module] = None,
726    ):
727        self.flat_graph = flat_graph
728        self.nodes = nodes
729        self.seen_nodes = seen_nodes
730        self.seen_modules = seen_modules
731        self.parent = parent
732        self.module_stack = module_stack
733        self.module_id = module_id
734
735        self.module_call_graph = module_call_graph
736        self.verbose = False
737
738        self.fqn = self.module_stack[-1]
739        if module is not None:
740            self.module = module
741        else:
742            self.module = InterpreterModule(torch.fx.Graph())
743        if self.module_id in self.seen_modules:
744            self.cached_graph_module = self.seen_modules[self.module_id]
745        else:
746            self.cached_graph_module = None
747            self.seen_modules[self.module_id] = self.module
748
749        self.graph = self.module.graph
750
751        # Mapping of nodes in the flat graph to nodes in this graph.
752        self.node_map: Dict[torch.fx.Node, torch.fx.Node] = {}
753        self.node_to_placeholder = {}
754
755        self.parent_call_module: Optional[torch.fx.Node] = None
756        if parent is not None:
757            accessor = _compute_accessor(parent.fqn, self.fqn)
758            _add_submodule(
759                parent.module,
760                accessor,
761                (
762                    self.module
763                    if self.cached_graph_module is None
764                    else self.cached_graph_module
765                ),
766            )
767            self.parent_call_module = parent.graph.call_module(accessor)
768
769        signature = module_call_graph.get(self.fqn)
770        if signature is not None and self.parent is not None:
771            assert signature.in_spec.num_children == 2
772            args_spec = signature.in_spec.children_specs[0]
773            kwargs_spec = signature.in_spec.children_specs[1]
774            assert args_spec.context is None
775            assert kwargs_spec.context is not None
776
777            with self.graph.inserting_after(None):
778                arg_nodes = []
779                for idx in range(args_spec.num_children):
780                    arg_nodes.append(self.graph.placeholder(f"_positional_arg_{idx}"))
781                kwarg_nodes = {}
782                for name in kwargs_spec.context:
783                    kwarg_nodes[name] = self.graph.placeholder(name)
784                flat_args = _generate_flatten(
785                    self.module,
786                    (tuple(arg_nodes), kwarg_nodes),
787                    signature.in_spec,
788                )
789                for idx, arg in enumerate(signature.inputs):
790                    flat_arg_node = self.graph.create_node(
791                        op="call_function",
792                        target=operator.getitem,
793                        args=(flat_args, idx),
794                        name=(
795                            arg.name
796                            if not isinstance(arg, ConstantArgument)
797                            else f"_constant_{idx}"
798                        ),
799                    )
800                    if isinstance(arg, ConstantArgument):
801                        continue
802
803                    if arg.name in self.seen_nodes:
804                        flat_arg_node.meta = copy.copy(self.seen_nodes[arg.name].meta)
805                        self.node_to_placeholder[
806                            self.seen_nodes[arg.name]
807                        ] = flat_arg_node
808
809            with self.parent.graph.inserting_before(self.parent_call_module):
810                input_nodes: List[Optional[torch.fx.Node]] = []
811                for input in signature.inputs:
812                    if isinstance(input, ConstantArgument) and input.value is None:
813                        input_nodes.append(None)
814                    elif input.name not in self.seen_nodes:
815                        input_nodes.append(None)
816                    else:
817                        assert isinstance(input, (TensorArgument, SymIntArgument))
818                        input_nodes.append(
819                            self.parent.remap_input(self.seen_nodes[input.name])
820                        )
821
822                inputs_node = _generate_unflatten(
823                    self.parent.module,
824                    input_nodes,
825                    signature.in_spec,
826                )
827
828                args_node = self.parent.graph.call_function(
829                    operator.getitem, (inputs_node, 0)
830                )
831                kwargs_node = self.parent.graph.call_function(
832                    operator.getitem, (inputs_node, 1)
833                )
834                arg_nodes = [
835                    self.parent.graph.call_function(operator.getitem, (args_node, i))
836                    for i in range(args_spec.num_children)
837                ]
838                kwarg_nodes = {
839                    k: self.parent.graph.call_function(
840                        operator.getitem, (kwargs_node, k)
841                    )
842                    for k in kwargs_spec.context
843                }
844            assert self.parent_call_module is not None
845            self.parent_call_module.args = tuple(arg_nodes)
846            self.parent_call_module.kwargs = kwarg_nodes
847
848    def add_placeholder(self, x):
849        assert self.fqn != "", f"Cannot add placeholder {x} to root module"
850        assert x.graph is self.flat_graph
851        # x is not in subgraph, create a new placeholder for subgraph
852        with self.graph.inserting_before(None):
853            placeholder_node = self.graph.placeholder(x.name, type_expr=x.type)
854        # copy all meta fields, even if some fields might be irrelvant for
855        # the placeholder node
856        placeholder_node.meta = copy.copy(x.meta)
857        self.node_to_placeholder[x] = placeholder_node
858
859    def copy_sym_call_function(self, x):
860        # This only exists because we deduplicate sym_size nodes in the flat export graph,
861        # and if preserve_module_call_signature is set, we may not be able to pass sym_size
862        # nodes, or their downstream users, as inputs to submodule calls.
863        # To avoid this we copy these call_function nodes with sym_type results.
864        # This should however only be done for sym_type nodes - call_function nodes on tensors
865        # should not be deduplicated in the first place.
866        args = pytree.tree_map_only(torch.fx.Node, self.remap_input, x.args)
867        kwargs = pytree.tree_map_only(torch.fx.Node, self.remap_input, x.kwargs)
868        node = self.graph.call_function(x.target, args, kwargs)
869        node.meta = copy.copy(x.meta)
870        self.node_map[x] = node
871        return node
872
873    def remap_input(self, x):
874        assert x.graph is self.flat_graph
875        if x in self.node_map:
876            return self.node_map[x]
877        self.print(f"remap_input({x})")
878        if x in self.node_to_placeholder:
879            return self.node_to_placeholder[x]
880        elif (
881            x.op == "placeholder"
882            or self.module_call_graph.get(self.fqn) is None
883            # allow placeholder creation if we are not preserving module call signature
884        ):
885            self.add_placeholder(x)
886            if self.parent_call_module is not None:
887                # Important to *prepend* the output to match how we are
888                # inserting placeholder nodes.
889                with self.parent.graph.inserting_before(self.parent_call_module):
890                    self.parent_call_module.insert_arg(0, self.parent.remap_input(x))
891            return self.node_to_placeholder[x]
892        elif x.op == "call_function" and (
893            x.target
894            in (
895                torch.ops.aten.sym_size.int,
896                torch.ops.aten.item.default,
897                torch.ops.aten.unbind.int,
898                torch.ops.aten.sum.dim_IntList,
899                torch.ops.aten.view.default,
900                torch.ops.aten.diff.default,
901            )
902            or (hasattr(x.target, "__module__") and x.target.__module__ == "_operator")
903        ):
904            # export deduplicates sym_size nodes, and may need to re-copy them
905            # if module call signature needs to be preserved
906            self.copy_sym_call_function(x)
907            return self.node_map[x]
908        else:
909            raise RuntimeError(
910                f"Could not run remap_input() on op type: {x.op} for node {x}"
911            )
912
913    def finalize_outputs(self):
914        orig_outputs = []
915
916        signature = self.module_call_graph.get(self.fqn)
917        if signature is not None and self.parent is not None:
918            for output in signature.outputs:
919                if isinstance(output, (TensorArgument, SymIntArgument)):
920                    if output.name in self.seen_nodes:
921                        orig_outputs.append(self.seen_nodes[output.name])
922                    else:
923                        orig_outputs.append(None)
924                else:
925                    raise RuntimeError(
926                        f"Unsupported data type for output node: {output}"
927                    )
928
929            def get_actual_output_node(output):
930                if output is None:
931                    return None
932
933                seen_node = self.seen_nodes[output.name]
934                if seen_node in self.node_map:
935                    return self.node_map[seen_node]
936                elif seen_node in self.node_to_placeholder:
937                    return self.node_to_placeholder[seen_node]
938                else:
939                    raise RuntimeError(
940                        f"Could not find output node {output}. Graph: {self.graph}"
941                    )
942
943            tree_out_node = _generate_unflatten(
944                self.module,
945                tuple(get_actual_output_node(output) for output in orig_outputs),
946                signature.out_spec,
947            )
948            parent_out: Optional[torch.fx.Node] = _generate_flatten(
949                self.parent.module, self.parent_call_module, signature.out_spec
950            )
951            graph_outputs: Union[torch.fx.Node, List[torch.fx.Node]] = tree_out_node
952        else:
953            graph_outputs = []
954            # Iterate through nodes we have copied into self.graph.
955            for orig_node in self.node_map.keys():
956                for user_node in orig_node.users:
957                    if user_node.name not in self.seen_nodes:
958                        # external user node, need to expose as an output
959                        orig_outputs.append(orig_node)
960                        graph_outputs.append(self.node_map[orig_node])
961                        break
962
963            parent_out = self.parent_call_module
964            if len(graph_outputs) == 1:
965                graph_outputs = graph_outputs[0]
966
967        assert isinstance(graph_outputs, (list, torch.fx.Node))
968
969        self.graph.output(graph_outputs)
970
971        # Rewrite outputs in parent module
972        if parent_out is None:
973            return
974
975        parent_out.meta["val"] = (
976            graph_outputs.meta.get("val")
977            if isinstance(graph_outputs, torch.fx.Node)
978            else [o.meta.get("val") for o in graph_outputs]
979        )
980
981        if len(orig_outputs) == 1 and signature is None:
982            self.parent.node_map[orig_outputs[0]] = parent_out
983        else:
984            for i, orig_output in enumerate(orig_outputs):
985                if orig_output is None:
986                    continue
987                # Use Proxy to record getitem access.
988                proxy_out = torch.fx.Proxy(parent_out)[i].node  # type: ignore[index]
989                proxy_out.meta["val"] = orig_output.meta.get("val")
990                self.parent.node_map[orig_output] = proxy_out
991
992        if self.cached_graph_module is not None:
993            _verify_graph_equivalence(self.cached_graph_module, self.module)
994
995    def copy_node(self, node):
996        self.print("copying", node.format_node())
997        self.node_map[node] = self.graph.node_copy(node, self.remap_input)
998        self.seen_nodes[node.name] = node
999
1000    def run_outer(self):
1001        i = 0
1002        for node in self.flat_graph.nodes:
1003            self.print(i, node.meta.get("nn_module_stack"), node.format_node())
1004            i += 1
1005
1006        # Copy all graph inputs
1007        node_idx: int = 0
1008        node = self.nodes[node_idx]
1009        while node.op == "placeholder":
1010            self.copy_node(node)
1011            node_idx += 1
1012            node = self.nodes[node_idx]
1013
1014        self.run_from(node_idx)
1015
1016        # Copy graph outputs
1017        for node in self.flat_graph.nodes:
1018            if node.op == "output":
1019                self.copy_node(node)
1020
1021    def print(self, *args, **kwargs):
1022        if self.verbose:
1023            print(*args, **kwargs)
1024
1025    def run_from(self, node_idx):
1026        module_idx = 0
1027        # Walk through the graph, building up a new graph with the right submodules
1028        while node_idx < len(self.nodes):
1029            node = self.nodes[node_idx]
1030            assert node.op != "placeholder"
1031
1032            self.print()
1033            self.print("STEP", node_idx, node.format_node())
1034            self.print(self.module_stack)
1035            if node.op == "output":
1036                if len(self.module_stack) == 1:
1037                    # We want the output node of the original graph to be handled
1038                    # specially by the outermost stack frame (in run_outer). So
1039                    # skip finalization here.
1040                    return node_idx
1041
1042                # We've reached the end of the graph. Wrap up all the existing stack frames.
1043                self.finalize_outputs()
1044                return node_idx
1045
1046            if len(node.meta.get("nn_module_stack", {})) == 0:
1047                raise RuntimeError(f"Unable to find nn_module_stack for node {node}")
1048
1049            nn_module_stack = node.meta["nn_module_stack"]
1050            from torch._export.passes._node_metadata_hook import (
1051                _EMPTY_NN_MODULE_STACK_KEY,
1052            )
1053
1054            if (
1055                len(nn_module_stack) == 1
1056                and _EMPTY_NN_MODULE_STACK_KEY in nn_module_stack
1057            ):
1058                # Empty case from the node_metadata_hook
1059                node_module_stack = self.module_stack
1060            else:
1061                node_module_stack = [
1062                    path for path, ty in node.meta["nn_module_stack"].values()
1063                ]
1064
1065            if node_module_stack[: len(self.module_stack)] != self.module_stack:
1066                # This means that the current module is done executing and the
1067                # current node is the beginning of a new module.
1068                #
1069                # In this case, we should finalize this module and return without
1070                # incrementing the node counter.
1071                self.finalize_outputs()
1072                self.print("outlining", self.fqn)
1073                self.print(self.graph)
1074                return node_idx
1075
1076            assert node_module_stack is not None
1077
1078            if _is_prefix(self.module_stack, node_module_stack):
1079                # This means that the current node represents the execution of a new
1080                # module.
1081                next_module = node_module_stack[len(self.module_stack)]
1082                self.print("Creating new stack frame for", next_module)
1083                # Run a nested version of module outliner from the current node
1084                # counter. Once it is complete, continue from that point.
1085                node_idx = _ModuleFrame(
1086                    self.flat_graph,
1087                    self.nodes,
1088                    self.seen_nodes,
1089                    self.seen_modules,
1090                    self,
1091                    self.module_stack + [next_module],
1092                    list(node.meta["nn_module_stack"].keys())[len(self.module_stack)],
1093                    self.module_call_graph,
1094                ).run_from(node_idx)
1095                module_idx += 1
1096                continue
1097
1098            # The only remaining possibility is that we are in the right stack
1099            # frame. Copy the node into this frame's graph and increment the node counter.
1100            assert node_module_stack == self.module_stack
1101            self.copy_node(node)
1102            node_idx += 1
1103
1104
1105def _outline_submodules(orig_graph: torch.fx.Graph, root_module: UnflattenedModule):
1106    seen_nodes: Dict[str, torch.fx.Node] = {}
1107    seen_modules: Dict[int, torch.nn.Module] = {}
1108    _ModuleFrame(
1109        orig_graph,
1110        tuple(orig_graph.nodes),
1111        seen_nodes,
1112        seen_modules,
1113        None,
1114        [""],
1115        "",
1116        {
1117            entry.fqn: entry.signature
1118            for entry in root_module.module_call_graph
1119            if entry.signature
1120        },
1121        module=root_module,
1122    ).run_outer()
1123
1124
1125def _reorder_submodules(
1126    parent: torch.nn.Module, fqn_order: Dict[str, int], prefix: str = ""
1127):
1128    # TODO Can be optimized by adding submodules ahead of time.
1129    if prefix == "":
1130        for fqn in list(fqn_order.keys())[1:]:
1131            if _get_submodule(parent, fqn) is None:
1132                _add_submodule(parent, fqn, torch.nn.Module())
1133
1134    children = []
1135    for name, child in list(parent._modules.items()):
1136        if child is None:
1137            continue
1138        fqn = prefix + name
1139        _reorder_submodules(child, fqn_order, prefix=fqn + ".")
1140        delattr(parent, name)
1141        children.append((fqn_order[fqn], name, child))
1142    children.sort(key=operator.itemgetter(0))
1143    for _, name, child in children:
1144        parent.register_module(name, child)
1145
1146
1147def _sink_params(
1148    module: torch.nn.Module,
1149    inputs_to_state: Dict[str, List[str]],
1150    scope: List[str],
1151):
1152    """Sink params, buffers, and constants from graph inputs into get_attr nodes.
1153
1154    Exported modules are purely functional, so they pass their parameters and
1155    buffers in as inputs to the graph.
1156
1157    To replicate eager's semantics, we need to get them from the module state
1158    via get_attr instead.
1159
1160    module: GraphModule, potentially containining nested submodules.
1161    inputs_to_state: mapping graph input names to the corresponding key in the state_dict.
1162    scope: tracks where we are in the module hierarchy, so that we can emit the
1163        right `getattr(self, "foo.bar")` calls, etc.
1164    """
1165    # This dict records inputs removed by child modules.
1166    # Maps the module object id to the list of placeholder node names
1167    # in the child module that were removed.
1168    module_id_to_inputs_removed: Dict[int, List[str]] = defaultdict(list)
1169
1170    # We need to use _modules here instead of named_children(), because we
1171    # explicitly want duplicate modules to show up in the traversal.
1172    for name, submodule in module._modules.items():
1173        submod_id_to_inputs_removed = _sink_params(
1174            cast(torch.nn.Module, submodule), inputs_to_state, scope + [name]
1175        )
1176        for k, v in submod_id_to_inputs_removed.items():
1177            module_id_to_inputs_removed[k].extend(v)
1178
1179    if not hasattr(module, "graph"):
1180        # Not all modules have graphs defined, if they are empty modules with no operations (like ParameterList)
1181        return module_id_to_inputs_removed
1182
1183    graph = module.graph
1184    inputs = list(filter(lambda n: n.op == "placeholder", graph.nodes))
1185    the_last_input = inputs[-1]
1186
1187    # Also remove from call_module nodes
1188    call_module_nodes = filter(lambda n: n.op == "call_module", graph.nodes)
1189    for node in call_module_nodes:
1190        submodule = _recursive_getattr(module, node.target.split("."))
1191        # remove placeholder from call_module node arguments, only if we've
1192        # erased the placeholder node in the corresponding _sink_params() call
1193        if submodule is not None and id(submodule) in module_id_to_inputs_removed:
1194            node.args = tuple(
1195                filter(
1196                    lambda n: n.name not in module_id_to_inputs_removed[id(submodule)],
1197                    node.args,
1198                )
1199            )
1200
1201    # Filter out inputs_to_state corresponding to current scope.
1202    inputs_to_state_of_scope: Dict[torch.fx.Node, list[str]] = {}
1203    for node in inputs:
1204        if node.name not in inputs_to_state:
1205            continue
1206
1207        state_name = None
1208        for sn in inputs_to_state[node.name]:
1209            sn_split = sn.split(".")
1210            if sn_split[: len(scope)] == scope:
1211                state_name = sn_split
1212                break
1213
1214        # If there's a mismatch beteewn scope name and state name, then
1215        # there must be multuple scopes pointing to the same state name,
1216        # meaning some modules are shared. In such case, we can simply skip
1217        # updating the current node because another later iteration will
1218        # take care of this input node when the unique match between scope
1219        # and state name occurs.  To make sure this always happen, we should
1220        # enforce the invariant that no placeholder node in the unflattened
1221        # graph appears in inputs_to_state dict, which means all the extra
1222        # input nodes have been handled.
1223        if state_name is None:
1224            continue
1225
1226        inputs_to_state_of_scope[node] = state_name
1227
1228    # Record name of remove inputs for return purpose.
1229    inputs_removed: List[str] = []
1230
1231    for node, state_name in inputs_to_state_of_scope.items():
1232        if len(node.users) > 0:
1233            attr_path = state_name[len(scope) :]
1234            state_attr = _recursive_getattr(module, attr_path)
1235            assert isinstance(state_attr, (torch.Tensor, torch.ScriptObject))
1236
1237            # Make sure the newly created get_attr node is placed after the last placeholder node
1238            with graph.inserting_after(the_last_input):
1239                new_node = graph.create_node("get_attr", ".".join(attr_path))
1240
1241            node.replace_all_uses_with(new_node, propagate_meta=True)
1242
1243        graph.erase_node(node)
1244        inputs_removed.append(node.name)
1245
1246    if isinstance(module, InterpreterModule):
1247        module.finalize()
1248
1249    return {id(module): inputs_removed}
1250
1251
1252def _recursive_getattr(obj, attr_path):
1253    for attr in attr_path:
1254        if not hasattr(obj, attr):
1255            return None
1256        obj = getattr(obj, attr)
1257
1258    return obj
1259