1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Qualcomm Innovation Center, Inc. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Workerimport logging 8*523fa7a6SAndroid Build Coastguard Workerfrom collections import defaultdict 9*523fa7a6SAndroid Build Coastguard Workerfrom typing import final, List 10*523fa7a6SAndroid Build Coastguard Worker 11*523fa7a6SAndroid Build Coastguard Workerimport executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager 12*523fa7a6SAndroid Build Coastguard Worker 13*523fa7a6SAndroid Build Coastguard Workerimport torch # noqa: F401 14*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.fuse_consecutive_transpose import ( 15*523fa7a6SAndroid Build Coastguard Worker FuseConsecutiveTranspose, 16*523fa7a6SAndroid Build Coastguard Worker) 17*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.insert_io_qdq import InsertIOQDQ 18*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.insert_requantize import InsertRequantize 19*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.layout_transform import LayoutTransform 20*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm.builders.node_visitor import get_node_visitors 21*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm.builders.qnn_constants import OpContextLoader 22*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm.partition.utils import generate_qnn_executorch_option 23*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.backend_details import ( 24*523fa7a6SAndroid Build Coastguard Worker BackendDetails, 25*523fa7a6SAndroid Build Coastguard Worker CompileSpec, 26*523fa7a6SAndroid Build Coastguard Worker PreprocessResult, 27*523fa7a6SAndroid Build Coastguard Worker) 28*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes import PassManager 29*523fa7a6SAndroid Build Coastguard Workerfrom torch.export.exported_program import ExportedProgram 30*523fa7a6SAndroid Build Coastguard Worker 31*523fa7a6SAndroid Build Coastguard WorkerDEFAULT_DEBUG_HANDLE = 65535 32*523fa7a6SAndroid Build Coastguard Worker 33*523fa7a6SAndroid Build Coastguard Workerlogger = logging.getLogger(__name__) 34*523fa7a6SAndroid Build Coastguard Workerlogger.setLevel(logging.DEBUG) 35*523fa7a6SAndroid Build Coastguard Worker 36*523fa7a6SAndroid Build Coastguard Worker 37*523fa7a6SAndroid Build Coastguard Worker@final 38*523fa7a6SAndroid Build Coastguard Workerclass QnnBackend(BackendDetails): 39*523fa7a6SAndroid Build Coastguard Worker @staticmethod 40*523fa7a6SAndroid Build Coastguard Worker def preprocess( 41*523fa7a6SAndroid Build Coastguard Worker edge_program: ExportedProgram, 42*523fa7a6SAndroid Build Coastguard Worker compile_specs: List[CompileSpec], 43*523fa7a6SAndroid Build Coastguard Worker ) -> PreprocessResult: 44*523fa7a6SAndroid Build Coastguard Worker option = generate_qnn_executorch_option(compile_specs) 45*523fa7a6SAndroid Build Coastguard Worker qnn_manager = PyQnnManager.QnnManager(option) 46*523fa7a6SAndroid Build Coastguard Worker qnn_manager.Init() 47*523fa7a6SAndroid Build Coastguard Worker 48*523fa7a6SAndroid Build Coastguard Worker # QNN Delegate Specific Passes 49*523fa7a6SAndroid Build Coastguard Worker qnn_compiler_passes = PassManager( 50*523fa7a6SAndroid Build Coastguard Worker passes=[ 51*523fa7a6SAndroid Build Coastguard Worker InsertRequantize(edge_program), 52*523fa7a6SAndroid Build Coastguard Worker InsertIOQDQ(edge_program), 53*523fa7a6SAndroid Build Coastguard Worker LayoutTransform(edge_program, insert_permute=True), 54*523fa7a6SAndroid Build Coastguard Worker FuseConsecutiveTranspose(), 55*523fa7a6SAndroid Build Coastguard Worker ] 56*523fa7a6SAndroid Build Coastguard Worker ) 57*523fa7a6SAndroid Build Coastguard Worker 58*523fa7a6SAndroid Build Coastguard Worker pass_result = qnn_compiler_passes(edge_program.graph_module) 59*523fa7a6SAndroid Build Coastguard Worker assert pass_result is not None 60*523fa7a6SAndroid Build Coastguard Worker 61*523fa7a6SAndroid Build Coastguard Worker enable_tensor_dump = qnn_manager.IsTensorDump() 62*523fa7a6SAndroid Build Coastguard Worker nodes_to_wrappers = defaultdict(dict) 63*523fa7a6SAndroid Build Coastguard Worker node_visitors = get_node_visitors( 64*523fa7a6SAndroid Build Coastguard Worker edge_program, enable_tensor_dump=enable_tensor_dump 65*523fa7a6SAndroid Build Coastguard Worker ) 66*523fa7a6SAndroid Build Coastguard Worker py_op_wrapper_list = [] 67*523fa7a6SAndroid Build Coastguard Worker for node in pass_result.graph_module.graph.nodes: 68*523fa7a6SAndroid Build Coastguard Worker if node.op == "call_function": 69*523fa7a6SAndroid Build Coastguard Worker logger.info(f"Visiting: {node}, {node.target.__name__}") 70*523fa7a6SAndroid Build Coastguard Worker if node.target.__name__ in node_visitors: 71*523fa7a6SAndroid Build Coastguard Worker py_op_wrapper = node_visitors[node.target.__name__].define_node( 72*523fa7a6SAndroid Build Coastguard Worker node, nodes_to_wrappers 73*523fa7a6SAndroid Build Coastguard Worker ) 74*523fa7a6SAndroid Build Coastguard Worker if py_op_wrapper is not None: 75*523fa7a6SAndroid Build Coastguard Worker if isinstance(py_op_wrapper, List): 76*523fa7a6SAndroid Build Coastguard Worker py_op_wrapper_list.extend(py_op_wrapper) 77*523fa7a6SAndroid Build Coastguard Worker else: 78*523fa7a6SAndroid Build Coastguard Worker py_op_wrapper_list.append(py_op_wrapper) 79*523fa7a6SAndroid Build Coastguard Worker else: 80*523fa7a6SAndroid Build Coastguard Worker err_msg = ( 81*523fa7a6SAndroid Build Coastguard Worker f"For {node}, {node.op}:{node.target.__name__} " 82*523fa7a6SAndroid Build Coastguard Worker "is not supported in Qnn Delegate" 83*523fa7a6SAndroid Build Coastguard Worker ) 84*523fa7a6SAndroid Build Coastguard Worker try: 85*523fa7a6SAndroid Build Coastguard Worker context_loader_target = eval( 86*523fa7a6SAndroid Build Coastguard Worker f"torch.ops.{OpContextLoader.namespace}.{node.target.__name__}", 87*523fa7a6SAndroid Build Coastguard Worker globals().update(torch.__dict__), 88*523fa7a6SAndroid Build Coastguard Worker ) 89*523fa7a6SAndroid Build Coastguard Worker assert node.target == context_loader_target, err_msg 90*523fa7a6SAndroid Build Coastguard Worker # if graph has context binary loader node, return directly 91*523fa7a6SAndroid Build Coastguard Worker return PreprocessResult( 92*523fa7a6SAndroid Build Coastguard Worker processed_bytes=node.meta[OpContextLoader.meta_ctx_bin], 93*523fa7a6SAndroid Build Coastguard Worker debug_handle_map={}, 94*523fa7a6SAndroid Build Coastguard Worker ) 95*523fa7a6SAndroid Build Coastguard Worker except: 96*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError(err_msg) 97*523fa7a6SAndroid Build Coastguard Worker 98*523fa7a6SAndroid Build Coastguard Worker elif node.op in [ 99*523fa7a6SAndroid Build Coastguard Worker "get_attr", 100*523fa7a6SAndroid Build Coastguard Worker "placeholder", 101*523fa7a6SAndroid Build Coastguard Worker "output", 102*523fa7a6SAndroid Build Coastguard Worker ]: 103*523fa7a6SAndroid Build Coastguard Worker continue 104*523fa7a6SAndroid Build Coastguard Worker else: 105*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError(f"{node.op} is not supported in Qnn") 106*523fa7a6SAndroid Build Coastguard Worker qnn_context_binary = qnn_manager.Compile( 107*523fa7a6SAndroid Build Coastguard Worker qnn_manager.GetGraphNames()[0], 108*523fa7a6SAndroid Build Coastguard Worker [py_op_wrapper.GetOpWrapper() for py_op_wrapper in py_op_wrapper_list], 109*523fa7a6SAndroid Build Coastguard Worker ) 110*523fa7a6SAndroid Build Coastguard Worker assert len(qnn_context_binary) != 0, "Failed to generate Qnn context binary." 111*523fa7a6SAndroid Build Coastguard Worker qnn_manager.Destroy() 112*523fa7a6SAndroid Build Coastguard Worker # For now, debug_handle_map is not used by QNN ExecuTorch 113*523fa7a6SAndroid Build Coastguard Worker return PreprocessResult( 114*523fa7a6SAndroid Build Coastguard Worker processed_bytes=bytes(qnn_context_binary), 115*523fa7a6SAndroid Build Coastguard Worker debug_handle_map={}, 116*523fa7a6SAndroid Build Coastguard Worker ) 117