1import copy 2import typing 3 4import torch 5from torch.export.exported_program import _decompose_exported_program 6 7 8def _copy_graph_module_and_signature( 9 ep: torch.fx.GraphModule, 10) -> typing.Tuple[ 11 torch.fx.GraphModule, torch.export.graph_signature.ExportGraphSignature 12]: 13 # copy.deepcopy lets the objects override __deepcopy__ methods with graph_copy() and node_copy(), 14 # and this can break placeholder names in some particular cases. 15 # For example, node copying will avoid Python keywords like 'input', suffixing and renaming to 'input_1'. 16 # So we manually overwrite placeholder names by reading the old graph. 17 gm = copy.deepcopy(ep.graph_module) 18 new_graph_signature = copy.deepcopy(ep.graph_signature) 19 20 # iterate over old/new graph modules 21 for old_gm, new_gm in zip(ep.graph_module.modules(), gm.modules()): 22 old_phs = [node for node in old_gm.graph.nodes if node.op == "placeholder"] 23 new_phs = [node for node in new_gm.graph.nodes if node.op == "placeholder"] 24 # iterate over placeholders 25 assert len(old_phs) == len(new_phs) 26 for old_node, new_node in zip(old_phs, new_phs): 27 new_node.name = old_node.name 28 29 return gm, new_graph_signature 30 31 32def _remove_detach_pass( 33 gm: torch.fx.GraphModule, sig: torch.export.graph_signature.ExportGraphSignature 34) -> None: 35 with gm._set_replace_hook(sig.get_replace_hook()): 36 for node in list(reversed(gm.graph.nodes)): 37 if node.op != "call_function": 38 continue 39 if ( 40 node.target == torch.ops.aten.detach.default 41 and len(node.users) == 1 42 and next(iter(node.users)).target == torch.ops.aten.detach.default 43 ): 44 next(iter(node.users)).replace_all_uses_with(node) 45 46 gm.graph.eliminate_dead_code() 47 gm.recompile() 48 49 50def _export_forward_backward( 51 ep: torch.export.ExportedProgram, joint_loss_index: int = 0 52) -> torch.export.ExportedProgram: 53 """ 54 WARNING: This API is highly unstable and will be subject to change in the future. 55 """ 56 from torch._decomp import core_aten_decompositions 57 58 ep = _decompose_exported_program( 59 ep, 60 decomp_table=core_aten_decompositions(), 61 _preserve_ops=(), # type: ignore[arg-type] 62 joint_loss_index=joint_loss_index, 63 ) 64 gm, new_graph_signature = _copy_graph_module_and_signature(ep) 65 _remove_detach_pass(gm, new_graph_signature) 66 67 return ep._update(gm, new_graph_signature) 68