xref: /aosp_15_r20/external/executorch/backends/qualcomm/qnn_preprocess.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Qualcomm Innovation Center, Inc.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Workerimport logging
8*523fa7a6SAndroid Build Coastguard Workerfrom collections import defaultdict
9*523fa7a6SAndroid Build Coastguard Workerfrom typing import final, List
10*523fa7a6SAndroid Build Coastguard Worker
11*523fa7a6SAndroid Build Coastguard Workerimport executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager
12*523fa7a6SAndroid Build Coastguard Worker
13*523fa7a6SAndroid Build Coastguard Workerimport torch  # noqa: F401
14*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.fuse_consecutive_transpose import (
15*523fa7a6SAndroid Build Coastguard Worker    FuseConsecutiveTranspose,
16*523fa7a6SAndroid Build Coastguard Worker)
17*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.insert_io_qdq import InsertIOQDQ
18*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.insert_requantize import InsertRequantize
19*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.layout_transform import LayoutTransform
20*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm.builders.node_visitor import get_node_visitors
21*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm.builders.qnn_constants import OpContextLoader
22*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm.partition.utils import generate_qnn_executorch_option
23*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.backend_details import (
24*523fa7a6SAndroid Build Coastguard Worker    BackendDetails,
25*523fa7a6SAndroid Build Coastguard Worker    CompileSpec,
26*523fa7a6SAndroid Build Coastguard Worker    PreprocessResult,
27*523fa7a6SAndroid Build Coastguard Worker)
28*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes import PassManager
29*523fa7a6SAndroid Build Coastguard Workerfrom torch.export.exported_program import ExportedProgram
30*523fa7a6SAndroid Build Coastguard Worker
31*523fa7a6SAndroid Build Coastguard WorkerDEFAULT_DEBUG_HANDLE = 65535
32*523fa7a6SAndroid Build Coastguard Worker
33*523fa7a6SAndroid Build Coastguard Workerlogger = logging.getLogger(__name__)
34*523fa7a6SAndroid Build Coastguard Workerlogger.setLevel(logging.DEBUG)
35*523fa7a6SAndroid Build Coastguard Worker
36*523fa7a6SAndroid Build Coastguard Worker
37*523fa7a6SAndroid Build Coastguard Worker@final
38*523fa7a6SAndroid Build Coastguard Workerclass QnnBackend(BackendDetails):
39*523fa7a6SAndroid Build Coastguard Worker    @staticmethod
40*523fa7a6SAndroid Build Coastguard Worker    def preprocess(
41*523fa7a6SAndroid Build Coastguard Worker        edge_program: ExportedProgram,
42*523fa7a6SAndroid Build Coastguard Worker        compile_specs: List[CompileSpec],
43*523fa7a6SAndroid Build Coastguard Worker    ) -> PreprocessResult:
44*523fa7a6SAndroid Build Coastguard Worker        option = generate_qnn_executorch_option(compile_specs)
45*523fa7a6SAndroid Build Coastguard Worker        qnn_manager = PyQnnManager.QnnManager(option)
46*523fa7a6SAndroid Build Coastguard Worker        qnn_manager.Init()
47*523fa7a6SAndroid Build Coastguard Worker
48*523fa7a6SAndroid Build Coastguard Worker        # QNN Delegate Specific Passes
49*523fa7a6SAndroid Build Coastguard Worker        qnn_compiler_passes = PassManager(
50*523fa7a6SAndroid Build Coastguard Worker            passes=[
51*523fa7a6SAndroid Build Coastguard Worker                InsertRequantize(edge_program),
52*523fa7a6SAndroid Build Coastguard Worker                InsertIOQDQ(edge_program),
53*523fa7a6SAndroid Build Coastguard Worker                LayoutTransform(edge_program, insert_permute=True),
54*523fa7a6SAndroid Build Coastguard Worker                FuseConsecutiveTranspose(),
55*523fa7a6SAndroid Build Coastguard Worker            ]
56*523fa7a6SAndroid Build Coastguard Worker        )
57*523fa7a6SAndroid Build Coastguard Worker
58*523fa7a6SAndroid Build Coastguard Worker        pass_result = qnn_compiler_passes(edge_program.graph_module)
59*523fa7a6SAndroid Build Coastguard Worker        assert pass_result is not None
60*523fa7a6SAndroid Build Coastguard Worker
61*523fa7a6SAndroid Build Coastguard Worker        enable_tensor_dump = qnn_manager.IsTensorDump()
62*523fa7a6SAndroid Build Coastguard Worker        nodes_to_wrappers = defaultdict(dict)
63*523fa7a6SAndroid Build Coastguard Worker        node_visitors = get_node_visitors(
64*523fa7a6SAndroid Build Coastguard Worker            edge_program, enable_tensor_dump=enable_tensor_dump
65*523fa7a6SAndroid Build Coastguard Worker        )
66*523fa7a6SAndroid Build Coastguard Worker        py_op_wrapper_list = []
67*523fa7a6SAndroid Build Coastguard Worker        for node in pass_result.graph_module.graph.nodes:
68*523fa7a6SAndroid Build Coastguard Worker            if node.op == "call_function":
69*523fa7a6SAndroid Build Coastguard Worker                logger.info(f"Visiting: {node}, {node.target.__name__}")
70*523fa7a6SAndroid Build Coastguard Worker                if node.target.__name__ in node_visitors:
71*523fa7a6SAndroid Build Coastguard Worker                    py_op_wrapper = node_visitors[node.target.__name__].define_node(
72*523fa7a6SAndroid Build Coastguard Worker                        node, nodes_to_wrappers
73*523fa7a6SAndroid Build Coastguard Worker                    )
74*523fa7a6SAndroid Build Coastguard Worker                    if py_op_wrapper is not None:
75*523fa7a6SAndroid Build Coastguard Worker                        if isinstance(py_op_wrapper, List):
76*523fa7a6SAndroid Build Coastguard Worker                            py_op_wrapper_list.extend(py_op_wrapper)
77*523fa7a6SAndroid Build Coastguard Worker                        else:
78*523fa7a6SAndroid Build Coastguard Worker                            py_op_wrapper_list.append(py_op_wrapper)
79*523fa7a6SAndroid Build Coastguard Worker                else:
80*523fa7a6SAndroid Build Coastguard Worker                    err_msg = (
81*523fa7a6SAndroid Build Coastguard Worker                        f"For {node}, {node.op}:{node.target.__name__} "
82*523fa7a6SAndroid Build Coastguard Worker                        "is not supported in Qnn Delegate"
83*523fa7a6SAndroid Build Coastguard Worker                    )
84*523fa7a6SAndroid Build Coastguard Worker                    try:
85*523fa7a6SAndroid Build Coastguard Worker                        context_loader_target = eval(
86*523fa7a6SAndroid Build Coastguard Worker                            f"torch.ops.{OpContextLoader.namespace}.{node.target.__name__}",
87*523fa7a6SAndroid Build Coastguard Worker                            globals().update(torch.__dict__),
88*523fa7a6SAndroid Build Coastguard Worker                        )
89*523fa7a6SAndroid Build Coastguard Worker                        assert node.target == context_loader_target, err_msg
90*523fa7a6SAndroid Build Coastguard Worker                        # if graph has context binary loader node, return directly
91*523fa7a6SAndroid Build Coastguard Worker                        return PreprocessResult(
92*523fa7a6SAndroid Build Coastguard Worker                            processed_bytes=node.meta[OpContextLoader.meta_ctx_bin],
93*523fa7a6SAndroid Build Coastguard Worker                            debug_handle_map={},
94*523fa7a6SAndroid Build Coastguard Worker                        )
95*523fa7a6SAndroid Build Coastguard Worker                    except:
96*523fa7a6SAndroid Build Coastguard Worker                        raise RuntimeError(err_msg)
97*523fa7a6SAndroid Build Coastguard Worker
98*523fa7a6SAndroid Build Coastguard Worker            elif node.op in [
99*523fa7a6SAndroid Build Coastguard Worker                "get_attr",
100*523fa7a6SAndroid Build Coastguard Worker                "placeholder",
101*523fa7a6SAndroid Build Coastguard Worker                "output",
102*523fa7a6SAndroid Build Coastguard Worker            ]:
103*523fa7a6SAndroid Build Coastguard Worker                continue
104*523fa7a6SAndroid Build Coastguard Worker            else:
105*523fa7a6SAndroid Build Coastguard Worker                raise RuntimeError(f"{node.op} is not supported in Qnn")
106*523fa7a6SAndroid Build Coastguard Worker        qnn_context_binary = qnn_manager.Compile(
107*523fa7a6SAndroid Build Coastguard Worker            qnn_manager.GetGraphNames()[0],
108*523fa7a6SAndroid Build Coastguard Worker            [py_op_wrapper.GetOpWrapper() for py_op_wrapper in py_op_wrapper_list],
109*523fa7a6SAndroid Build Coastguard Worker        )
110*523fa7a6SAndroid Build Coastguard Worker        assert len(qnn_context_binary) != 0, "Failed to generate Qnn context binary."
111*523fa7a6SAndroid Build Coastguard Worker        qnn_manager.Destroy()
112*523fa7a6SAndroid Build Coastguard Worker        # For now, debug_handle_map is not used by QNN ExecuTorch
113*523fa7a6SAndroid Build Coastguard Worker        return PreprocessResult(
114*523fa7a6SAndroid Build Coastguard Worker            processed_bytes=bytes(qnn_context_binary),
115*523fa7a6SAndroid Build Coastguard Worker            debug_handle_map={},
116*523fa7a6SAndroid Build Coastguard Worker        )
117