xref: /aosp_15_r20/external/executorch/exir/passes/remove_graph_asserts_pass.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9import torch
10
11from torch.fx.passes.infra.pass_base import PassBase, PassResult
12
13
14class RemoveGraphAssertsPass(PassBase):
15    """
16    Temporary pass to remove all the assert ops until runtime decides to address it.
17    """
18
19    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
20        for module in graph_module.modules():
21            if not isinstance(module, torch.fx.GraphModule):
22                continue
23
24            for node in module.graph.nodes:
25                if node.op == "call_function" and (
26                    node.target
27                    in (
28                        torch.ops.aten._assert_async.msg,
29                        torch.ops.aten._assert_scalar.default,
30                        torch.ops.aten.sym_constrain_range_for_size.default,
31                        torch.ops.aten.sym_constrain_range.default,
32                    )
33                ):
34                    module.graph.erase_node(node)
35
36            module.recompile()
37            module.graph.eliminate_dead_code()
38
39        return PassResult(graph_module, True)
40