xref: /aosp_15_r20/external/pytorch/torch/export/_remove_auto_functionalized_pass.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# Copyright (c) Meta Platforms, Inc. and affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8
9import torch
10from torch._higher_order_ops.auto_functionalize import (
11    auto_functionalized,
12    auto_functionalized_v2,
13)
14from torch._inductor.fx_passes.post_grad import decompose_auto_functionalized
15from torch.export import ExportedProgram
16
17
18def remove_self_clone(graph: torch.fx.Graph):
19    for node in graph.nodes:
20        if node.target == torch.ops.aten.copy_.default and node.args[0] == node.args[1]:
21            node.replace_all_uses_with(node.args[0])
22            graph.erase_node(node)
23
24
25def unsafe_remove_auto_functionalized_pass(
26    ep: ExportedProgram,
27) -> ExportedProgram:
28    """
29    This pass removes an instances of the higher order op 'auto_functionalized',
30    and modifies the calling EP inplace to have the original mutator op.
31    This pass doesn't perform safety checks to make sure that this inplace mutation is safe.
32    """
33
34    with ep.graph_module._set_replace_hook(ep.graph_signature.get_replace_hook()):
35        for module in ep.graph_module.modules():
36            if not isinstance(module, torch.fx.GraphModule):
37                continue
38            for node in ep.graph.nodes:
39                if (
40                    node.op == "call_function" and node.target is auto_functionalized
41                ) or (
42                    node.op == "call_function" and node.target is auto_functionalized_v2
43                ):
44                    func = node.args[0]
45                    assert isinstance(func, torch._ops.OpOverload)
46                    # re-inplace everything
47                    node.meta["only_clone_these_tensors"] = []
48            decompose_auto_functionalized(ep.graph)
49            remove_self_clone(ep.graph)
50            ep.graph.eliminate_dead_code()
51
52    return ep
53