1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9import torch 10from executorch.exir.dialects._ops import ops 11from executorch.exir.passes.executorch_prim_ops_registry import ( 12 _PYTHON_SYM_OPS_TO_EXECUTORCH_SYM_OPS, 13) 14from torch.fx.passes.infra.pass_base import PassBase, PassResult 15 16 17class EdgeToBackendOpsPass(PassBase): 18 """ 19 Converts 20 1. symbolic int ops to the executorch_prims namespaced ops 21 2. other backend ops from torch._ops.OpOverload to BackendOpOverload 22 """ 23 24 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: 25 for module in graph_module.modules(): 26 if not isinstance(module, torch.fx.GraphModule): 27 continue 28 29 for node in module.graph.nodes: 30 if node.op == "call_function": 31 if node.target in _PYTHON_SYM_OPS_TO_EXECUTORCH_SYM_OPS: 32 node.target = _PYTHON_SYM_OPS_TO_EXECUTORCH_SYM_OPS[node.target] 33 34 elif isinstance(node.target, torch._ops.OpOverload): 35 # replace torch.ops.OpOverload with its corresponding backend ops. 36 # Looking op name up from _dir in _DialectNamespace, _OpNamespace 37 # and BackendOpOverloadPacket 38 39 def get_new_op( 40 target: torch._ops.OpOverload, 41 ) -> torch._ops.OpOverload: 42 namespace = target.namespace 43 name = target._schema.name.split("::")[1] 44 overload_name = target._overloadname 45 obj = ops.backend 46 for key in [namespace, name, overload_name]: 47 if key not in obj._dir: 48 return target 49 obj = getattr(obj, key) 50 return obj 51 52 node.target = get_new_op(node.target) 53 54 module.recompile() 55 56 return PassResult(graph_module, True) 57