xref: /aosp_15_r20/external/executorch/exir/passes/replace_edge_with_backend_pass.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9import torch
10from executorch.exir.dialects._ops import ops
11from executorch.exir.passes.executorch_prim_ops_registry import (
12    _PYTHON_SYM_OPS_TO_EXECUTORCH_SYM_OPS,
13)
14from torch.fx.passes.infra.pass_base import PassBase, PassResult
15
16
17class EdgeToBackendOpsPass(PassBase):
18    """
19    Converts
20    1. symbolic int ops to the executorch_prims namespaced ops
21    2. other backend ops from torch._ops.OpOverload to BackendOpOverload
22    """
23
24    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
25        for module in graph_module.modules():
26            if not isinstance(module, torch.fx.GraphModule):
27                continue
28
29            for node in module.graph.nodes:
30                if node.op == "call_function":
31                    if node.target in _PYTHON_SYM_OPS_TO_EXECUTORCH_SYM_OPS:
32                        node.target = _PYTHON_SYM_OPS_TO_EXECUTORCH_SYM_OPS[node.target]
33
34                    elif isinstance(node.target, torch._ops.OpOverload):
35                        # replace torch.ops.OpOverload with its corresponding backend ops.
36                        # Looking op name up from _dir in _DialectNamespace, _OpNamespace
37                        # and BackendOpOverloadPacket
38
39                        def get_new_op(
40                            target: torch._ops.OpOverload,
41                        ) -> torch._ops.OpOverload:
42                            namespace = target.namespace
43                            name = target._schema.name.split("::")[1]
44                            overload_name = target._overloadname
45                            obj = ops.backend
46                            for key in [namespace, name, overload_name]:
47                                if key not in obj._dir:
48                                    return target
49                                obj = getattr(obj, key)
50                            return obj
51
52                        node.target = get_new_op(node.target)
53
54            module.recompile()
55
56        return PassResult(graph_module, True)
57