xref: /aosp_15_r20/external/executorch/backends/qualcomm/qnn_preprocess.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import logging
8from collections import defaultdict
9from typing import final, List
10
11import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager
12
13import torch  # noqa: F401
14from executorch.backends.qualcomm._passes.fuse_consecutive_transpose import (
15    FuseConsecutiveTranspose,
16)
17from executorch.backends.qualcomm._passes.insert_io_qdq import InsertIOQDQ
18from executorch.backends.qualcomm._passes.insert_requantize import InsertRequantize
19from executorch.backends.qualcomm._passes.layout_transform import LayoutTransform
20from executorch.backends.qualcomm.builders.node_visitor import get_node_visitors
21from executorch.backends.qualcomm.builders.qnn_constants import OpContextLoader
22from executorch.backends.qualcomm.partition.utils import generate_qnn_executorch_option
23from executorch.exir.backend.backend_details import (
24    BackendDetails,
25    CompileSpec,
26    PreprocessResult,
27)
28from executorch.exir.passes import PassManager
29from torch.export.exported_program import ExportedProgram
30
31DEFAULT_DEBUG_HANDLE = 65535
32
33logger = logging.getLogger(__name__)
34logger.setLevel(logging.DEBUG)
35
36
37@final
38class QnnBackend(BackendDetails):
39    @staticmethod
40    def preprocess(
41        edge_program: ExportedProgram,
42        compile_specs: List[CompileSpec],
43    ) -> PreprocessResult:
44        option = generate_qnn_executorch_option(compile_specs)
45        qnn_manager = PyQnnManager.QnnManager(option)
46        qnn_manager.Init()
47
48        # QNN Delegate Specific Passes
49        qnn_compiler_passes = PassManager(
50            passes=[
51                InsertRequantize(edge_program),
52                InsertIOQDQ(edge_program),
53                LayoutTransform(edge_program, insert_permute=True),
54                FuseConsecutiveTranspose(),
55            ]
56        )
57
58        pass_result = qnn_compiler_passes(edge_program.graph_module)
59        assert pass_result is not None
60
61        enable_tensor_dump = qnn_manager.IsTensorDump()
62        nodes_to_wrappers = defaultdict(dict)
63        node_visitors = get_node_visitors(
64            edge_program, enable_tensor_dump=enable_tensor_dump
65        )
66        py_op_wrapper_list = []
67        for node in pass_result.graph_module.graph.nodes:
68            if node.op == "call_function":
69                logger.info(f"Visiting: {node}, {node.target.__name__}")
70                if node.target.__name__ in node_visitors:
71                    py_op_wrapper = node_visitors[node.target.__name__].define_node(
72                        node, nodes_to_wrappers
73                    )
74                    if py_op_wrapper is not None:
75                        if isinstance(py_op_wrapper, List):
76                            py_op_wrapper_list.extend(py_op_wrapper)
77                        else:
78                            py_op_wrapper_list.append(py_op_wrapper)
79                else:
80                    err_msg = (
81                        f"For {node}, {node.op}:{node.target.__name__} "
82                        "is not supported in Qnn Delegate"
83                    )
84                    try:
85                        context_loader_target = eval(
86                            f"torch.ops.{OpContextLoader.namespace}.{node.target.__name__}",
87                            globals().update(torch.__dict__),
88                        )
89                        assert node.target == context_loader_target, err_msg
90                        # if graph has context binary loader node, return directly
91                        return PreprocessResult(
92                            processed_bytes=node.meta[OpContextLoader.meta_ctx_bin],
93                            debug_handle_map={},
94                        )
95                    except:
96                        raise RuntimeError(err_msg)
97
98            elif node.op in [
99                "get_attr",
100                "placeholder",
101                "output",
102            ]:
103                continue
104            else:
105                raise RuntimeError(f"{node.op} is not supported in Qnn")
106        qnn_context_binary = qnn_manager.Compile(
107            qnn_manager.GetGraphNames()[0],
108            [py_op_wrapper.GetOpWrapper() for py_op_wrapper in py_op_wrapper_list],
109        )
110        assert len(qnn_context_binary) != 0, "Failed to generate Qnn context binary."
111        qnn_manager.Destroy()
112        # For now, debug_handle_map is not used by QNN ExecuTorch
113        return PreprocessResult(
114            processed_bytes=bytes(qnn_context_binary),
115            debug_handle_map={},
116        )
117