xref: /aosp_15_r20/external/pytorch/torch/export/experimental/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import copy
2import typing
3
4import torch
5from torch.export.exported_program import _decompose_exported_program
6
7
8def _copy_graph_module_and_signature(
9    ep: torch.fx.GraphModule,
10) -> typing.Tuple[
11    torch.fx.GraphModule, torch.export.graph_signature.ExportGraphSignature
12]:
13    # copy.deepcopy lets the objects override __deepcopy__ methods with graph_copy() and node_copy(),
14    # and this can break placeholder names in some particular cases.
15    # For example, node copying will avoid Python keywords like 'input', suffixing and renaming to 'input_1'.
16    # So we manually overwrite placeholder names by reading the old graph.
17    gm = copy.deepcopy(ep.graph_module)
18    new_graph_signature = copy.deepcopy(ep.graph_signature)
19
20    # iterate over old/new graph modules
21    for old_gm, new_gm in zip(ep.graph_module.modules(), gm.modules()):
22        old_phs = [node for node in old_gm.graph.nodes if node.op == "placeholder"]
23        new_phs = [node for node in new_gm.graph.nodes if node.op == "placeholder"]
24        # iterate over placeholders
25        assert len(old_phs) == len(new_phs)
26        for old_node, new_node in zip(old_phs, new_phs):
27            new_node.name = old_node.name
28
29    return gm, new_graph_signature
30
31
32def _remove_detach_pass(
33    gm: torch.fx.GraphModule, sig: torch.export.graph_signature.ExportGraphSignature
34) -> None:
35    with gm._set_replace_hook(sig.get_replace_hook()):
36        for node in list(reversed(gm.graph.nodes)):
37            if node.op != "call_function":
38                continue
39            if (
40                node.target == torch.ops.aten.detach.default
41                and len(node.users) == 1
42                and next(iter(node.users)).target == torch.ops.aten.detach.default
43            ):
44                next(iter(node.users)).replace_all_uses_with(node)
45
46    gm.graph.eliminate_dead_code()
47    gm.recompile()
48
49
50def _export_forward_backward(
51    ep: torch.export.ExportedProgram, joint_loss_index: int = 0
52) -> torch.export.ExportedProgram:
53    """
54    WARNING: This API is highly unstable and will be subject to change in the future.
55    """
56    from torch._decomp import core_aten_decompositions
57
58    ep = _decompose_exported_program(
59        ep,
60        decomp_table=core_aten_decompositions(),
61        _preserve_ops=(),  # type: ignore[arg-type]
62        joint_loss_index=joint_loss_index,
63    )
64    gm, new_graph_signature = _copy_graph_module_and_signature(ep)
65    _remove_detach_pass(gm, new_graph_signature)
66
67    return ep._update(gm, new_graph_signature)
68