1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9import torch 10 11from torch.fx.passes.infra.pass_base import PassBase, PassResult 12 13 14class RemoveGraphAssertsPass(PassBase): 15 """ 16 Temporary pass to remove all the assert ops until runtime decides to address it. 17 """ 18 19 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: 20 for module in graph_module.modules(): 21 if not isinstance(module, torch.fx.GraphModule): 22 continue 23 24 for node in module.graph.nodes: 25 if node.op == "call_function" and ( 26 node.target 27 in ( 28 torch.ops.aten._assert_async.msg, 29 torch.ops.aten._assert_scalar.default, 30 torch.ops.aten.sym_constrain_range_for_size.default, 31 torch.ops.aten.sym_constrain_range.default, 32 ) 33 ): 34 module.graph.erase_node(node) 35 36 module.recompile() 37 module.graph.eliminate_dead_code() 38 39 return PassResult(graph_module, True) 40