1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Qualcomm Innovation Center, Inc. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Workerimport operator 8*523fa7a6SAndroid Build Coastguard Workerimport warnings 9*523fa7a6SAndroid Build Coastguard Workerfrom collections import OrderedDict 10*523fa7a6SAndroid Build Coastguard Workerfrom typing import Callable, Dict, FrozenSet, List, Tuple 11*523fa7a6SAndroid Build Coastguard Worker 12*523fa7a6SAndroid Build Coastguard Workerimport executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManagerAdaptor 13*523fa7a6SAndroid Build Coastguard Worker 14*523fa7a6SAndroid Build Coastguard Workerimport executorch.exir as exir 15*523fa7a6SAndroid Build Coastguard Worker 16*523fa7a6SAndroid Build Coastguard Workerimport torch 17*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.annotate_and_quant_scalar import ( 18*523fa7a6SAndroid Build Coastguard Worker AnnotateAndQuantScalar, 19*523fa7a6SAndroid Build Coastguard Worker) 20*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.annotate_decomposed import AnnotateDecomposed 21*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.annotate_quant_attrs import AnnotateQuantAttrs 22*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.convert_binary_op_with_scalar import ( 23*523fa7a6SAndroid Build Coastguard Worker ConvertBinaryOpsWithScalar, 24*523fa7a6SAndroid Build Coastguard Worker) 25*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.convert_bmm_to_matmul import ( 26*523fa7a6SAndroid Build Coastguard Worker ConvertBmmToMatmul, 27*523fa7a6SAndroid Build Coastguard Worker) 28*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.convert_interpolate_with_upsample2d import ( 29*523fa7a6SAndroid Build Coastguard Worker ConvertInterpolateWithUpsample2D, 30*523fa7a6SAndroid Build Coastguard Worker) 31*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.convert_prelu import ConvertPReLU 32*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.convert_to_linear import ConvertToLinear 33*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.expand_broadcast_tensor_shape import ( 34*523fa7a6SAndroid Build Coastguard Worker ExpandBroadcastTensorShape, 35*523fa7a6SAndroid Build Coastguard Worker) 36*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.fold_qdq import FoldQDQ 37*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.i64_to_i32 import I64toI32 38*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.layout_transform import LayoutTransform 39*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.recompose_pixel_unshuffle import ( 40*523fa7a6SAndroid Build Coastguard Worker RecomposePixelUnshuffle, 41*523fa7a6SAndroid Build Coastguard Worker) 42*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.recompose_rms_norm import RecomposeRmsNorm 43*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.remove_redundancy import RemoveRedundancy 44*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.replace_index_put_input import ( 45*523fa7a6SAndroid Build Coastguard Worker ReplaceIndexPutInput, 46*523fa7a6SAndroid Build Coastguard Worker) 47*523fa7a6SAndroid Build Coastguard Worker 48*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm.builders.node_visitor import ( 49*523fa7a6SAndroid Build Coastguard Worker QNN_QUANT_TYPE_MAP, 50*523fa7a6SAndroid Build Coastguard Worker QNN_TENSOR_TYPE_MAP, 51*523fa7a6SAndroid Build Coastguard Worker) 52*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm.builders.qnn_constants import OpContextLoader 53*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm.partition.qnn_partitioner import ( 54*523fa7a6SAndroid Build Coastguard Worker generate_qnn_executorch_option, 55*523fa7a6SAndroid Build Coastguard Worker QnnPartitioner, 56*523fa7a6SAndroid Build Coastguard Worker) 57*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm.serialization.qc_schema import ( 58*523fa7a6SAndroid Build Coastguard Worker _soc_info_table, 59*523fa7a6SAndroid Build Coastguard Worker HtpArch, 60*523fa7a6SAndroid Build Coastguard Worker QcomChipset, 61*523fa7a6SAndroid Build Coastguard Worker QnnExecuTorchBackendOptions, 62*523fa7a6SAndroid Build Coastguard Worker QnnExecuTorchBackendType, 63*523fa7a6SAndroid Build Coastguard Worker QnnExecuTorchHtpBackendOptions, 64*523fa7a6SAndroid Build Coastguard Worker QnnExecuTorchHtpPerformanceMode, 65*523fa7a6SAndroid Build Coastguard Worker QnnExecuTorchHtpPrecision, 66*523fa7a6SAndroid Build Coastguard Worker QnnExecuTorchLogLevel, 67*523fa7a6SAndroid Build Coastguard Worker QnnExecuTorchOptions, 68*523fa7a6SAndroid Build Coastguard Worker QnnExecuTorchProfileLevel, 69*523fa7a6SAndroid Build Coastguard Worker) 70*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm.serialization.qc_schema_serialize import ( 71*523fa7a6SAndroid Build Coastguard Worker flatbuffer_to_option, 72*523fa7a6SAndroid Build Coastguard Worker option_to_flatbuffer, 73*523fa7a6SAndroid Build Coastguard Worker) 74*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm.utils.constants import ( 75*523fa7a6SAndroid Build Coastguard Worker QCOM_PASS_EXPAND_BROADCAST_SHAPE, 76*523fa7a6SAndroid Build Coastguard Worker QCOM_PASS_SKIP_ADVANCED_REQUANT, 77*523fa7a6SAndroid Build Coastguard Worker QCOM_QNN_COMPILE_SPEC, 78*523fa7a6SAndroid Build Coastguard Worker QCOM_QUANTIZED_IO, 79*523fa7a6SAndroid Build Coastguard Worker) 80*523fa7a6SAndroid Build Coastguard Worker 81*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir import ( 82*523fa7a6SAndroid Build Coastguard Worker EdgeCompileConfig, 83*523fa7a6SAndroid Build Coastguard Worker ExecutorchProgramManager, 84*523fa7a6SAndroid Build Coastguard Worker ExirExportedProgram, 85*523fa7a6SAndroid Build Coastguard Worker to_edge, 86*523fa7a6SAndroid Build Coastguard Worker) 87*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.compile_spec_schema import CompileSpec 88*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.capture import ExecutorchBackendConfig 89*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.lowered_backend_module import LoweredBackendModule 90*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.program._program import _get_updated_graph_signature 91*523fa7a6SAndroid Build Coastguard Workerfrom torch._decomp import core_aten_decompositions as torch_core_aten_decompositions 92*523fa7a6SAndroid Build Coastguard Workerfrom torch.export.exported_program import ExportedProgram 93*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx import passes 94*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx.passes.operator_support import OperatorSupportBase 95*523fa7a6SAndroid Build Coastguard Workerfrom torch.library import Library 96*523fa7a6SAndroid Build Coastguard Worker 97*523fa7a6SAndroid Build Coastguard Worker 98*523fa7a6SAndroid Build Coastguard Workerclass _AnnotationSkipper(OperatorSupportBase): 99*523fa7a6SAndroid Build Coastguard Worker """ 100*523fa7a6SAndroid Build Coastguard Worker Class used to partition out unwanted graph nodes. 101*523fa7a6SAndroid Build Coastguard Worker e.g. - nodes are prevented from quantization annotation 102*523fa7a6SAndroid Build Coastguard Worker - nodes have been grouped together as a submodule 103*523fa7a6SAndroid Build Coastguard Worker 104*523fa7a6SAndroid Build Coastguard Worker Attributes 105*523fa7a6SAndroid Build Coastguard Worker ---------- 106*523fa7a6SAndroid Build Coastguard Worker fp_node_id_set : set 107*523fa7a6SAndroid Build Coastguard Worker a set contains nodes' name to be left in fp precision 108*523fa7a6SAndroid Build Coastguard Worker fp_node_op_set : set 109*523fa7a6SAndroid Build Coastguard Worker a set contains nodes' target (aten dialect) to be left in fp precision 110*523fa7a6SAndroid Build Coastguard Worker skip_annotated_submodule : bool 111*523fa7a6SAndroid Build Coastguard Worker flag to skip annotated submodule or not 112*523fa7a6SAndroid Build Coastguard Worker 113*523fa7a6SAndroid Build Coastguard Worker Methods 114*523fa7a6SAndroid Build Coastguard Worker ------- 115*523fa7a6SAndroid Build Coastguard Worker should_delegate(n: torch.fx.Node) 116*523fa7a6SAndroid Build Coastguard Worker identify the residual nodes haven't be lowered with fixed-precision 117*523fa7a6SAndroid Build Coastguard Worker should_skip(n: torch.fx.Node) 118*523fa7a6SAndroid Build Coastguard Worker identify the nodes should be kept out with fixed-precision or not 119*523fa7a6SAndroid Build Coastguard Worker is_node_supported(_, node: torch.fx.Node) 120*523fa7a6SAndroid Build Coastguard Worker overridden method for graph partitioning 121*523fa7a6SAndroid Build Coastguard Worker """ 122*523fa7a6SAndroid Build Coastguard Worker 123*523fa7a6SAndroid Build Coastguard Worker def __init__( 124*523fa7a6SAndroid Build Coastguard Worker self, 125*523fa7a6SAndroid Build Coastguard Worker fp_node_id_set: set = None, 126*523fa7a6SAndroid Build Coastguard Worker fp_node_op_set: set = None, 127*523fa7a6SAndroid Build Coastguard Worker skip_annotated_submodule: bool = False, 128*523fa7a6SAndroid Build Coastguard Worker ): 129*523fa7a6SAndroid Build Coastguard Worker self.fp_node_id_set = fp_node_id_set 130*523fa7a6SAndroid Build Coastguard Worker self.fp_node_op_set = fp_node_op_set 131*523fa7a6SAndroid Build Coastguard Worker self.skip_annotated_submodule = skip_annotated_submodule 132*523fa7a6SAndroid Build Coastguard Worker 133*523fa7a6SAndroid Build Coastguard Worker def should_delegate(self, n: torch.fx.Node): 134*523fa7a6SAndroid Build Coastguard Worker return n.op == "call_function" and n.target != operator.getitem 135*523fa7a6SAndroid Build Coastguard Worker 136*523fa7a6SAndroid Build Coastguard Worker def should_skip(self, n: torch.fx.Node): 137*523fa7a6SAndroid Build Coastguard Worker return n.name in self.fp_node_id_set or n.target in self.fp_node_op_set 138*523fa7a6SAndroid Build Coastguard Worker 139*523fa7a6SAndroid Build Coastguard Worker def is_node_supported(self, _, node: torch.fx.Node) -> bool: 140*523fa7a6SAndroid Build Coastguard Worker if self.skip_annotated_submodule: 141*523fa7a6SAndroid Build Coastguard Worker if node.op == "get_attr": 142*523fa7a6SAndroid Build Coastguard Worker return all(self.should_delegate(user) for user in node.users) 143*523fa7a6SAndroid Build Coastguard Worker return self.should_delegate(node) 144*523fa7a6SAndroid Build Coastguard Worker 145*523fa7a6SAndroid Build Coastguard Worker if any( 146*523fa7a6SAndroid Build Coastguard Worker [ 147*523fa7a6SAndroid Build Coastguard Worker node.op in ("placeholder", "output"), 148*523fa7a6SAndroid Build Coastguard Worker self.should_skip(node), 149*523fa7a6SAndroid Build Coastguard Worker # check if parameters belong to fallbacked operator 150*523fa7a6SAndroid Build Coastguard Worker ( 151*523fa7a6SAndroid Build Coastguard Worker node.op == "get_attr" 152*523fa7a6SAndroid Build Coastguard Worker and all(self.should_skip(user) for user in node.users) 153*523fa7a6SAndroid Build Coastguard Worker ), 154*523fa7a6SAndroid Build Coastguard Worker ] 155*523fa7a6SAndroid Build Coastguard Worker ): 156*523fa7a6SAndroid Build Coastguard Worker print(f"[QNN Quantizer Annotation]: {node.name} | Skipped") 157*523fa7a6SAndroid Build Coastguard Worker return False 158*523fa7a6SAndroid Build Coastguard Worker 159*523fa7a6SAndroid Build Coastguard Worker return True 160*523fa7a6SAndroid Build Coastguard Worker 161*523fa7a6SAndroid Build Coastguard Worker 162*523fa7a6SAndroid Build Coastguard Workerdef qnn_capture_config(): 163*523fa7a6SAndroid Build Coastguard Worker return exir.CaptureConfig(enable_aot=True) 164*523fa7a6SAndroid Build Coastguard Worker 165*523fa7a6SAndroid Build Coastguard Worker 166*523fa7a6SAndroid Build Coastguard Workerdef qnn_edge_config() -> exir.EdgeCompileConfig: 167*523fa7a6SAndroid Build Coastguard Worker return exir.EdgeCompileConfig( 168*523fa7a6SAndroid Build Coastguard Worker _check_ir_validity=False, 169*523fa7a6SAndroid Build Coastguard Worker _skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend. 170*523fa7a6SAndroid Build Coastguard Worker ) 171*523fa7a6SAndroid Build Coastguard Worker 172*523fa7a6SAndroid Build Coastguard Worker 173*523fa7a6SAndroid Build Coastguard Workerdef convert_linear_to_conv2d(module: torch.nn.Module): 174*523fa7a6SAndroid Build Coastguard Worker class Conv2D(torch.nn.Module): 175*523fa7a6SAndroid Build Coastguard Worker def __init__(self, weight, bias=None): 176*523fa7a6SAndroid Build Coastguard Worker super().__init__() 177*523fa7a6SAndroid Build Coastguard Worker use_bias = bias is not None 178*523fa7a6SAndroid Build Coastguard Worker self.conv = torch.nn.Conv2d( 179*523fa7a6SAndroid Build Coastguard Worker in_channels=weight.shape[0], 180*523fa7a6SAndroid Build Coastguard Worker out_channels=weight.shape[1], 181*523fa7a6SAndroid Build Coastguard Worker kernel_size=1, 182*523fa7a6SAndroid Build Coastguard Worker padding=0, 183*523fa7a6SAndroid Build Coastguard Worker bias=use_bias, 184*523fa7a6SAndroid Build Coastguard Worker ) 185*523fa7a6SAndroid Build Coastguard Worker self.conv.weight = torch.nn.Parameter(weight.reshape(*weight.shape, 1, 1)) 186*523fa7a6SAndroid Build Coastguard Worker if use_bias: 187*523fa7a6SAndroid Build Coastguard Worker self.conv.bias = torch.nn.Parameter(bias) 188*523fa7a6SAndroid Build Coastguard Worker 189*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 190*523fa7a6SAndroid Build Coastguard Worker rank = x.dim() 191*523fa7a6SAndroid Build Coastguard Worker x = x.unsqueeze(-1) if rank == 3 else x.reshape(1, *x.shape, 1) 192*523fa7a6SAndroid Build Coastguard Worker x = torch.transpose(x, 1, 2) 193*523fa7a6SAndroid Build Coastguard Worker res = self.conv(x) 194*523fa7a6SAndroid Build Coastguard Worker res = torch.transpose(res, 1, 2) 195*523fa7a6SAndroid Build Coastguard Worker res = res.squeeze(-1) if rank == 3 else res.reshape(*res.shape[1:3]) 196*523fa7a6SAndroid Build Coastguard Worker return res 197*523fa7a6SAndroid Build Coastguard Worker 198*523fa7a6SAndroid Build Coastguard Worker def replace_linear(module: torch.nn.Module): 199*523fa7a6SAndroid Build Coastguard Worker attr_strs = dir(module) 200*523fa7a6SAndroid Build Coastguard Worker if isinstance(module, torch.nn.ModuleList): 201*523fa7a6SAndroid Build Coastguard Worker attr_strs += [str(i) for i in range(len(module))] 202*523fa7a6SAndroid Build Coastguard Worker 203*523fa7a6SAndroid Build Coastguard Worker for attr_str in attr_strs: 204*523fa7a6SAndroid Build Coastguard Worker target_attr = getattr(module, attr_str) 205*523fa7a6SAndroid Build Coastguard Worker if isinstance(target_attr, torch.nn.Linear): 206*523fa7a6SAndroid Build Coastguard Worker setattr(module, attr_str, Conv2D(target_attr.weight, target_attr.bias)) 207*523fa7a6SAndroid Build Coastguard Worker 208*523fa7a6SAndroid Build Coastguard Worker for _, sub_module in module.named_children(): 209*523fa7a6SAndroid Build Coastguard Worker sub_module = replace_linear(sub_module) 210*523fa7a6SAndroid Build Coastguard Worker return module 211*523fa7a6SAndroid Build Coastguard Worker 212*523fa7a6SAndroid Build Coastguard Worker return replace_linear(module) 213*523fa7a6SAndroid Build Coastguard Worker 214*523fa7a6SAndroid Build Coastguard Worker 215*523fa7a6SAndroid Build Coastguard Workerdef update_spill_fill_size( 216*523fa7a6SAndroid Build Coastguard Worker exported_program: ExportedProgram | List[LoweredBackendModule], 217*523fa7a6SAndroid Build Coastguard Worker): 218*523fa7a6SAndroid Build Coastguard Worker # check if user specifies to use multi_contexts 219*523fa7a6SAndroid Build Coastguard Worker # this is a generic approach in case there exists multiple backends 220*523fa7a6SAndroid Build Coastguard Worker def get_program_info(program): 221*523fa7a6SAndroid Build Coastguard Worker def process_exported_program(prog): 222*523fa7a6SAndroid Build Coastguard Worker max_sf_buf_size, module_map = 0, {} 223*523fa7a6SAndroid Build Coastguard Worker for _, m in prog.graph_module._modules.items(): 224*523fa7a6SAndroid Build Coastguard Worker # currently only 1 compile spec is expected in each partition 225*523fa7a6SAndroid Build Coastguard Worker options = flatbuffer_to_option(m.compile_specs[0].value) 226*523fa7a6SAndroid Build Coastguard Worker if ( 227*523fa7a6SAndroid Build Coastguard Worker options.backend_options.backend_type 228*523fa7a6SAndroid Build Coastguard Worker == QnnExecuTorchBackendType.kHtpBackend 229*523fa7a6SAndroid Build Coastguard Worker and options.backend_options.htp_options.use_multi_contexts 230*523fa7a6SAndroid Build Coastguard Worker ): 231*523fa7a6SAndroid Build Coastguard Worker qnn_mgr = PyQnnManagerAdaptor.QnnManager( 232*523fa7a6SAndroid Build Coastguard Worker m.compile_specs[0].value, m.processed_bytes 233*523fa7a6SAndroid Build Coastguard Worker ) 234*523fa7a6SAndroid Build Coastguard Worker assert qnn_mgr.Init().value == 0, "failed to load context binary" 235*523fa7a6SAndroid Build Coastguard Worker max_sf_buf_size = max( 236*523fa7a6SAndroid Build Coastguard Worker max_sf_buf_size, qnn_mgr.GetSpillFillBufferSize() 237*523fa7a6SAndroid Build Coastguard Worker ) 238*523fa7a6SAndroid Build Coastguard Worker module_map[m] = options 239*523fa7a6SAndroid Build Coastguard Worker qnn_mgr.Destroy() 240*523fa7a6SAndroid Build Coastguard Worker return max_sf_buf_size, module_map 241*523fa7a6SAndroid Build Coastguard Worker 242*523fa7a6SAndroid Build Coastguard Worker def process_lowered_module(module): 243*523fa7a6SAndroid Build Coastguard Worker qnn_mgr = PyQnnManagerAdaptor.QnnManager( 244*523fa7a6SAndroid Build Coastguard Worker module.compile_specs[0].value, module.processed_bytes 245*523fa7a6SAndroid Build Coastguard Worker ) 246*523fa7a6SAndroid Build Coastguard Worker assert qnn_mgr.Init().value == 0, "failed to load context binary" 247*523fa7a6SAndroid Build Coastguard Worker spill_fill_size = qnn_mgr.GetSpillFillBufferSize() 248*523fa7a6SAndroid Build Coastguard Worker qnn_mgr.Destroy() 249*523fa7a6SAndroid Build Coastguard Worker return spill_fill_size, { 250*523fa7a6SAndroid Build Coastguard Worker module: flatbuffer_to_option(module.compile_specs[0].value) 251*523fa7a6SAndroid Build Coastguard Worker } 252*523fa7a6SAndroid Build Coastguard Worker 253*523fa7a6SAndroid Build Coastguard Worker dispatch = { 254*523fa7a6SAndroid Build Coastguard Worker ExportedProgram: process_exported_program, 255*523fa7a6SAndroid Build Coastguard Worker LoweredBackendModule: process_lowered_module, 256*523fa7a6SAndroid Build Coastguard Worker } 257*523fa7a6SAndroid Build Coastguard Worker return dispatch[type(program)](program) 258*523fa7a6SAndroid Build Coastguard Worker 259*523fa7a6SAndroid Build Coastguard Worker def update_program(max_sf_buf_size, module_map): 260*523fa7a6SAndroid Build Coastguard Worker def set_spec(module, options): 261*523fa7a6SAndroid Build Coastguard Worker spec = CompileSpec(QCOM_QNN_COMPILE_SPEC, option_to_flatbuffer(options)) 262*523fa7a6SAndroid Build Coastguard Worker if isinstance(module, ExportedProgram): 263*523fa7a6SAndroid Build Coastguard Worker module.compile_specs[0] = spec 264*523fa7a6SAndroid Build Coastguard Worker else: 265*523fa7a6SAndroid Build Coastguard Worker module._compile_specs[0] = spec 266*523fa7a6SAndroid Build Coastguard Worker 267*523fa7a6SAndroid Build Coastguard Worker for module, options in module_map.items(): 268*523fa7a6SAndroid Build Coastguard Worker options.backend_options.htp_options.max_sf_buf_size = max_sf_buf_size 269*523fa7a6SAndroid Build Coastguard Worker set_spec(module, options) 270*523fa7a6SAndroid Build Coastguard Worker 271*523fa7a6SAndroid Build Coastguard Worker if isinstance(exported_program, list): 272*523fa7a6SAndroid Build Coastguard Worker max_sf_size, modules_map = 0, {} 273*523fa7a6SAndroid Build Coastguard Worker for prog in exported_program: 274*523fa7a6SAndroid Build Coastguard Worker max_sf_buf_size, module_map = get_program_info(prog) 275*523fa7a6SAndroid Build Coastguard Worker max_sf_size = max(max_sf_size, max_sf_buf_size) 276*523fa7a6SAndroid Build Coastguard Worker modules_map.update(module_map) 277*523fa7a6SAndroid Build Coastguard Worker update_program(max_sf_size, modules_map) 278*523fa7a6SAndroid Build Coastguard Worker else: 279*523fa7a6SAndroid Build Coastguard Worker update_program(*get_program_info(exported_program)) 280*523fa7a6SAndroid Build Coastguard Worker 281*523fa7a6SAndroid Build Coastguard Worker 282*523fa7a6SAndroid Build Coastguard Workerdef get_decomp_table() -> Dict[torch._ops.OperatorBase, Callable]: 283*523fa7a6SAndroid Build Coastguard Worker source_decompositions = torch_core_aten_decompositions() 284*523fa7a6SAndroid Build Coastguard Worker # The below super ops are supported by QNN 285*523fa7a6SAndroid Build Coastguard Worker remove_decompositions = [ 286*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten.pixel_shuffle.default, 287*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten.pixel_unshuffle.default, 288*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten.hardsigmoid.default, 289*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten.hardswish.default, 290*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten._safe_softmax.default, 291*523fa7a6SAndroid Build Coastguard Worker ] 292*523fa7a6SAndroid Build Coastguard Worker 293*523fa7a6SAndroid Build Coastguard Worker for key in remove_decompositions: 294*523fa7a6SAndroid Build Coastguard Worker source_decompositions.pop(key) 295*523fa7a6SAndroid Build Coastguard Worker 296*523fa7a6SAndroid Build Coastguard Worker return source_decompositions 297*523fa7a6SAndroid Build Coastguard Worker 298*523fa7a6SAndroid Build Coastguard Worker 299*523fa7a6SAndroid Build Coastguard Workerdef _transform( 300*523fa7a6SAndroid Build Coastguard Worker edge_program: ExportedProgram, custom_pass_config: FrozenSet[str] = frozenset() 301*523fa7a6SAndroid Build Coastguard Worker) -> ExportedProgram: 302*523fa7a6SAndroid Build Coastguard Worker # currently ExirExportedProgram.transform does not accept 303*523fa7a6SAndroid Build Coastguard Worker # changes of input number which was caused by FoldQDQ 304*523fa7a6SAndroid Build Coastguard Worker # apply passes one by one here to avoid IR capture failure 305*523fa7a6SAndroid Build Coastguard Worker graph_module = edge_program.graph_module 306*523fa7a6SAndroid Build Coastguard Worker RemoveRedundancy()(graph_module) 307*523fa7a6SAndroid Build Coastguard Worker RecomposePixelUnshuffle()(graph_module) 308*523fa7a6SAndroid Build Coastguard Worker RecomposeRmsNorm()(graph_module) 309*523fa7a6SAndroid Build Coastguard Worker ConvertToLinear()(graph_module) 310*523fa7a6SAndroid Build Coastguard Worker ConvertPReLU(edge_program)(graph_module) 311*523fa7a6SAndroid Build Coastguard Worker ConvertBmmToMatmul()(graph_module) 312*523fa7a6SAndroid Build Coastguard Worker ConvertInterpolateWithUpsample2D()(graph_module) 313*523fa7a6SAndroid Build Coastguard Worker I64toI32(edge_program)(graph_module) 314*523fa7a6SAndroid Build Coastguard Worker AnnotateQuantAttrs( 315*523fa7a6SAndroid Build Coastguard Worker edge_program, QCOM_PASS_SKIP_ADVANCED_REQUANT in custom_pass_config 316*523fa7a6SAndroid Build Coastguard Worker )(graph_module) 317*523fa7a6SAndroid Build Coastguard Worker AnnotateAndQuantScalar(edge_program)(graph_module) 318*523fa7a6SAndroid Build Coastguard Worker AnnotateDecomposed(edge_program)(graph_module) 319*523fa7a6SAndroid Build Coastguard Worker FoldQDQ()(graph_module) 320*523fa7a6SAndroid Build Coastguard Worker # this pass is not necessary for network without layout-sensitive ops 321*523fa7a6SAndroid Build Coastguard Worker # enable defaultly will introduce overhead from extra view_copy nodes 322*523fa7a6SAndroid Build Coastguard Worker if QCOM_PASS_EXPAND_BROADCAST_SHAPE in custom_pass_config: 323*523fa7a6SAndroid Build Coastguard Worker ExpandBroadcastTensorShape()(graph_module) 324*523fa7a6SAndroid Build Coastguard Worker LayoutTransform(edge_program)(graph_module) 325*523fa7a6SAndroid Build Coastguard Worker ReplaceIndexPutInput(edge_program)(graph_module) 326*523fa7a6SAndroid Build Coastguard Worker 327*523fa7a6SAndroid Build Coastguard Worker # Since QDQ nodes are stripped, update graph signature again to validate program 328*523fa7a6SAndroid Build Coastguard Worker edge_program._graph_signature = _get_updated_graph_signature( 329*523fa7a6SAndroid Build Coastguard Worker edge_program.graph_signature, 330*523fa7a6SAndroid Build Coastguard Worker edge_program.graph_module, 331*523fa7a6SAndroid Build Coastguard Worker ) 332*523fa7a6SAndroid Build Coastguard Worker edge_program._validate() 333*523fa7a6SAndroid Build Coastguard Worker return edge_program 334*523fa7a6SAndroid Build Coastguard Worker 335*523fa7a6SAndroid Build Coastguard Worker 336*523fa7a6SAndroid Build Coastguard Workerdef capture_program( 337*523fa7a6SAndroid Build Coastguard Worker module: torch.nn.Module, 338*523fa7a6SAndroid Build Coastguard Worker inputs: Tuple[torch.Tensor], 339*523fa7a6SAndroid Build Coastguard Worker custom_pass_config: FrozenSet[str] = frozenset(), 340*523fa7a6SAndroid Build Coastguard Worker) -> exir.ExirExportedProgram: 341*523fa7a6SAndroid Build Coastguard Worker ep = torch.export.export(module, inputs) 342*523fa7a6SAndroid Build Coastguard Worker decomposed_ep = ep.run_decompositions(get_decomp_table()) 343*523fa7a6SAndroid Build Coastguard Worker # We choose call_operator by target in ConvertBinaryOpsWithScalar 344*523fa7a6SAndroid Build Coastguard Worker # because it is the same source_fn_stack for MultiheadAttention 345*523fa7a6SAndroid Build Coastguard Worker # TODO: Should modify the scalar op in the op builder instead of 346*523fa7a6SAndroid Build Coastguard Worker # using transformation 347*523fa7a6SAndroid Build Coastguard Worker core_ep = ExirExportedProgram(decomposed_ep, False) 348*523fa7a6SAndroid Build Coastguard Worker core_ep.transform(ConvertBinaryOpsWithScalar()) 349*523fa7a6SAndroid Build Coastguard Worker edge_ep = core_ep.to_edge(qnn_edge_config()) 350*523fa7a6SAndroid Build Coastguard Worker _transform(edge_ep.exported_program, custom_pass_config) 351*523fa7a6SAndroid Build Coastguard Worker return edge_ep 352*523fa7a6SAndroid Build Coastguard Worker 353*523fa7a6SAndroid Build Coastguard Worker 354*523fa7a6SAndroid Build Coastguard Workerdef _partition_graph_into_submodules(gm, subgm_tag, subgm_cb, ptn): 355*523fa7a6SAndroid Build Coastguard Worker from torch.fx.passes.utils.fuser_utils import ( 356*523fa7a6SAndroid Build Coastguard Worker erase_nodes, 357*523fa7a6SAndroid Build Coastguard Worker fuse_as_graphmodule, 358*523fa7a6SAndroid Build Coastguard Worker insert_subgm, 359*523fa7a6SAndroid Build Coastguard Worker legalize_graph, 360*523fa7a6SAndroid Build Coastguard Worker topo_sort, 361*523fa7a6SAndroid Build Coastguard Worker ) 362*523fa7a6SAndroid Build Coastguard Worker 363*523fa7a6SAndroid Build Coastguard Worker partitions = ptn.propose_partitions() 364*523fa7a6SAndroid Build Coastguard Worker # insert meta for each partition group 365*523fa7a6SAndroid Build Coastguard Worker for i, partition in enumerate(partitions): 366*523fa7a6SAndroid Build Coastguard Worker for node in partition.nodes: 367*523fa7a6SAndroid Build Coastguard Worker node.meta[subgm_tag] = i 368*523fa7a6SAndroid Build Coastguard Worker 369*523fa7a6SAndroid Build Coastguard Worker for i in range(len(partitions)): 370*523fa7a6SAndroid Build Coastguard Worker # find nodes with same group id in current graph 371*523fa7a6SAndroid Build Coastguard Worker node_list = [ 372*523fa7a6SAndroid Build Coastguard Worker node for node in gm.graph.nodes if node.meta.get(subgm_tag, "") == i 373*523fa7a6SAndroid Build Coastguard Worker ] 374*523fa7a6SAndroid Build Coastguard Worker # fuse group nodes into submodule 375*523fa7a6SAndroid Build Coastguard Worker sorted_nodes = topo_sort(node_list) 376*523fa7a6SAndroid Build Coastguard Worker submodule_name = f"{subgm_tag}_{i}" 377*523fa7a6SAndroid Build Coastguard Worker subgm, orig_inputs, orig_outputs = fuse_as_graphmodule( 378*523fa7a6SAndroid Build Coastguard Worker gm, sorted_nodes, submodule_name 379*523fa7a6SAndroid Build Coastguard Worker ) 380*523fa7a6SAndroid Build Coastguard Worker # insert submodule & trim group nodes 381*523fa7a6SAndroid Build Coastguard Worker gm = insert_subgm( 382*523fa7a6SAndroid Build Coastguard Worker gm, 383*523fa7a6SAndroid Build Coastguard Worker subgm_cb(subgm, submodule_name), 384*523fa7a6SAndroid Build Coastguard Worker orig_inputs, 385*523fa7a6SAndroid Build Coastguard Worker orig_outputs, 386*523fa7a6SAndroid Build Coastguard Worker ) 387*523fa7a6SAndroid Build Coastguard Worker erase_nodes(gm, sorted_nodes) 388*523fa7a6SAndroid Build Coastguard Worker legalize_graph(gm) 389*523fa7a6SAndroid Build Coastguard Worker 390*523fa7a6SAndroid Build Coastguard Worker gm.recompile() 391*523fa7a6SAndroid Build Coastguard Worker return gm 392*523fa7a6SAndroid Build Coastguard Worker 393*523fa7a6SAndroid Build Coastguard Worker 394*523fa7a6SAndroid Build Coastguard Workerdef _canonicalize_graph_with_lowered_module(gm, subgm_tag, ptn): 395*523fa7a6SAndroid Build Coastguard Worker from executorch.exir.backend.backend_api import to_backend 396*523fa7a6SAndroid Build Coastguard Worker 397*523fa7a6SAndroid Build Coastguard Worker # return lowered program for user to debug 398*523fa7a6SAndroid Build Coastguard Worker exported_progs = [] 399*523fa7a6SAndroid Build Coastguard Worker # partition each submodule which went through convert_pt2e 400*523fa7a6SAndroid Build Coastguard Worker for node in gm.graph.nodes: 401*523fa7a6SAndroid Build Coastguard Worker if node.op == "call_module" and subgm_tag in node.name: 402*523fa7a6SAndroid Build Coastguard Worker # obtain sample inputs through meta 403*523fa7a6SAndroid Build Coastguard Worker subgm_input = [ 404*523fa7a6SAndroid Build Coastguard Worker torch.ones(arg.meta["val"].shape, dtype=arg.meta["val"].dtype) 405*523fa7a6SAndroid Build Coastguard Worker for arg in node.args 406*523fa7a6SAndroid Build Coastguard Worker ] 407*523fa7a6SAndroid Build Coastguard Worker # program meets QNN backend requirement 408*523fa7a6SAndroid Build Coastguard Worker sub_prog = capture_program(gm.get_submodule(node.name), tuple(subgm_input)) 409*523fa7a6SAndroid Build Coastguard Worker # start lowering with given partitioner 410*523fa7a6SAndroid Build Coastguard Worker exported_progs.append(to_backend(sub_prog.exported_program, ptn)) 411*523fa7a6SAndroid Build Coastguard Worker # replace submodule with lowered module 412*523fa7a6SAndroid Build Coastguard Worker gm.set_submodule( 413*523fa7a6SAndroid Build Coastguard Worker node.name, 414*523fa7a6SAndroid Build Coastguard Worker exported_progs[-1].graph_module, 415*523fa7a6SAndroid Build Coastguard Worker ) 416*523fa7a6SAndroid Build Coastguard Worker # if node has multiple outputs, getitems will be default generated 417*523fa7a6SAndroid Build Coastguard Worker if all(n.target != operator.getitem for n in node.users): 418*523fa7a6SAndroid Build Coastguard Worker with gm.graph.inserting_after(node): 419*523fa7a6SAndroid Build Coastguard Worker getitem_node = gm.graph.call_function( 420*523fa7a6SAndroid Build Coastguard Worker operator.getitem, 421*523fa7a6SAndroid Build Coastguard Worker (node, 0), 422*523fa7a6SAndroid Build Coastguard Worker ) 423*523fa7a6SAndroid Build Coastguard Worker getitem_node.meta = node.meta 424*523fa7a6SAndroid Build Coastguard Worker node.replace_all_uses_with( 425*523fa7a6SAndroid Build Coastguard Worker replace_with=getitem_node, 426*523fa7a6SAndroid Build Coastguard Worker delete_user_cb=lambda user: user.target != operator.getitem, 427*523fa7a6SAndroid Build Coastguard Worker ) 428*523fa7a6SAndroid Build Coastguard Worker 429*523fa7a6SAndroid Build Coastguard Worker gm.recompile() 430*523fa7a6SAndroid Build Coastguard Worker return gm, exported_progs 431*523fa7a6SAndroid Build Coastguard Worker 432*523fa7a6SAndroid Build Coastguard Worker 433*523fa7a6SAndroid Build Coastguard Workerdef skip_annotation( 434*523fa7a6SAndroid Build Coastguard Worker nn_module: torch.nn.Module, 435*523fa7a6SAndroid Build Coastguard Worker quantizer, 436*523fa7a6SAndroid Build Coastguard Worker partitioner, 437*523fa7a6SAndroid Build Coastguard Worker sample_input: Tuple[torch.Tensor, ...], 438*523fa7a6SAndroid Build Coastguard Worker calibration_cb: Callable[[torch.fx.GraphModule], None], 439*523fa7a6SAndroid Build Coastguard Worker fp_node_id_set: set = None, 440*523fa7a6SAndroid Build Coastguard Worker fp_node_op_set: set = None, 441*523fa7a6SAndroid Build Coastguard Worker fallback_to_cpu: bool = True, 442*523fa7a6SAndroid Build Coastguard Worker): 443*523fa7a6SAndroid Build Coastguard Worker r""" 444*523fa7a6SAndroid Build Coastguard Worker Exclude speific operators from quantizer annotation. 445*523fa7a6SAndroid Build Coastguard Worker Skipped operators will defaultly stay in CPU, set 'fallback_to_cpu' 446*523fa7a6SAndroid Build Coastguard Worker to False for trying to delegate them with FP16 precision. 447*523fa7a6SAndroid Build Coastguard Worker 448*523fa7a6SAndroid Build Coastguard Worker e.g.: consider following graph: 449*523fa7a6SAndroid Build Coastguard Worker bias_1 weight_1 input_1 bias_2 weight_2 input_2 450*523fa7a6SAndroid Build Coastguard Worker | (placeholder) | | (placeholder) | 451*523fa7a6SAndroid Build Coastguard Worker \ | / \ | / 452*523fa7a6SAndroid Build Coastguard Worker \ | / \ | / 453*523fa7a6SAndroid Build Coastguard Worker \ | / \ | / 454*523fa7a6SAndroid Build Coastguard Worker conv2d_1 conv2d_2 455*523fa7a6SAndroid Build Coastguard Worker (torch.ops.aten.conv2d.default) 456*523fa7a6SAndroid Build Coastguard Worker \ / 457*523fa7a6SAndroid Build Coastguard Worker \ / 458*523fa7a6SAndroid Build Coastguard Worker \_______ _______/ 459*523fa7a6SAndroid Build Coastguard Worker add_1 460*523fa7a6SAndroid Build Coastguard Worker (torch.ops.aten.add.default) 461*523fa7a6SAndroid Build Coastguard Worker | 462*523fa7a6SAndroid Build Coastguard Worker output 463*523fa7a6SAndroid Build Coastguard Worker 464*523fa7a6SAndroid Build Coastguard Worker If user wants to skip convolution op by names with 465*523fa7a6SAndroid Build Coastguard Worker 'skip_node_id_set' = {"conv2d_1"} 466*523fa7a6SAndroid Build Coastguard Worker "bias_1 / weight_1 / input_1 / input_2 / conv2d_1" 467*523fa7a6SAndroid Build Coastguard Worker will be partitioned out and not annotated / lowered with QNN. 468*523fa7a6SAndroid Build Coastguard Worker 469*523fa7a6SAndroid Build Coastguard Worker [Generated graph] 470*523fa7a6SAndroid Build Coastguard Worker bias_1 weight_1 input_1 input_2 471*523fa7a6SAndroid Build Coastguard Worker | (placeholder) | | 472*523fa7a6SAndroid Build Coastguard Worker \ | / | 473*523fa7a6SAndroid Build Coastguard Worker \ | / | 474*523fa7a6SAndroid Build Coastguard Worker \ | / | 475*523fa7a6SAndroid Build Coastguard Worker conv2d_1 | 476*523fa7a6SAndroid Build Coastguard Worker \ / 477*523fa7a6SAndroid Build Coastguard Worker \ / 478*523fa7a6SAndroid Build Coastguard Worker \ / 479*523fa7a6SAndroid Build Coastguard Worker lowered_module_1 480*523fa7a6SAndroid Build Coastguard Worker (QNN fixed precision) 481*523fa7a6SAndroid Build Coastguard Worker | 482*523fa7a6SAndroid Build Coastguard Worker output 483*523fa7a6SAndroid Build Coastguard Worker 484*523fa7a6SAndroid Build Coastguard Worker If user wants to skip convolution op by target with 485*523fa7a6SAndroid Build Coastguard Worker 'skip_node_op_set' = {torch.ops.aten.conv2d.default} 486*523fa7a6SAndroid Build Coastguard Worker "bias_1 / weight_1 / input_1 / conv2d_1, 487*523fa7a6SAndroid Build Coastguard Worker bias_2 / weight_2 / input_2 / conv2d_2" 488*523fa7a6SAndroid Build Coastguard Worker will be partitioned out and not annotated / lowered with QNN. 489*523fa7a6SAndroid Build Coastguard Worker 490*523fa7a6SAndroid Build Coastguard Worker [Generated graph] 491*523fa7a6SAndroid Build Coastguard Worker bias_1 weight_1 input_1 bias_2 weight_2 input_2 492*523fa7a6SAndroid Build Coastguard Worker | (placeholder) | | (placeholder) | 493*523fa7a6SAndroid Build Coastguard Worker \ | / \ | / 494*523fa7a6SAndroid Build Coastguard Worker \ | / \ | / 495*523fa7a6SAndroid Build Coastguard Worker \ | / \ | / 496*523fa7a6SAndroid Build Coastguard Worker conv2d_1 conv2d_2 497*523fa7a6SAndroid Build Coastguard Worker (torch.ops.aten.conv2d.default) 498*523fa7a6SAndroid Build Coastguard Worker \ / 499*523fa7a6SAndroid Build Coastguard Worker \ / 500*523fa7a6SAndroid Build Coastguard Worker \__ __/ 501*523fa7a6SAndroid Build Coastguard Worker lowered_module_1 502*523fa7a6SAndroid Build Coastguard Worker (QNN fixed precision) 503*523fa7a6SAndroid Build Coastguard Worker | 504*523fa7a6SAndroid Build Coastguard Worker output 505*523fa7a6SAndroid Build Coastguard Worker 506*523fa7a6SAndroid Build Coastguard Worker If user wants to delegate the skipped conv2d from above graph 507*523fa7a6SAndroid Build Coastguard Worker with 'fallback_to_cpu' = False: 508*523fa7a6SAndroid Build Coastguard Worker 509*523fa7a6SAndroid Build Coastguard Worker [Generated graph] 510*523fa7a6SAndroid Build Coastguard Worker input_1 input_2 511*523fa7a6SAndroid Build Coastguard Worker (placeholder) (placeholder) 512*523fa7a6SAndroid Build Coastguard Worker | | 513*523fa7a6SAndroid Build Coastguard Worker \ / 514*523fa7a6SAndroid Build Coastguard Worker lowered_module_2 515*523fa7a6SAndroid Build Coastguard Worker (QNN fp16 precision) 516*523fa7a6SAndroid Build Coastguard Worker | 517*523fa7a6SAndroid Build Coastguard Worker | 518*523fa7a6SAndroid Build Coastguard Worker lowered_module_1 519*523fa7a6SAndroid Build Coastguard Worker (QNN fixed precision) 520*523fa7a6SAndroid Build Coastguard Worker | 521*523fa7a6SAndroid Build Coastguard Worker output 522*523fa7a6SAndroid Build Coastguard Worker 523*523fa7a6SAndroid Build Coastguard Worker Args: 524*523fa7a6SAndroid Build Coastguard Worker nn_module (torch.nn.Module): The module to be lowered. 525*523fa7a6SAndroid Build Coastguard Worker quantizer (QnnQuantizer): Instance of QnnQuantizer. 526*523fa7a6SAndroid Build Coastguard Worker partitioner (QnnPartitioner): Instance of QnnPartitioner. 527*523fa7a6SAndroid Build Coastguard Worker sample_input ((torch.Tensor, ...)): Sample input tensors for graph exporting. 528*523fa7a6SAndroid Build Coastguard Worker calibration_cb (callable): Callback function for user-defined calibration. 529*523fa7a6SAndroid Build Coastguard Worker fp_node_id_set ({str, ...}): Set of operator names to be left in fp precision. 530*523fa7a6SAndroid Build Coastguard Worker fp_node_op_set ({torch.ops.aten.xxx, ...}): Set of operator targets to be left in fp precision. 531*523fa7a6SAndroid Build Coastguard Worker fallback_to_cpu (bool): Whether to lower skipped nodes to fp16 or not. 532*523fa7a6SAndroid Build Coastguard Worker 533*523fa7a6SAndroid Build Coastguard Worker Returns: 534*523fa7a6SAndroid Build Coastguard Worker exported_programs: List of programs lowered to QnnBackend (quantized graphs only). 535*523fa7a6SAndroid Build Coastguard Worker """ 536*523fa7a6SAndroid Build Coastguard Worker from executorch.backends.qualcomm.serialization.qc_schema import ( 537*523fa7a6SAndroid Build Coastguard Worker QnnExecuTorchHtpPrecision, 538*523fa7a6SAndroid Build Coastguard Worker ) 539*523fa7a6SAndroid Build Coastguard Worker from executorch.backends.qualcomm.serialization.qc_schema_serialize import ( 540*523fa7a6SAndroid Build Coastguard Worker flatbuffer_to_option, 541*523fa7a6SAndroid Build Coastguard Worker ) 542*523fa7a6SAndroid Build Coastguard Worker from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e 543*523fa7a6SAndroid Build Coastguard Worker from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner 544*523fa7a6SAndroid Build Coastguard Worker 545*523fa7a6SAndroid Build Coastguard Worker def prepare_subgm(subgm, subgm_name): 546*523fa7a6SAndroid Build Coastguard Worker # prepare current submodule for quantization annotation 547*523fa7a6SAndroid Build Coastguard Worker subgm_prepared = prepare_pt2e(subgm, quantizer) 548*523fa7a6SAndroid Build Coastguard Worker # overwrite this attribute or name will be set to "GraphModule" 549*523fa7a6SAndroid Build Coastguard Worker # we could not identify each submodule if action is not performed 550*523fa7a6SAndroid Build Coastguard Worker subgm_prepared.__class__.__name__ = subgm_name 551*523fa7a6SAndroid Build Coastguard Worker return subgm_prepared 552*523fa7a6SAndroid Build Coastguard Worker 553*523fa7a6SAndroid Build Coastguard Worker fp_node_id_set = fp_node_id_set if fp_node_id_set is not None else set() 554*523fa7a6SAndroid Build Coastguard Worker fp_node_op_set = fp_node_op_set if fp_node_op_set is not None else set() 555*523fa7a6SAndroid Build Coastguard Worker graph_module = torch.export.export(nn_module, sample_input).module() 556*523fa7a6SAndroid Build Coastguard Worker # define node support type 557*523fa7a6SAndroid Build Coastguard Worker capability_partitioner = CapabilityBasedPartitioner( 558*523fa7a6SAndroid Build Coastguard Worker graph_module, 559*523fa7a6SAndroid Build Coastguard Worker _AnnotationSkipper(fp_node_id_set, fp_node_op_set), 560*523fa7a6SAndroid Build Coastguard Worker allows_single_node_partition=True, 561*523fa7a6SAndroid Build Coastguard Worker ) 562*523fa7a6SAndroid Build Coastguard Worker subgm_tag = "annotated_group" 563*523fa7a6SAndroid Build Coastguard Worker graph_module = _partition_graph_into_submodules( 564*523fa7a6SAndroid Build Coastguard Worker gm=graph_module, 565*523fa7a6SAndroid Build Coastguard Worker subgm_tag=subgm_tag, 566*523fa7a6SAndroid Build Coastguard Worker subgm_cb=prepare_subgm, 567*523fa7a6SAndroid Build Coastguard Worker ptn=capability_partitioner, 568*523fa7a6SAndroid Build Coastguard Worker ) 569*523fa7a6SAndroid Build Coastguard Worker # perform calibration 570*523fa7a6SAndroid Build Coastguard Worker calibration_cb(graph_module) 571*523fa7a6SAndroid Build Coastguard Worker # convert sub modules which went through prepare_pt2e 572*523fa7a6SAndroid Build Coastguard Worker for node in graph_module.graph.nodes: 573*523fa7a6SAndroid Build Coastguard Worker if node.op == "call_module": 574*523fa7a6SAndroid Build Coastguard Worker graph_module.set_submodule( 575*523fa7a6SAndroid Build Coastguard Worker node.name, convert_pt2e(graph_module.get_submodule(node.name)) 576*523fa7a6SAndroid Build Coastguard Worker ) 577*523fa7a6SAndroid Build Coastguard Worker # canonicalize graph for lowering again 578*523fa7a6SAndroid Build Coastguard Worker graph_module, exported_progs = _canonicalize_graph_with_lowered_module( 579*523fa7a6SAndroid Build Coastguard Worker gm=graph_module, 580*523fa7a6SAndroid Build Coastguard Worker subgm_tag=subgm_tag, 581*523fa7a6SAndroid Build Coastguard Worker ptn=partitioner, 582*523fa7a6SAndroid Build Coastguard Worker ) 583*523fa7a6SAndroid Build Coastguard Worker 584*523fa7a6SAndroid Build Coastguard Worker if not fallback_to_cpu: 585*523fa7a6SAndroid Build Coastguard Worker try: 586*523fa7a6SAndroid Build Coastguard Worker from executorch.exir.backend.partitioner import DelegationSpec 587*523fa7a6SAndroid Build Coastguard Worker 588*523fa7a6SAndroid Build Coastguard Worker # change HTP compiler spec for hardware to enable fp16 589*523fa7a6SAndroid Build Coastguard Worker qnn_option = generate_qnn_executorch_option( 590*523fa7a6SAndroid Build Coastguard Worker partitioner.compiler_specs_snapshot 591*523fa7a6SAndroid Build Coastguard Worker ) 592*523fa7a6SAndroid Build Coastguard Worker compile_option = flatbuffer_to_option(qnn_option) 593*523fa7a6SAndroid Build Coastguard Worker htp_options = compile_option.backend_options.htp_options 594*523fa7a6SAndroid Build Coastguard Worker htp_options.precision = QnnExecuTorchHtpPrecision.kHtpFp16 595*523fa7a6SAndroid Build Coastguard Worker partitioner.delegation_spec = DelegationSpec( 596*523fa7a6SAndroid Build Coastguard Worker "QnnBackend", 597*523fa7a6SAndroid Build Coastguard Worker [ 598*523fa7a6SAndroid Build Coastguard Worker CompileSpec( 599*523fa7a6SAndroid Build Coastguard Worker QCOM_QNN_COMPILE_SPEC, option_to_flatbuffer(compile_option) 600*523fa7a6SAndroid Build Coastguard Worker ) 601*523fa7a6SAndroid Build Coastguard Worker ], 602*523fa7a6SAndroid Build Coastguard Worker ) 603*523fa7a6SAndroid Build Coastguard Worker except: 604*523fa7a6SAndroid Build Coastguard Worker print( 605*523fa7a6SAndroid Build Coastguard Worker "Failed to change HTP compiler spec with 'use_fp16' as True," 606*523fa7a6SAndroid Build Coastguard Worker " skipped operators will fallback to cpu," 607*523fa7a6SAndroid Build Coastguard Worker ) 608*523fa7a6SAndroid Build Coastguard Worker return graph_module, exported_progs 609*523fa7a6SAndroid Build Coastguard Worker 610*523fa7a6SAndroid Build Coastguard Worker # try lowering skipped operator into fp16 611*523fa7a6SAndroid Build Coastguard Worker capability_partitioner = CapabilityBasedPartitioner( 612*523fa7a6SAndroid Build Coastguard Worker graph_module, 613*523fa7a6SAndroid Build Coastguard Worker _AnnotationSkipper(skip_annotated_submodule=True), 614*523fa7a6SAndroid Build Coastguard Worker allows_single_node_partition=True, 615*523fa7a6SAndroid Build Coastguard Worker ) 616*523fa7a6SAndroid Build Coastguard Worker subgm_tag = "skipped_group" 617*523fa7a6SAndroid Build Coastguard Worker graph_module = _partition_graph_into_submodules( 618*523fa7a6SAndroid Build Coastguard Worker gm=graph_module, 619*523fa7a6SAndroid Build Coastguard Worker subgm_tag=subgm_tag, 620*523fa7a6SAndroid Build Coastguard Worker subgm_cb=lambda subgm, _: subgm, 621*523fa7a6SAndroid Build Coastguard Worker ptn=capability_partitioner, 622*523fa7a6SAndroid Build Coastguard Worker ) 623*523fa7a6SAndroid Build Coastguard Worker graph_module, exported_progs_fp = _canonicalize_graph_with_lowered_module( 624*523fa7a6SAndroid Build Coastguard Worker gm=graph_module, 625*523fa7a6SAndroid Build Coastguard Worker subgm_tag=subgm_tag, 626*523fa7a6SAndroid Build Coastguard Worker ptn=partitioner, 627*523fa7a6SAndroid Build Coastguard Worker ) 628*523fa7a6SAndroid Build Coastguard Worker exported_progs.extend(exported_progs_fp) 629*523fa7a6SAndroid Build Coastguard Worker 630*523fa7a6SAndroid Build Coastguard Worker return graph_module, exported_progs 631*523fa7a6SAndroid Build Coastguard Worker 632*523fa7a6SAndroid Build Coastguard Worker 633*523fa7a6SAndroid Build Coastguard Workerdef from_context_binary( # noqa: C901 634*523fa7a6SAndroid Build Coastguard Worker ctx_path: str | bytes, 635*523fa7a6SAndroid Build Coastguard Worker op_name: str, 636*523fa7a6SAndroid Build Coastguard Worker soc_model: QcomChipset = QcomChipset.SM8650, 637*523fa7a6SAndroid Build Coastguard Worker custom_info: Dict = None, 638*523fa7a6SAndroid Build Coastguard Worker): 639*523fa7a6SAndroid Build Coastguard Worker from pathlib import Path 640*523fa7a6SAndroid Build Coastguard Worker 641*523fa7a6SAndroid Build Coastguard Worker def implement_op(custom_op, op_name, outputs): 642*523fa7a6SAndroid Build Coastguard Worker @torch.library.impl( 643*523fa7a6SAndroid Build Coastguard Worker custom_op, str(op_name), dispatch_key="CompositeExplicitAutograd" 644*523fa7a6SAndroid Build Coastguard Worker ) 645*523fa7a6SAndroid Build Coastguard Worker def op_impl(inputs: List[torch.Tensor]): 646*523fa7a6SAndroid Build Coastguard Worker return tuple( 647*523fa7a6SAndroid Build Coastguard Worker torch.zeros(tuple(v.shape), device="meta", dtype=v.dtype) 648*523fa7a6SAndroid Build Coastguard Worker for v in outputs.values() 649*523fa7a6SAndroid Build Coastguard Worker ) 650*523fa7a6SAndroid Build Coastguard Worker 651*523fa7a6SAndroid Build Coastguard Worker def build_graph(inputs, outputs): 652*523fa7a6SAndroid Build Coastguard Worker # custom op declaration 653*523fa7a6SAndroid Build Coastguard Worker inputs_str = "Tensor[] inputs" 654*523fa7a6SAndroid Build Coastguard Worker func_proto = f"{op_name}({inputs_str}) -> Any" 655*523fa7a6SAndroid Build Coastguard Worker custom_op = Library(OpContextLoader.namespace, "FRAGMENT") 656*523fa7a6SAndroid Build Coastguard Worker custom_op.define(func_proto) 657*523fa7a6SAndroid Build Coastguard Worker # custom op implementation 658*523fa7a6SAndroid Build Coastguard Worker implement_op(custom_op, op_name, outputs) 659*523fa7a6SAndroid Build Coastguard Worker 660*523fa7a6SAndroid Build Coastguard Worker # model architecture mimicking context binary 661*523fa7a6SAndroid Build Coastguard Worker class Model(torch.nn.Module): 662*523fa7a6SAndroid Build Coastguard Worker def forward(self, *inputs): 663*523fa7a6SAndroid Build Coastguard Worker return getattr( 664*523fa7a6SAndroid Build Coastguard Worker getattr(torch.ops, OpContextLoader.namespace), op_name 665*523fa7a6SAndroid Build Coastguard Worker ).default(inputs) 666*523fa7a6SAndroid Build Coastguard Worker 667*523fa7a6SAndroid Build Coastguard Worker model = Model() 668*523fa7a6SAndroid Build Coastguard Worker prog = torch.export.export(model, tuple(inputs.values())) 669*523fa7a6SAndroid Build Coastguard Worker # bookkeeping for variables' life cycle 670*523fa7a6SAndroid Build Coastguard Worker return { 671*523fa7a6SAndroid Build Coastguard Worker "custom_op": custom_op, 672*523fa7a6SAndroid Build Coastguard Worker "custom_module": model, 673*523fa7a6SAndroid Build Coastguard Worker "exported_program": prog, 674*523fa7a6SAndroid Build Coastguard Worker } 675*523fa7a6SAndroid Build Coastguard Worker 676*523fa7a6SAndroid Build Coastguard Worker def build_tensor(tensors, dtype_map): 677*523fa7a6SAndroid Build Coastguard Worker ret = OrderedDict() 678*523fa7a6SAndroid Build Coastguard Worker for t in tensors: 679*523fa7a6SAndroid Build Coastguard Worker dtype = t.GetDataType() 680*523fa7a6SAndroid Build Coastguard Worker dtype_torch = dtype_map.get(dtype, None) 681*523fa7a6SAndroid Build Coastguard Worker assert dtype_torch is not None, f"unknown qnn data type {dtype}" 682*523fa7a6SAndroid Build Coastguard Worker ret[t.GetName()] = torch.zeros(tuple(t.GetDims()), dtype=dtype_torch) 683*523fa7a6SAndroid Build Coastguard Worker 684*523fa7a6SAndroid Build Coastguard Worker return ret 685*523fa7a6SAndroid Build Coastguard Worker 686*523fa7a6SAndroid Build Coastguard Worker def preprocess_binary(ctx_bin, compiler_specs): 687*523fa7a6SAndroid Build Coastguard Worker qnn_mgr = PyQnnManagerAdaptor.QnnManager( 688*523fa7a6SAndroid Build Coastguard Worker generate_qnn_executorch_option(compiler_specs), 689*523fa7a6SAndroid Build Coastguard Worker ) 690*523fa7a6SAndroid Build Coastguard Worker return bytes(qnn_mgr.MakeBinaryInfo(ctx_bin)) 691*523fa7a6SAndroid Build Coastguard Worker 692*523fa7a6SAndroid Build Coastguard Worker # dummy compiler spec would be fine, since we're not compiling 693*523fa7a6SAndroid Build Coastguard Worker backend_options = generate_htp_compiler_spec(use_fp16=False) 694*523fa7a6SAndroid Build Coastguard Worker compiler_specs = generate_qnn_executorch_compiler_spec( 695*523fa7a6SAndroid Build Coastguard Worker soc_model=soc_model, 696*523fa7a6SAndroid Build Coastguard Worker backend_options=backend_options, 697*523fa7a6SAndroid Build Coastguard Worker is_from_context_binary=True, 698*523fa7a6SAndroid Build Coastguard Worker ) 699*523fa7a6SAndroid Build Coastguard Worker 700*523fa7a6SAndroid Build Coastguard Worker ctx_bin = ( 701*523fa7a6SAndroid Build Coastguard Worker ctx_path 702*523fa7a6SAndroid Build Coastguard Worker if not isinstance(ctx_path, str) 703*523fa7a6SAndroid Build Coastguard Worker else preprocess_binary(Path(f"{ctx_path}").read_bytes(), compiler_specs) 704*523fa7a6SAndroid Build Coastguard Worker ) 705*523fa7a6SAndroid Build Coastguard Worker 706*523fa7a6SAndroid Build Coastguard Worker dtype_map = {} 707*523fa7a6SAndroid Build Coastguard Worker for type_map in (QNN_QUANT_TYPE_MAP, QNN_TENSOR_TYPE_MAP): 708*523fa7a6SAndroid Build Coastguard Worker for k, v in type_map.items(): 709*523fa7a6SAndroid Build Coastguard Worker dtype_map.setdefault(v, k) 710*523fa7a6SAndroid Build Coastguard Worker 711*523fa7a6SAndroid Build Coastguard Worker if custom_info is not None: 712*523fa7a6SAndroid Build Coastguard Worker # since some context binaries might fail to open on host 713*523fa7a6SAndroid Build Coastguard Worker # if they are compiled with special flags: 714*523fa7a6SAndroid Build Coastguard Worker # e.g. weight sharing 715*523fa7a6SAndroid Build Coastguard Worker # use custom information here instead 716*523fa7a6SAndroid Build Coastguard Worker inputs = build_tensor(custom_info["graph_inputs"], dtype_map) 717*523fa7a6SAndroid Build Coastguard Worker outputs = build_tensor(custom_info["graph_outputs"], dtype_map) 718*523fa7a6SAndroid Build Coastguard Worker graph_name = custom_info["graph_name"] 719*523fa7a6SAndroid Build Coastguard Worker else: 720*523fa7a6SAndroid Build Coastguard Worker # get context-binary io tensor info through qnn manager 721*523fa7a6SAndroid Build Coastguard Worker qnn_mgr = PyQnnManagerAdaptor.QnnManager( 722*523fa7a6SAndroid Build Coastguard Worker generate_qnn_executorch_option(compiler_specs), 723*523fa7a6SAndroid Build Coastguard Worker ctx_bin, 724*523fa7a6SAndroid Build Coastguard Worker ) 725*523fa7a6SAndroid Build Coastguard Worker assert qnn_mgr.Init().value == 0, "failed to load context binary" 726*523fa7a6SAndroid Build Coastguard Worker # assume we only have one graph in current context 727*523fa7a6SAndroid Build Coastguard Worker graph_name = qnn_mgr.GetGraphNames()[0] 728*523fa7a6SAndroid Build Coastguard Worker qnn_mgr.AllocateTensor(graph_name) 729*523fa7a6SAndroid Build Coastguard Worker inputs = build_tensor(qnn_mgr.GetGraphInputs(graph_name), dtype_map) 730*523fa7a6SAndroid Build Coastguard Worker outputs = build_tensor(qnn_mgr.GetGraphOutputs(graph_name), dtype_map) 731*523fa7a6SAndroid Build Coastguard Worker qnn_mgr.Destroy() 732*523fa7a6SAndroid Build Coastguard Worker 733*523fa7a6SAndroid Build Coastguard Worker # generate graph specific for loading context 734*523fa7a6SAndroid Build Coastguard Worker bundle_prog = build_graph(inputs, outputs) 735*523fa7a6SAndroid Build Coastguard Worker bundle_prog.update({"inputs": inputs, "outputs": outputs}) 736*523fa7a6SAndroid Build Coastguard Worker edge_prog_mgr = to_edge( 737*523fa7a6SAndroid Build Coastguard Worker programs={graph_name: bundle_prog["exported_program"]}, 738*523fa7a6SAndroid Build Coastguard Worker # do not alter name for custom op 739*523fa7a6SAndroid Build Coastguard Worker compile_config=EdgeCompileConfig(_use_edge_ops=False), 740*523fa7a6SAndroid Build Coastguard Worker ) 741*523fa7a6SAndroid Build Coastguard Worker # update meta with context binary 742*523fa7a6SAndroid Build Coastguard Worker for n in edge_prog_mgr._edge_programs[graph_name].graph.nodes: 743*523fa7a6SAndroid Build Coastguard Worker if n.op == "call_function" and OpContextLoader.namespace in str(n.target): 744*523fa7a6SAndroid Build Coastguard Worker n.meta[OpContextLoader.meta_ctx_bin] = ctx_bin 745*523fa7a6SAndroid Build Coastguard Worker break 746*523fa7a6SAndroid Build Coastguard Worker 747*523fa7a6SAndroid Build Coastguard Worker bundle_prog["edge_program_manager"] = edge_prog_mgr.to_backend( 748*523fa7a6SAndroid Build Coastguard Worker QnnPartitioner(compiler_specs) 749*523fa7a6SAndroid Build Coastguard Worker ) 750*523fa7a6SAndroid Build Coastguard Worker return bundle_prog 751*523fa7a6SAndroid Build Coastguard Worker 752*523fa7a6SAndroid Build Coastguard Worker 753*523fa7a6SAndroid Build Coastguard Workerdef draw_graph(title, path, graph_module: torch.fx.GraphModule): 754*523fa7a6SAndroid Build Coastguard Worker graph = passes.graph_drawer.FxGraphDrawer(graph_module, title) 755*523fa7a6SAndroid Build Coastguard Worker with open(f"{path}/{title}.svg", "wb") as f: 756*523fa7a6SAndroid Build Coastguard Worker f.write(graph.get_dot_graph().create_svg()) 757*523fa7a6SAndroid Build Coastguard Worker 758*523fa7a6SAndroid Build Coastguard Worker 759*523fa7a6SAndroid Build Coastguard Workerdef generate_multi_graph_program( 760*523fa7a6SAndroid Build Coastguard Worker compiler_specs: List[CompileSpec], 761*523fa7a6SAndroid Build Coastguard Worker processed_bytes: List[bytes], 762*523fa7a6SAndroid Build Coastguard Worker backend_config: ExecutorchBackendConfig = None, 763*523fa7a6SAndroid Build Coastguard Worker) -> ExecutorchProgramManager: 764*523fa7a6SAndroid Build Coastguard Worker # compile multiple graphs in qcir into single context binary 765*523fa7a6SAndroid Build Coastguard Worker graph_inputs, graph_outputs = {}, {} 766*523fa7a6SAndroid Build Coastguard Worker qnn_mgr = PyQnnManagerAdaptor.QnnManager( 767*523fa7a6SAndroid Build Coastguard Worker generate_qnn_executorch_option(compiler_specs), processed_bytes 768*523fa7a6SAndroid Build Coastguard Worker ) 769*523fa7a6SAndroid Build Coastguard Worker assert qnn_mgr.Init().value == 0, "failed to load processed bytes" 770*523fa7a6SAndroid Build Coastguard Worker binary_info = bytes(qnn_mgr.Compile()) 771*523fa7a6SAndroid Build Coastguard Worker assert len(binary_info) != 0, "failed to generate QNN context binary" 772*523fa7a6SAndroid Build Coastguard Worker graph_names = qnn_mgr.GetGraphNames() 773*523fa7a6SAndroid Build Coastguard Worker for graph_name in graph_names: 774*523fa7a6SAndroid Build Coastguard Worker graph_inputs[graph_name] = qnn_mgr.GetGraphInputs(graph_name) 775*523fa7a6SAndroid Build Coastguard Worker graph_outputs[graph_name] = qnn_mgr.GetGraphOutputs(graph_name) 776*523fa7a6SAndroid Build Coastguard Worker qnn_mgr.Destroy() 777*523fa7a6SAndroid Build Coastguard Worker 778*523fa7a6SAndroid Build Coastguard Worker # build custom ops with different graph signatures 779*523fa7a6SAndroid Build Coastguard Worker compiler_options = flatbuffer_to_option(compiler_specs[0].value) 780*523fa7a6SAndroid Build Coastguard Worker bundle_progs = [ 781*523fa7a6SAndroid Build Coastguard Worker from_context_binary( 782*523fa7a6SAndroid Build Coastguard Worker ctx_path=binary_info, 783*523fa7a6SAndroid Build Coastguard Worker op_name=f"loader_{graph_name}", 784*523fa7a6SAndroid Build Coastguard Worker soc_model=compiler_options.soc_info.soc_model, 785*523fa7a6SAndroid Build Coastguard Worker custom_info={ 786*523fa7a6SAndroid Build Coastguard Worker "graph_inputs": graph_inputs[graph_name], 787*523fa7a6SAndroid Build Coastguard Worker "graph_outputs": graph_outputs[graph_name], 788*523fa7a6SAndroid Build Coastguard Worker "graph_name": graph_name, 789*523fa7a6SAndroid Build Coastguard Worker }, 790*523fa7a6SAndroid Build Coastguard Worker ) 791*523fa7a6SAndroid Build Coastguard Worker for graph_name in graph_names 792*523fa7a6SAndroid Build Coastguard Worker ] 793*523fa7a6SAndroid Build Coastguard Worker # leverage ExecutorchProgramManager for generating pte with multi-methods 794*523fa7a6SAndroid Build Coastguard Worker edge_prog_mgr = to_edge( 795*523fa7a6SAndroid Build Coastguard Worker programs={ 796*523fa7a6SAndroid Build Coastguard Worker graph_name: bundle_prog["exported_program"] 797*523fa7a6SAndroid Build Coastguard Worker for graph_name, bundle_prog in zip(graph_names, bundle_progs) 798*523fa7a6SAndroid Build Coastguard Worker }, 799*523fa7a6SAndroid Build Coastguard Worker # do not alter name for custom op 800*523fa7a6SAndroid Build Coastguard Worker compile_config=EdgeCompileConfig(_use_edge_ops=False), 801*523fa7a6SAndroid Build Coastguard Worker ) 802*523fa7a6SAndroid Build Coastguard Worker # restore meta losed in generating EdgeProgramManager 803*523fa7a6SAndroid Build Coastguard Worker for graph_name in graph_names: 804*523fa7a6SAndroid Build Coastguard Worker for n in edge_prog_mgr._edge_programs[graph_name].graph.nodes: 805*523fa7a6SAndroid Build Coastguard Worker if graph_name in n.name: 806*523fa7a6SAndroid Build Coastguard Worker n.meta[OpContextLoader.meta_ctx_bin] = binary_info 807*523fa7a6SAndroid Build Coastguard Worker break 808*523fa7a6SAndroid Build Coastguard Worker 809*523fa7a6SAndroid Build Coastguard Worker return edge_prog_mgr.to_backend(QnnPartitioner(compiler_specs)).to_executorch( 810*523fa7a6SAndroid Build Coastguard Worker config=backend_config or ExecutorchBackendConfig() 811*523fa7a6SAndroid Build Coastguard Worker ) 812*523fa7a6SAndroid Build Coastguard Worker 813*523fa7a6SAndroid Build Coastguard Worker 814*523fa7a6SAndroid Build Coastguard Workerdef generate_htp_compiler_spec( 815*523fa7a6SAndroid Build Coastguard Worker use_fp16: bool, 816*523fa7a6SAndroid Build Coastguard Worker use_dlbc: bool = False, 817*523fa7a6SAndroid Build Coastguard Worker use_multi_contexts: bool = False, 818*523fa7a6SAndroid Build Coastguard Worker) -> QnnExecuTorchBackendOptions: 819*523fa7a6SAndroid Build Coastguard Worker """ 820*523fa7a6SAndroid Build Coastguard Worker Helper function generating backend options for QNN HTP 821*523fa7a6SAndroid Build Coastguard Worker 822*523fa7a6SAndroid Build Coastguard Worker Args: 823*523fa7a6SAndroid Build Coastguard Worker use_fp16: If true, the model is compiled to QNN HTP fp16 runtime. 824*523fa7a6SAndroid Build Coastguard Worker Note that not all SoC support QNN HTP fp16. Only premium tier SoC 825*523fa7a6SAndroid Build Coastguard Worker like Snapdragon 8 Gen 1 or newer can support HTP fp16. 826*523fa7a6SAndroid Build Coastguard Worker use_dlbc: Deep Learning Bandwidth Compression allows inputs to be 827*523fa7a6SAndroid Build Coastguard Worker compressed, such that the processing bandwidth can be lowered. 828*523fa7a6SAndroid Build Coastguard Worker use_multi_contexts: When multiple contexts are generated inside the same 829*523fa7a6SAndroid Build Coastguard Worker pte, it is possible to reserve a single spill-fill allocation that 830*523fa7a6SAndroid Build Coastguard Worker could be re-used across all the splits. 831*523fa7a6SAndroid Build Coastguard Worker 832*523fa7a6SAndroid Build Coastguard Worker Returns: 833*523fa7a6SAndroid Build Coastguard Worker QnnExecuTorchHtpBackendOptions: backend options for QNN HTP. 834*523fa7a6SAndroid Build Coastguard Worker """ 835*523fa7a6SAndroid Build Coastguard Worker htp_options = QnnExecuTorchHtpBackendOptions() 836*523fa7a6SAndroid Build Coastguard Worker htp_options.precision = ( 837*523fa7a6SAndroid Build Coastguard Worker QnnExecuTorchHtpPrecision.kHtpFp16 838*523fa7a6SAndroid Build Coastguard Worker if use_fp16 839*523fa7a6SAndroid Build Coastguard Worker else QnnExecuTorchHtpPrecision.kHtpQuantized 840*523fa7a6SAndroid Build Coastguard Worker ) 841*523fa7a6SAndroid Build Coastguard Worker # This actually is not an option which can affect the compiled blob. 842*523fa7a6SAndroid Build Coastguard Worker # But we don't have other place to pass this option at execution stage. 843*523fa7a6SAndroid Build Coastguard Worker # TODO: enable voting mechanism in runtime and make this as an option 844*523fa7a6SAndroid Build Coastguard Worker htp_options.performance_mode = QnnExecuTorchHtpPerformanceMode.kHtpBurst 845*523fa7a6SAndroid Build Coastguard Worker htp_options.use_multi_contexts = use_multi_contexts 846*523fa7a6SAndroid Build Coastguard Worker htp_options.use_dlbc = use_dlbc 847*523fa7a6SAndroid Build Coastguard Worker return QnnExecuTorchBackendOptions( 848*523fa7a6SAndroid Build Coastguard Worker backend_type=QnnExecuTorchBackendType.kHtpBackend, 849*523fa7a6SAndroid Build Coastguard Worker htp_options=htp_options, 850*523fa7a6SAndroid Build Coastguard Worker ) 851*523fa7a6SAndroid Build Coastguard Worker 852*523fa7a6SAndroid Build Coastguard Worker 853*523fa7a6SAndroid Build Coastguard Workerdef generate_qnn_executorch_compiler_spec( 854*523fa7a6SAndroid Build Coastguard Worker soc_model: QcomChipset, 855*523fa7a6SAndroid Build Coastguard Worker backend_options: QnnExecuTorchBackendOptions, 856*523fa7a6SAndroid Build Coastguard Worker debug: bool = False, 857*523fa7a6SAndroid Build Coastguard Worker saver: bool = False, 858*523fa7a6SAndroid Build Coastguard Worker online_prepare: bool = False, 859*523fa7a6SAndroid Build Coastguard Worker dump_intermediate_outputs: bool = False, 860*523fa7a6SAndroid Build Coastguard Worker profile: bool = False, 861*523fa7a6SAndroid Build Coastguard Worker optrace: bool = False, 862*523fa7a6SAndroid Build Coastguard Worker shared_buffer: bool = False, 863*523fa7a6SAndroid Build Coastguard Worker is_from_context_binary: bool = False, 864*523fa7a6SAndroid Build Coastguard Worker multiple_graphs: bool = False, 865*523fa7a6SAndroid Build Coastguard Worker graph_name: str = "forward", 866*523fa7a6SAndroid Build Coastguard Worker) -> List[CompileSpec]: 867*523fa7a6SAndroid Build Coastguard Worker """ 868*523fa7a6SAndroid Build Coastguard Worker Helper function generating compiler specs for Qualcomm AI Engine Direct 869*523fa7a6SAndroid Build Coastguard Worker 870*523fa7a6SAndroid Build Coastguard Worker Args: 871*523fa7a6SAndroid Build Coastguard Worker soc_model: The SoC you plan to run the compiled model. Please check 872*523fa7a6SAndroid Build Coastguard Worker QcomChipset for supported SoC. 873*523fa7a6SAndroid Build Coastguard Worker SM8450 (Snapdragon 8 Gen 1) 874*523fa7a6SAndroid Build Coastguard Worker SM8475(Snapdragon 8 Gen 1+) 875*523fa7a6SAndroid Build Coastguard Worker SM8550(Snapdragon 8 Gen 2) 876*523fa7a6SAndroid Build Coastguard Worker SM8650(Snapdragon 8 Gen 3) 877*523fa7a6SAndroid Build Coastguard Worker backend_options: Options required by different backends. 878*523fa7a6SAndroid Build Coastguard Worker debug: Enable verbose logging. Disclaimer: this option must change in 879*523fa7a6SAndroid Build Coastguard Worker the near future. 880*523fa7a6SAndroid Build Coastguard Worker online_prepare: Compose QNN graph on device if set to True 881*523fa7a6SAndroid Build Coastguard Worker saver: Instead of compiling the model, run QNN Saver. Please check 882*523fa7a6SAndroid Build Coastguard Worker documents of Qualcomm AI Engine Direct SDK. This feature is usually 883*523fa7a6SAndroid Build Coastguard Worker for debugging purpose. 884*523fa7a6SAndroid Build Coastguard Worker dump_intermediate_outputs: If tensor dump is enabled, all intermediate tensors output will be dumped. 885*523fa7a6SAndroid Build Coastguard Worker This option exists for debugging accuracy issues 886*523fa7a6SAndroid Build Coastguard Worker profile: Enable profile the performance of per operator. 887*523fa7a6SAndroid Build Coastguard Worker Note that for now only support kProfileDetailed to 888*523fa7a6SAndroid Build Coastguard Worker profile the performance of each operator with cycle unit. 889*523fa7a6SAndroid Build Coastguard Worker shared_buffer: Enables usage of shared buffer between application 890*523fa7a6SAndroid Build Coastguard Worker and backend for graph I/O. 891*523fa7a6SAndroid Build Coastguard Worker is_from_context_binary: True if current graph comes from pre-built context binary. 892*523fa7a6SAndroid Build Coastguard Worker multiple_graphs: True if multiple methods are expected to have in single .pte file. 893*523fa7a6SAndroid Build Coastguard Worker Please see test cases for post-processing example. 894*523fa7a6SAndroid Build Coastguard Worker graph_name: Assign unique graph name if 'multiple_graphs' is used. 895*523fa7a6SAndroid Build Coastguard Worker 896*523fa7a6SAndroid Build Coastguard Worker Returns: 897*523fa7a6SAndroid Build Coastguard Worker List[CompileSpec]: Compiler specs for Qualcomm AI Engine Direct. 898*523fa7a6SAndroid Build Coastguard Worker 899*523fa7a6SAndroid Build Coastguard Worker Raises: 900*523fa7a6SAndroid Build Coastguard Worker ValueError: The value QcomChipset is currently not supported. 901*523fa7a6SAndroid Build Coastguard Worker ValueError: Confliction between compiler specs. 902*523fa7a6SAndroid Build Coastguard Worker """ 903*523fa7a6SAndroid Build Coastguard Worker _supported_soc_models = {soc_model.value for soc_model in QcomChipset} 904*523fa7a6SAndroid Build Coastguard Worker if soc_model not in _supported_soc_models: 905*523fa7a6SAndroid Build Coastguard Worker raise ValueError(f"unknown SoC model for QNN: {soc_model}") 906*523fa7a6SAndroid Build Coastguard Worker 907*523fa7a6SAndroid Build Coastguard Worker if profile and dump_intermediate_outputs: 908*523fa7a6SAndroid Build Coastguard Worker warnings.warn( 909*523fa7a6SAndroid Build Coastguard Worker "It is not recommended to turn on both profiling and dump_intermediate_outputs the same time" 910*523fa7a6SAndroid Build Coastguard Worker ", because dump_intermediate_outputs will cause performance drop.", 911*523fa7a6SAndroid Build Coastguard Worker stacklevel=1, 912*523fa7a6SAndroid Build Coastguard Worker ) 913*523fa7a6SAndroid Build Coastguard Worker 914*523fa7a6SAndroid Build Coastguard Worker qnn_executorch_options = QnnExecuTorchOptions( 915*523fa7a6SAndroid Build Coastguard Worker _soc_info_table[soc_model], backend_options 916*523fa7a6SAndroid Build Coastguard Worker ) 917*523fa7a6SAndroid Build Coastguard Worker qnn_executorch_options.graph_name = graph_name 918*523fa7a6SAndroid Build Coastguard Worker qnn_executorch_options.log_level = ( 919*523fa7a6SAndroid Build Coastguard Worker QnnExecuTorchLogLevel.kLogLevelDebug 920*523fa7a6SAndroid Build Coastguard Worker if debug 921*523fa7a6SAndroid Build Coastguard Worker else QnnExecuTorchLogLevel.kLogLevelWarn 922*523fa7a6SAndroid Build Coastguard Worker ) 923*523fa7a6SAndroid Build Coastguard Worker 924*523fa7a6SAndroid Build Coastguard Worker qnn_executorch_options.dump_intermediate_outputs = dump_intermediate_outputs 925*523fa7a6SAndroid Build Coastguard Worker 926*523fa7a6SAndroid Build Coastguard Worker if saver: 927*523fa7a6SAndroid Build Coastguard Worker qnn_executorch_options.library_path = "libQnnSaver.so" 928*523fa7a6SAndroid Build Coastguard Worker 929*523fa7a6SAndroid Build Coastguard Worker if optrace: 930*523fa7a6SAndroid Build Coastguard Worker qnn_executorch_options.profile_level = QnnExecuTorchProfileLevel.kProfileOptrace 931*523fa7a6SAndroid Build Coastguard Worker elif profile: 932*523fa7a6SAndroid Build Coastguard Worker qnn_executorch_options.profile_level = ( 933*523fa7a6SAndroid Build Coastguard Worker QnnExecuTorchProfileLevel.kProfileDetailed 934*523fa7a6SAndroid Build Coastguard Worker ) 935*523fa7a6SAndroid Build Coastguard Worker else: 936*523fa7a6SAndroid Build Coastguard Worker qnn_executorch_options.profile_level = QnnExecuTorchProfileLevel.kProfileOff 937*523fa7a6SAndroid Build Coastguard Worker 938*523fa7a6SAndroid Build Coastguard Worker if ( 939*523fa7a6SAndroid Build Coastguard Worker online_prepare 940*523fa7a6SAndroid Build Coastguard Worker and backend_options.backend_type == QnnExecuTorchBackendType.kHtpBackend 941*523fa7a6SAndroid Build Coastguard Worker and backend_options.htp_options.use_multi_contexts 942*523fa7a6SAndroid Build Coastguard Worker ): 943*523fa7a6SAndroid Build Coastguard Worker raise ValueError( 944*523fa7a6SAndroid Build Coastguard Worker "'use_multi_context' could not function in online prepare mode, " 945*523fa7a6SAndroid Build Coastguard Worker "please set 'online_prepare' to False" 946*523fa7a6SAndroid Build Coastguard Worker ) 947*523fa7a6SAndroid Build Coastguard Worker 948*523fa7a6SAndroid Build Coastguard Worker qnn_executorch_options.shared_buffer = shared_buffer 949*523fa7a6SAndroid Build Coastguard Worker qnn_executorch_options.online_prepare = online_prepare 950*523fa7a6SAndroid Build Coastguard Worker qnn_executorch_options.is_from_context_binary = is_from_context_binary 951*523fa7a6SAndroid Build Coastguard Worker qnn_executorch_options.multiple_graphs = multiple_graphs 952*523fa7a6SAndroid Build Coastguard Worker 953*523fa7a6SAndroid Build Coastguard Worker if multiple_graphs: 954*523fa7a6SAndroid Build Coastguard Worker # enable weight sharing mechanism if multiple graphs appear 955*523fa7a6SAndroid Build Coastguard Worker if backend_options.backend_type == QnnExecuTorchBackendType.kHtpBackend: 956*523fa7a6SAndroid Build Coastguard Worker backend_options.htp_options.use_weight_sharing = True 957*523fa7a6SAndroid Build Coastguard Worker 958*523fa7a6SAndroid Build Coastguard Worker return [ 959*523fa7a6SAndroid Build Coastguard Worker CompileSpec(QCOM_QNN_COMPILE_SPEC, option_to_flatbuffer(qnn_executorch_options)) 960*523fa7a6SAndroid Build Coastguard Worker ] 961*523fa7a6SAndroid Build Coastguard Worker 962*523fa7a6SAndroid Build Coastguard Worker 963*523fa7a6SAndroid Build Coastguard Workerdef get_soc_to_arch_map(): 964*523fa7a6SAndroid Build Coastguard Worker return { 965*523fa7a6SAndroid Build Coastguard Worker "SSG2115P": HtpArch.V73, 966*523fa7a6SAndroid Build Coastguard Worker "SM8650": HtpArch.V75, 967*523fa7a6SAndroid Build Coastguard Worker "SM8550": HtpArch.V73, 968*523fa7a6SAndroid Build Coastguard Worker "SM8475": HtpArch.V69, 969*523fa7a6SAndroid Build Coastguard Worker "SM8450": HtpArch.V69, 970*523fa7a6SAndroid Build Coastguard Worker "SA8295": HtpArch.V68, 971*523fa7a6SAndroid Build Coastguard Worker } 972*523fa7a6SAndroid Build Coastguard Worker 973*523fa7a6SAndroid Build Coastguard Worker 974*523fa7a6SAndroid Build Coastguard Workerdef get_soc_to_chipset_map(): 975*523fa7a6SAndroid Build Coastguard Worker return { 976*523fa7a6SAndroid Build Coastguard Worker "SSG2115P": QcomChipset.SSG2115P, 977*523fa7a6SAndroid Build Coastguard Worker "SM8650": QcomChipset.SM8650, 978*523fa7a6SAndroid Build Coastguard Worker "SM8550": QcomChipset.SM8550, 979*523fa7a6SAndroid Build Coastguard Worker "SM8475": QcomChipset.SM8475, 980*523fa7a6SAndroid Build Coastguard Worker "SM8450": QcomChipset.SM8450, 981*523fa7a6SAndroid Build Coastguard Worker "SA8295": QcomChipset.SA8295, 982*523fa7a6SAndroid Build Coastguard Worker } 983*523fa7a6SAndroid Build Coastguard Worker 984*523fa7a6SAndroid Build Coastguard Worker 985*523fa7a6SAndroid Build Coastguard Workerdef tag_quant_io(gm: torch.fx.GraphModule, get_quant_io_dtype_fn: Callable): 986*523fa7a6SAndroid Build Coastguard Worker """ 987*523fa7a6SAndroid Build Coastguard Worker Tag io nodes which get/output quantized tensor. No need to insert q/dq in qnn_preprocess 988*523fa7a6SAndroid Build Coastguard Worker """ 989*523fa7a6SAndroid Build Coastguard Worker for node in gm.graph.nodes: 990*523fa7a6SAndroid Build Coastguard Worker if dtype := get_quant_io_dtype_fn(node): 991*523fa7a6SAndroid Build Coastguard Worker node.meta[QCOM_QUANTIZED_IO] = dtype 992