1# mypy: allow-untyped-defs 2# Copyright (c) Meta Platforms, Inc. and affiliates. 3# All rights reserved. 4# 5# This source code is licensed under the BSD-style license found in the 6# LICENSE file in the root directory of this source tree. 7 8 9import torch 10from torch._higher_order_ops.auto_functionalize import ( 11 auto_functionalized, 12 auto_functionalized_v2, 13) 14from torch._inductor.fx_passes.post_grad import decompose_auto_functionalized 15from torch.export import ExportedProgram 16 17 18def remove_self_clone(graph: torch.fx.Graph): 19 for node in graph.nodes: 20 if node.target == torch.ops.aten.copy_.default and node.args[0] == node.args[1]: 21 node.replace_all_uses_with(node.args[0]) 22 graph.erase_node(node) 23 24 25def unsafe_remove_auto_functionalized_pass( 26 ep: ExportedProgram, 27) -> ExportedProgram: 28 """ 29 This pass removes an instances of the higher order op 'auto_functionalized', 30 and modifies the calling EP inplace to have the original mutator op. 31 This pass doesn't perform safety checks to make sure that this inplace mutation is safe. 32 """ 33 34 with ep.graph_module._set_replace_hook(ep.graph_signature.get_replace_hook()): 35 for module in ep.graph_module.modules(): 36 if not isinstance(module, torch.fx.GraphModule): 37 continue 38 for node in ep.graph.nodes: 39 if ( 40 node.op == "call_function" and node.target is auto_functionalized 41 ) or ( 42 node.op == "call_function" and node.target is auto_functionalized_v2 43 ): 44 func = node.args[0] 45 assert isinstance(func, torch._ops.OpOverload) 46 # re-inplace everything 47 node.meta["only_clone_these_tensors"] = [] 48 decompose_auto_functionalized(ep.graph) 49 remove_self_clone(ep.graph) 50 ep.graph.eliminate_dead_code() 51 52 return ep 53