xref: /aosp_15_r20/external/executorch/backends/qualcomm/utils/utils.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Qualcomm Innovation Center, Inc.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Workerimport operator
8*523fa7a6SAndroid Build Coastguard Workerimport warnings
9*523fa7a6SAndroid Build Coastguard Workerfrom collections import OrderedDict
10*523fa7a6SAndroid Build Coastguard Workerfrom typing import Callable, Dict, FrozenSet, List, Tuple
11*523fa7a6SAndroid Build Coastguard Worker
12*523fa7a6SAndroid Build Coastguard Workerimport executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManagerAdaptor
13*523fa7a6SAndroid Build Coastguard Worker
14*523fa7a6SAndroid Build Coastguard Workerimport executorch.exir as exir
15*523fa7a6SAndroid Build Coastguard Worker
16*523fa7a6SAndroid Build Coastguard Workerimport torch
17*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.annotate_and_quant_scalar import (
18*523fa7a6SAndroid Build Coastguard Worker    AnnotateAndQuantScalar,
19*523fa7a6SAndroid Build Coastguard Worker)
20*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.annotate_decomposed import AnnotateDecomposed
21*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.annotate_quant_attrs import AnnotateQuantAttrs
22*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.convert_binary_op_with_scalar import (
23*523fa7a6SAndroid Build Coastguard Worker    ConvertBinaryOpsWithScalar,
24*523fa7a6SAndroid Build Coastguard Worker)
25*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.convert_bmm_to_matmul import (
26*523fa7a6SAndroid Build Coastguard Worker    ConvertBmmToMatmul,
27*523fa7a6SAndroid Build Coastguard Worker)
28*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.convert_interpolate_with_upsample2d import (
29*523fa7a6SAndroid Build Coastguard Worker    ConvertInterpolateWithUpsample2D,
30*523fa7a6SAndroid Build Coastguard Worker)
31*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.convert_prelu import ConvertPReLU
32*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.convert_to_linear import ConvertToLinear
33*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.expand_broadcast_tensor_shape import (
34*523fa7a6SAndroid Build Coastguard Worker    ExpandBroadcastTensorShape,
35*523fa7a6SAndroid Build Coastguard Worker)
36*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.fold_qdq import FoldQDQ
37*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.i64_to_i32 import I64toI32
38*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.layout_transform import LayoutTransform
39*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.recompose_pixel_unshuffle import (
40*523fa7a6SAndroid Build Coastguard Worker    RecomposePixelUnshuffle,
41*523fa7a6SAndroid Build Coastguard Worker)
42*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.recompose_rms_norm import RecomposeRmsNorm
43*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.remove_redundancy import RemoveRedundancy
44*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.replace_index_put_input import (
45*523fa7a6SAndroid Build Coastguard Worker    ReplaceIndexPutInput,
46*523fa7a6SAndroid Build Coastguard Worker)
47*523fa7a6SAndroid Build Coastguard Worker
48*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm.builders.node_visitor import (
49*523fa7a6SAndroid Build Coastguard Worker    QNN_QUANT_TYPE_MAP,
50*523fa7a6SAndroid Build Coastguard Worker    QNN_TENSOR_TYPE_MAP,
51*523fa7a6SAndroid Build Coastguard Worker)
52*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm.builders.qnn_constants import OpContextLoader
53*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm.partition.qnn_partitioner import (
54*523fa7a6SAndroid Build Coastguard Worker    generate_qnn_executorch_option,
55*523fa7a6SAndroid Build Coastguard Worker    QnnPartitioner,
56*523fa7a6SAndroid Build Coastguard Worker)
57*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm.serialization.qc_schema import (
58*523fa7a6SAndroid Build Coastguard Worker    _soc_info_table,
59*523fa7a6SAndroid Build Coastguard Worker    HtpArch,
60*523fa7a6SAndroid Build Coastguard Worker    QcomChipset,
61*523fa7a6SAndroid Build Coastguard Worker    QnnExecuTorchBackendOptions,
62*523fa7a6SAndroid Build Coastguard Worker    QnnExecuTorchBackendType,
63*523fa7a6SAndroid Build Coastguard Worker    QnnExecuTorchHtpBackendOptions,
64*523fa7a6SAndroid Build Coastguard Worker    QnnExecuTorchHtpPerformanceMode,
65*523fa7a6SAndroid Build Coastguard Worker    QnnExecuTorchHtpPrecision,
66*523fa7a6SAndroid Build Coastguard Worker    QnnExecuTorchLogLevel,
67*523fa7a6SAndroid Build Coastguard Worker    QnnExecuTorchOptions,
68*523fa7a6SAndroid Build Coastguard Worker    QnnExecuTorchProfileLevel,
69*523fa7a6SAndroid Build Coastguard Worker)
70*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm.serialization.qc_schema_serialize import (
71*523fa7a6SAndroid Build Coastguard Worker    flatbuffer_to_option,
72*523fa7a6SAndroid Build Coastguard Worker    option_to_flatbuffer,
73*523fa7a6SAndroid Build Coastguard Worker)
74*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm.utils.constants import (
75*523fa7a6SAndroid Build Coastguard Worker    QCOM_PASS_EXPAND_BROADCAST_SHAPE,
76*523fa7a6SAndroid Build Coastguard Worker    QCOM_PASS_SKIP_ADVANCED_REQUANT,
77*523fa7a6SAndroid Build Coastguard Worker    QCOM_QNN_COMPILE_SPEC,
78*523fa7a6SAndroid Build Coastguard Worker    QCOM_QUANTIZED_IO,
79*523fa7a6SAndroid Build Coastguard Worker)
80*523fa7a6SAndroid Build Coastguard Worker
81*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir import (
82*523fa7a6SAndroid Build Coastguard Worker    EdgeCompileConfig,
83*523fa7a6SAndroid Build Coastguard Worker    ExecutorchProgramManager,
84*523fa7a6SAndroid Build Coastguard Worker    ExirExportedProgram,
85*523fa7a6SAndroid Build Coastguard Worker    to_edge,
86*523fa7a6SAndroid Build Coastguard Worker)
87*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.compile_spec_schema import CompileSpec
88*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.capture import ExecutorchBackendConfig
89*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.lowered_backend_module import LoweredBackendModule
90*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.program._program import _get_updated_graph_signature
91*523fa7a6SAndroid Build Coastguard Workerfrom torch._decomp import core_aten_decompositions as torch_core_aten_decompositions
92*523fa7a6SAndroid Build Coastguard Workerfrom torch.export.exported_program import ExportedProgram
93*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx import passes
94*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx.passes.operator_support import OperatorSupportBase
95*523fa7a6SAndroid Build Coastguard Workerfrom torch.library import Library
96*523fa7a6SAndroid Build Coastguard Worker
97*523fa7a6SAndroid Build Coastguard Worker
98*523fa7a6SAndroid Build Coastguard Workerclass _AnnotationSkipper(OperatorSupportBase):
99*523fa7a6SAndroid Build Coastguard Worker    """
100*523fa7a6SAndroid Build Coastguard Worker    Class used to partition out unwanted graph nodes.
101*523fa7a6SAndroid Build Coastguard Worker    e.g. - nodes are prevented from quantization annotation
102*523fa7a6SAndroid Build Coastguard Worker         - nodes have been grouped together as a submodule
103*523fa7a6SAndroid Build Coastguard Worker
104*523fa7a6SAndroid Build Coastguard Worker    Attributes
105*523fa7a6SAndroid Build Coastguard Worker    ----------
106*523fa7a6SAndroid Build Coastguard Worker    fp_node_id_set : set
107*523fa7a6SAndroid Build Coastguard Worker        a set contains nodes' name to be left in fp precision
108*523fa7a6SAndroid Build Coastguard Worker    fp_node_op_set : set
109*523fa7a6SAndroid Build Coastguard Worker        a set contains nodes' target (aten dialect) to be left in fp precision
110*523fa7a6SAndroid Build Coastguard Worker    skip_annotated_submodule : bool
111*523fa7a6SAndroid Build Coastguard Worker        flag to skip annotated submodule or not
112*523fa7a6SAndroid Build Coastguard Worker
113*523fa7a6SAndroid Build Coastguard Worker    Methods
114*523fa7a6SAndroid Build Coastguard Worker    -------
115*523fa7a6SAndroid Build Coastguard Worker    should_delegate(n: torch.fx.Node)
116*523fa7a6SAndroid Build Coastguard Worker        identify the residual nodes haven't be lowered with fixed-precision
117*523fa7a6SAndroid Build Coastguard Worker    should_skip(n: torch.fx.Node)
118*523fa7a6SAndroid Build Coastguard Worker        identify the nodes should be kept out with fixed-precision or not
119*523fa7a6SAndroid Build Coastguard Worker    is_node_supported(_, node: torch.fx.Node)
120*523fa7a6SAndroid Build Coastguard Worker        overridden method for graph partitioning
121*523fa7a6SAndroid Build Coastguard Worker    """
122*523fa7a6SAndroid Build Coastguard Worker
123*523fa7a6SAndroid Build Coastguard Worker    def __init__(
124*523fa7a6SAndroid Build Coastguard Worker        self,
125*523fa7a6SAndroid Build Coastguard Worker        fp_node_id_set: set = None,
126*523fa7a6SAndroid Build Coastguard Worker        fp_node_op_set: set = None,
127*523fa7a6SAndroid Build Coastguard Worker        skip_annotated_submodule: bool = False,
128*523fa7a6SAndroid Build Coastguard Worker    ):
129*523fa7a6SAndroid Build Coastguard Worker        self.fp_node_id_set = fp_node_id_set
130*523fa7a6SAndroid Build Coastguard Worker        self.fp_node_op_set = fp_node_op_set
131*523fa7a6SAndroid Build Coastguard Worker        self.skip_annotated_submodule = skip_annotated_submodule
132*523fa7a6SAndroid Build Coastguard Worker
133*523fa7a6SAndroid Build Coastguard Worker    def should_delegate(self, n: torch.fx.Node):
134*523fa7a6SAndroid Build Coastguard Worker        return n.op == "call_function" and n.target != operator.getitem
135*523fa7a6SAndroid Build Coastguard Worker
136*523fa7a6SAndroid Build Coastguard Worker    def should_skip(self, n: torch.fx.Node):
137*523fa7a6SAndroid Build Coastguard Worker        return n.name in self.fp_node_id_set or n.target in self.fp_node_op_set
138*523fa7a6SAndroid Build Coastguard Worker
139*523fa7a6SAndroid Build Coastguard Worker    def is_node_supported(self, _, node: torch.fx.Node) -> bool:
140*523fa7a6SAndroid Build Coastguard Worker        if self.skip_annotated_submodule:
141*523fa7a6SAndroid Build Coastguard Worker            if node.op == "get_attr":
142*523fa7a6SAndroid Build Coastguard Worker                return all(self.should_delegate(user) for user in node.users)
143*523fa7a6SAndroid Build Coastguard Worker            return self.should_delegate(node)
144*523fa7a6SAndroid Build Coastguard Worker
145*523fa7a6SAndroid Build Coastguard Worker        if any(
146*523fa7a6SAndroid Build Coastguard Worker            [
147*523fa7a6SAndroid Build Coastguard Worker                node.op in ("placeholder", "output"),
148*523fa7a6SAndroid Build Coastguard Worker                self.should_skip(node),
149*523fa7a6SAndroid Build Coastguard Worker                # check if parameters belong to fallbacked operator
150*523fa7a6SAndroid Build Coastguard Worker                (
151*523fa7a6SAndroid Build Coastguard Worker                    node.op == "get_attr"
152*523fa7a6SAndroid Build Coastguard Worker                    and all(self.should_skip(user) for user in node.users)
153*523fa7a6SAndroid Build Coastguard Worker                ),
154*523fa7a6SAndroid Build Coastguard Worker            ]
155*523fa7a6SAndroid Build Coastguard Worker        ):
156*523fa7a6SAndroid Build Coastguard Worker            print(f"[QNN Quantizer Annotation]: {node.name} | Skipped")
157*523fa7a6SAndroid Build Coastguard Worker            return False
158*523fa7a6SAndroid Build Coastguard Worker
159*523fa7a6SAndroid Build Coastguard Worker        return True
160*523fa7a6SAndroid Build Coastguard Worker
161*523fa7a6SAndroid Build Coastguard Worker
162*523fa7a6SAndroid Build Coastguard Workerdef qnn_capture_config():
163*523fa7a6SAndroid Build Coastguard Worker    return exir.CaptureConfig(enable_aot=True)
164*523fa7a6SAndroid Build Coastguard Worker
165*523fa7a6SAndroid Build Coastguard Worker
166*523fa7a6SAndroid Build Coastguard Workerdef qnn_edge_config() -> exir.EdgeCompileConfig:
167*523fa7a6SAndroid Build Coastguard Worker    return exir.EdgeCompileConfig(
168*523fa7a6SAndroid Build Coastguard Worker        _check_ir_validity=False,
169*523fa7a6SAndroid Build Coastguard Worker        _skip_dim_order=True,  # TODO(T182928844): Delegate dim order op to backend.
170*523fa7a6SAndroid Build Coastguard Worker    )
171*523fa7a6SAndroid Build Coastguard Worker
172*523fa7a6SAndroid Build Coastguard Worker
173*523fa7a6SAndroid Build Coastguard Workerdef convert_linear_to_conv2d(module: torch.nn.Module):
174*523fa7a6SAndroid Build Coastguard Worker    class Conv2D(torch.nn.Module):
175*523fa7a6SAndroid Build Coastguard Worker        def __init__(self, weight, bias=None):
176*523fa7a6SAndroid Build Coastguard Worker            super().__init__()
177*523fa7a6SAndroid Build Coastguard Worker            use_bias = bias is not None
178*523fa7a6SAndroid Build Coastguard Worker            self.conv = torch.nn.Conv2d(
179*523fa7a6SAndroid Build Coastguard Worker                in_channels=weight.shape[0],
180*523fa7a6SAndroid Build Coastguard Worker                out_channels=weight.shape[1],
181*523fa7a6SAndroid Build Coastguard Worker                kernel_size=1,
182*523fa7a6SAndroid Build Coastguard Worker                padding=0,
183*523fa7a6SAndroid Build Coastguard Worker                bias=use_bias,
184*523fa7a6SAndroid Build Coastguard Worker            )
185*523fa7a6SAndroid Build Coastguard Worker            self.conv.weight = torch.nn.Parameter(weight.reshape(*weight.shape, 1, 1))
186*523fa7a6SAndroid Build Coastguard Worker            if use_bias:
187*523fa7a6SAndroid Build Coastguard Worker                self.conv.bias = torch.nn.Parameter(bias)
188*523fa7a6SAndroid Build Coastguard Worker
189*523fa7a6SAndroid Build Coastguard Worker        def forward(self, x):
190*523fa7a6SAndroid Build Coastguard Worker            rank = x.dim()
191*523fa7a6SAndroid Build Coastguard Worker            x = x.unsqueeze(-1) if rank == 3 else x.reshape(1, *x.shape, 1)
192*523fa7a6SAndroid Build Coastguard Worker            x = torch.transpose(x, 1, 2)
193*523fa7a6SAndroid Build Coastguard Worker            res = self.conv(x)
194*523fa7a6SAndroid Build Coastguard Worker            res = torch.transpose(res, 1, 2)
195*523fa7a6SAndroid Build Coastguard Worker            res = res.squeeze(-1) if rank == 3 else res.reshape(*res.shape[1:3])
196*523fa7a6SAndroid Build Coastguard Worker            return res
197*523fa7a6SAndroid Build Coastguard Worker
198*523fa7a6SAndroid Build Coastguard Worker    def replace_linear(module: torch.nn.Module):
199*523fa7a6SAndroid Build Coastguard Worker        attr_strs = dir(module)
200*523fa7a6SAndroid Build Coastguard Worker        if isinstance(module, torch.nn.ModuleList):
201*523fa7a6SAndroid Build Coastguard Worker            attr_strs += [str(i) for i in range(len(module))]
202*523fa7a6SAndroid Build Coastguard Worker
203*523fa7a6SAndroid Build Coastguard Worker        for attr_str in attr_strs:
204*523fa7a6SAndroid Build Coastguard Worker            target_attr = getattr(module, attr_str)
205*523fa7a6SAndroid Build Coastguard Worker            if isinstance(target_attr, torch.nn.Linear):
206*523fa7a6SAndroid Build Coastguard Worker                setattr(module, attr_str, Conv2D(target_attr.weight, target_attr.bias))
207*523fa7a6SAndroid Build Coastguard Worker
208*523fa7a6SAndroid Build Coastguard Worker        for _, sub_module in module.named_children():
209*523fa7a6SAndroid Build Coastguard Worker            sub_module = replace_linear(sub_module)
210*523fa7a6SAndroid Build Coastguard Worker        return module
211*523fa7a6SAndroid Build Coastguard Worker
212*523fa7a6SAndroid Build Coastguard Worker    return replace_linear(module)
213*523fa7a6SAndroid Build Coastguard Worker
214*523fa7a6SAndroid Build Coastguard Worker
215*523fa7a6SAndroid Build Coastguard Workerdef update_spill_fill_size(
216*523fa7a6SAndroid Build Coastguard Worker    exported_program: ExportedProgram | List[LoweredBackendModule],
217*523fa7a6SAndroid Build Coastguard Worker):
218*523fa7a6SAndroid Build Coastguard Worker    # check if user specifies to use multi_contexts
219*523fa7a6SAndroid Build Coastguard Worker    # this is a generic approach in case there exists multiple backends
220*523fa7a6SAndroid Build Coastguard Worker    def get_program_info(program):
221*523fa7a6SAndroid Build Coastguard Worker        def process_exported_program(prog):
222*523fa7a6SAndroid Build Coastguard Worker            max_sf_buf_size, module_map = 0, {}
223*523fa7a6SAndroid Build Coastguard Worker            for _, m in prog.graph_module._modules.items():
224*523fa7a6SAndroid Build Coastguard Worker                # currently only 1 compile spec is expected in each partition
225*523fa7a6SAndroid Build Coastguard Worker                options = flatbuffer_to_option(m.compile_specs[0].value)
226*523fa7a6SAndroid Build Coastguard Worker                if (
227*523fa7a6SAndroid Build Coastguard Worker                    options.backend_options.backend_type
228*523fa7a6SAndroid Build Coastguard Worker                    == QnnExecuTorchBackendType.kHtpBackend
229*523fa7a6SAndroid Build Coastguard Worker                    and options.backend_options.htp_options.use_multi_contexts
230*523fa7a6SAndroid Build Coastguard Worker                ):
231*523fa7a6SAndroid Build Coastguard Worker                    qnn_mgr = PyQnnManagerAdaptor.QnnManager(
232*523fa7a6SAndroid Build Coastguard Worker                        m.compile_specs[0].value, m.processed_bytes
233*523fa7a6SAndroid Build Coastguard Worker                    )
234*523fa7a6SAndroid Build Coastguard Worker                    assert qnn_mgr.Init().value == 0, "failed to load context binary"
235*523fa7a6SAndroid Build Coastguard Worker                    max_sf_buf_size = max(
236*523fa7a6SAndroid Build Coastguard Worker                        max_sf_buf_size, qnn_mgr.GetSpillFillBufferSize()
237*523fa7a6SAndroid Build Coastguard Worker                    )
238*523fa7a6SAndroid Build Coastguard Worker                    module_map[m] = options
239*523fa7a6SAndroid Build Coastguard Worker                    qnn_mgr.Destroy()
240*523fa7a6SAndroid Build Coastguard Worker            return max_sf_buf_size, module_map
241*523fa7a6SAndroid Build Coastguard Worker
242*523fa7a6SAndroid Build Coastguard Worker        def process_lowered_module(module):
243*523fa7a6SAndroid Build Coastguard Worker            qnn_mgr = PyQnnManagerAdaptor.QnnManager(
244*523fa7a6SAndroid Build Coastguard Worker                module.compile_specs[0].value, module.processed_bytes
245*523fa7a6SAndroid Build Coastguard Worker            )
246*523fa7a6SAndroid Build Coastguard Worker            assert qnn_mgr.Init().value == 0, "failed to load context binary"
247*523fa7a6SAndroid Build Coastguard Worker            spill_fill_size = qnn_mgr.GetSpillFillBufferSize()
248*523fa7a6SAndroid Build Coastguard Worker            qnn_mgr.Destroy()
249*523fa7a6SAndroid Build Coastguard Worker            return spill_fill_size, {
250*523fa7a6SAndroid Build Coastguard Worker                module: flatbuffer_to_option(module.compile_specs[0].value)
251*523fa7a6SAndroid Build Coastguard Worker            }
252*523fa7a6SAndroid Build Coastguard Worker
253*523fa7a6SAndroid Build Coastguard Worker        dispatch = {
254*523fa7a6SAndroid Build Coastguard Worker            ExportedProgram: process_exported_program,
255*523fa7a6SAndroid Build Coastguard Worker            LoweredBackendModule: process_lowered_module,
256*523fa7a6SAndroid Build Coastguard Worker        }
257*523fa7a6SAndroid Build Coastguard Worker        return dispatch[type(program)](program)
258*523fa7a6SAndroid Build Coastguard Worker
259*523fa7a6SAndroid Build Coastguard Worker    def update_program(max_sf_buf_size, module_map):
260*523fa7a6SAndroid Build Coastguard Worker        def set_spec(module, options):
261*523fa7a6SAndroid Build Coastguard Worker            spec = CompileSpec(QCOM_QNN_COMPILE_SPEC, option_to_flatbuffer(options))
262*523fa7a6SAndroid Build Coastguard Worker            if isinstance(module, ExportedProgram):
263*523fa7a6SAndroid Build Coastguard Worker                module.compile_specs[0] = spec
264*523fa7a6SAndroid Build Coastguard Worker            else:
265*523fa7a6SAndroid Build Coastguard Worker                module._compile_specs[0] = spec
266*523fa7a6SAndroid Build Coastguard Worker
267*523fa7a6SAndroid Build Coastguard Worker        for module, options in module_map.items():
268*523fa7a6SAndroid Build Coastguard Worker            options.backend_options.htp_options.max_sf_buf_size = max_sf_buf_size
269*523fa7a6SAndroid Build Coastguard Worker            set_spec(module, options)
270*523fa7a6SAndroid Build Coastguard Worker
271*523fa7a6SAndroid Build Coastguard Worker    if isinstance(exported_program, list):
272*523fa7a6SAndroid Build Coastguard Worker        max_sf_size, modules_map = 0, {}
273*523fa7a6SAndroid Build Coastguard Worker        for prog in exported_program:
274*523fa7a6SAndroid Build Coastguard Worker            max_sf_buf_size, module_map = get_program_info(prog)
275*523fa7a6SAndroid Build Coastguard Worker            max_sf_size = max(max_sf_size, max_sf_buf_size)
276*523fa7a6SAndroid Build Coastguard Worker            modules_map.update(module_map)
277*523fa7a6SAndroid Build Coastguard Worker        update_program(max_sf_size, modules_map)
278*523fa7a6SAndroid Build Coastguard Worker    else:
279*523fa7a6SAndroid Build Coastguard Worker        update_program(*get_program_info(exported_program))
280*523fa7a6SAndroid Build Coastguard Worker
281*523fa7a6SAndroid Build Coastguard Worker
282*523fa7a6SAndroid Build Coastguard Workerdef get_decomp_table() -> Dict[torch._ops.OperatorBase, Callable]:
283*523fa7a6SAndroid Build Coastguard Worker    source_decompositions = torch_core_aten_decompositions()
284*523fa7a6SAndroid Build Coastguard Worker    # The below super ops are supported by QNN
285*523fa7a6SAndroid Build Coastguard Worker    remove_decompositions = [
286*523fa7a6SAndroid Build Coastguard Worker        torch.ops.aten.pixel_shuffle.default,
287*523fa7a6SAndroid Build Coastguard Worker        torch.ops.aten.pixel_unshuffle.default,
288*523fa7a6SAndroid Build Coastguard Worker        torch.ops.aten.hardsigmoid.default,
289*523fa7a6SAndroid Build Coastguard Worker        torch.ops.aten.hardswish.default,
290*523fa7a6SAndroid Build Coastguard Worker        torch.ops.aten._safe_softmax.default,
291*523fa7a6SAndroid Build Coastguard Worker    ]
292*523fa7a6SAndroid Build Coastguard Worker
293*523fa7a6SAndroid Build Coastguard Worker    for key in remove_decompositions:
294*523fa7a6SAndroid Build Coastguard Worker        source_decompositions.pop(key)
295*523fa7a6SAndroid Build Coastguard Worker
296*523fa7a6SAndroid Build Coastguard Worker    return source_decompositions
297*523fa7a6SAndroid Build Coastguard Worker
298*523fa7a6SAndroid Build Coastguard Worker
299*523fa7a6SAndroid Build Coastguard Workerdef _transform(
300*523fa7a6SAndroid Build Coastguard Worker    edge_program: ExportedProgram, custom_pass_config: FrozenSet[str] = frozenset()
301*523fa7a6SAndroid Build Coastguard Worker) -> ExportedProgram:
302*523fa7a6SAndroid Build Coastguard Worker    # currently ExirExportedProgram.transform does not accept
303*523fa7a6SAndroid Build Coastguard Worker    # changes of input number which was caused by FoldQDQ
304*523fa7a6SAndroid Build Coastguard Worker    # apply passes one by one here to avoid IR capture failure
305*523fa7a6SAndroid Build Coastguard Worker    graph_module = edge_program.graph_module
306*523fa7a6SAndroid Build Coastguard Worker    RemoveRedundancy()(graph_module)
307*523fa7a6SAndroid Build Coastguard Worker    RecomposePixelUnshuffle()(graph_module)
308*523fa7a6SAndroid Build Coastguard Worker    RecomposeRmsNorm()(graph_module)
309*523fa7a6SAndroid Build Coastguard Worker    ConvertToLinear()(graph_module)
310*523fa7a6SAndroid Build Coastguard Worker    ConvertPReLU(edge_program)(graph_module)
311*523fa7a6SAndroid Build Coastguard Worker    ConvertBmmToMatmul()(graph_module)
312*523fa7a6SAndroid Build Coastguard Worker    ConvertInterpolateWithUpsample2D()(graph_module)
313*523fa7a6SAndroid Build Coastguard Worker    I64toI32(edge_program)(graph_module)
314*523fa7a6SAndroid Build Coastguard Worker    AnnotateQuantAttrs(
315*523fa7a6SAndroid Build Coastguard Worker        edge_program, QCOM_PASS_SKIP_ADVANCED_REQUANT in custom_pass_config
316*523fa7a6SAndroid Build Coastguard Worker    )(graph_module)
317*523fa7a6SAndroid Build Coastguard Worker    AnnotateAndQuantScalar(edge_program)(graph_module)
318*523fa7a6SAndroid Build Coastguard Worker    AnnotateDecomposed(edge_program)(graph_module)
319*523fa7a6SAndroid Build Coastguard Worker    FoldQDQ()(graph_module)
320*523fa7a6SAndroid Build Coastguard Worker    # this pass is not necessary for network without layout-sensitive ops
321*523fa7a6SAndroid Build Coastguard Worker    # enable defaultly will introduce overhead from extra view_copy nodes
322*523fa7a6SAndroid Build Coastguard Worker    if QCOM_PASS_EXPAND_BROADCAST_SHAPE in custom_pass_config:
323*523fa7a6SAndroid Build Coastguard Worker        ExpandBroadcastTensorShape()(graph_module)
324*523fa7a6SAndroid Build Coastguard Worker    LayoutTransform(edge_program)(graph_module)
325*523fa7a6SAndroid Build Coastguard Worker    ReplaceIndexPutInput(edge_program)(graph_module)
326*523fa7a6SAndroid Build Coastguard Worker
327*523fa7a6SAndroid Build Coastguard Worker    # Since QDQ nodes are stripped, update graph signature again to validate program
328*523fa7a6SAndroid Build Coastguard Worker    edge_program._graph_signature = _get_updated_graph_signature(
329*523fa7a6SAndroid Build Coastguard Worker        edge_program.graph_signature,
330*523fa7a6SAndroid Build Coastguard Worker        edge_program.graph_module,
331*523fa7a6SAndroid Build Coastguard Worker    )
332*523fa7a6SAndroid Build Coastguard Worker    edge_program._validate()
333*523fa7a6SAndroid Build Coastguard Worker    return edge_program
334*523fa7a6SAndroid Build Coastguard Worker
335*523fa7a6SAndroid Build Coastguard Worker
336*523fa7a6SAndroid Build Coastguard Workerdef capture_program(
337*523fa7a6SAndroid Build Coastguard Worker    module: torch.nn.Module,
338*523fa7a6SAndroid Build Coastguard Worker    inputs: Tuple[torch.Tensor],
339*523fa7a6SAndroid Build Coastguard Worker    custom_pass_config: FrozenSet[str] = frozenset(),
340*523fa7a6SAndroid Build Coastguard Worker) -> exir.ExirExportedProgram:
341*523fa7a6SAndroid Build Coastguard Worker    ep = torch.export.export(module, inputs)
342*523fa7a6SAndroid Build Coastguard Worker    decomposed_ep = ep.run_decompositions(get_decomp_table())
343*523fa7a6SAndroid Build Coastguard Worker    # We choose call_operator by target in ConvertBinaryOpsWithScalar
344*523fa7a6SAndroid Build Coastguard Worker    # because it is the same source_fn_stack for MultiheadAttention
345*523fa7a6SAndroid Build Coastguard Worker    # TODO: Should modify the scalar op in the op builder instead of
346*523fa7a6SAndroid Build Coastguard Worker    #       using transformation
347*523fa7a6SAndroid Build Coastguard Worker    core_ep = ExirExportedProgram(decomposed_ep, False)
348*523fa7a6SAndroid Build Coastguard Worker    core_ep.transform(ConvertBinaryOpsWithScalar())
349*523fa7a6SAndroid Build Coastguard Worker    edge_ep = core_ep.to_edge(qnn_edge_config())
350*523fa7a6SAndroid Build Coastguard Worker    _transform(edge_ep.exported_program, custom_pass_config)
351*523fa7a6SAndroid Build Coastguard Worker    return edge_ep
352*523fa7a6SAndroid Build Coastguard Worker
353*523fa7a6SAndroid Build Coastguard Worker
354*523fa7a6SAndroid Build Coastguard Workerdef _partition_graph_into_submodules(gm, subgm_tag, subgm_cb, ptn):
355*523fa7a6SAndroid Build Coastguard Worker    from torch.fx.passes.utils.fuser_utils import (
356*523fa7a6SAndroid Build Coastguard Worker        erase_nodes,
357*523fa7a6SAndroid Build Coastguard Worker        fuse_as_graphmodule,
358*523fa7a6SAndroid Build Coastguard Worker        insert_subgm,
359*523fa7a6SAndroid Build Coastguard Worker        legalize_graph,
360*523fa7a6SAndroid Build Coastguard Worker        topo_sort,
361*523fa7a6SAndroid Build Coastguard Worker    )
362*523fa7a6SAndroid Build Coastguard Worker
363*523fa7a6SAndroid Build Coastguard Worker    partitions = ptn.propose_partitions()
364*523fa7a6SAndroid Build Coastguard Worker    # insert meta for each partition group
365*523fa7a6SAndroid Build Coastguard Worker    for i, partition in enumerate(partitions):
366*523fa7a6SAndroid Build Coastguard Worker        for node in partition.nodes:
367*523fa7a6SAndroid Build Coastguard Worker            node.meta[subgm_tag] = i
368*523fa7a6SAndroid Build Coastguard Worker
369*523fa7a6SAndroid Build Coastguard Worker    for i in range(len(partitions)):
370*523fa7a6SAndroid Build Coastguard Worker        # find nodes with same group id in current graph
371*523fa7a6SAndroid Build Coastguard Worker        node_list = [
372*523fa7a6SAndroid Build Coastguard Worker            node for node in gm.graph.nodes if node.meta.get(subgm_tag, "") == i
373*523fa7a6SAndroid Build Coastguard Worker        ]
374*523fa7a6SAndroid Build Coastguard Worker        # fuse group nodes into submodule
375*523fa7a6SAndroid Build Coastguard Worker        sorted_nodes = topo_sort(node_list)
376*523fa7a6SAndroid Build Coastguard Worker        submodule_name = f"{subgm_tag}_{i}"
377*523fa7a6SAndroid Build Coastguard Worker        subgm, orig_inputs, orig_outputs = fuse_as_graphmodule(
378*523fa7a6SAndroid Build Coastguard Worker            gm, sorted_nodes, submodule_name
379*523fa7a6SAndroid Build Coastguard Worker        )
380*523fa7a6SAndroid Build Coastguard Worker        # insert submodule & trim group nodes
381*523fa7a6SAndroid Build Coastguard Worker        gm = insert_subgm(
382*523fa7a6SAndroid Build Coastguard Worker            gm,
383*523fa7a6SAndroid Build Coastguard Worker            subgm_cb(subgm, submodule_name),
384*523fa7a6SAndroid Build Coastguard Worker            orig_inputs,
385*523fa7a6SAndroid Build Coastguard Worker            orig_outputs,
386*523fa7a6SAndroid Build Coastguard Worker        )
387*523fa7a6SAndroid Build Coastguard Worker        erase_nodes(gm, sorted_nodes)
388*523fa7a6SAndroid Build Coastguard Worker        legalize_graph(gm)
389*523fa7a6SAndroid Build Coastguard Worker
390*523fa7a6SAndroid Build Coastguard Worker    gm.recompile()
391*523fa7a6SAndroid Build Coastguard Worker    return gm
392*523fa7a6SAndroid Build Coastguard Worker
393*523fa7a6SAndroid Build Coastguard Worker
394*523fa7a6SAndroid Build Coastguard Workerdef _canonicalize_graph_with_lowered_module(gm, subgm_tag, ptn):
395*523fa7a6SAndroid Build Coastguard Worker    from executorch.exir.backend.backend_api import to_backend
396*523fa7a6SAndroid Build Coastguard Worker
397*523fa7a6SAndroid Build Coastguard Worker    # return lowered program for user to debug
398*523fa7a6SAndroid Build Coastguard Worker    exported_progs = []
399*523fa7a6SAndroid Build Coastguard Worker    # partition each submodule which went through convert_pt2e
400*523fa7a6SAndroid Build Coastguard Worker    for node in gm.graph.nodes:
401*523fa7a6SAndroid Build Coastguard Worker        if node.op == "call_module" and subgm_tag in node.name:
402*523fa7a6SAndroid Build Coastguard Worker            # obtain sample inputs through meta
403*523fa7a6SAndroid Build Coastguard Worker            subgm_input = [
404*523fa7a6SAndroid Build Coastguard Worker                torch.ones(arg.meta["val"].shape, dtype=arg.meta["val"].dtype)
405*523fa7a6SAndroid Build Coastguard Worker                for arg in node.args
406*523fa7a6SAndroid Build Coastguard Worker            ]
407*523fa7a6SAndroid Build Coastguard Worker            # program meets QNN backend requirement
408*523fa7a6SAndroid Build Coastguard Worker            sub_prog = capture_program(gm.get_submodule(node.name), tuple(subgm_input))
409*523fa7a6SAndroid Build Coastguard Worker            # start lowering with given partitioner
410*523fa7a6SAndroid Build Coastguard Worker            exported_progs.append(to_backend(sub_prog.exported_program, ptn))
411*523fa7a6SAndroid Build Coastguard Worker            # replace submodule with lowered module
412*523fa7a6SAndroid Build Coastguard Worker            gm.set_submodule(
413*523fa7a6SAndroid Build Coastguard Worker                node.name,
414*523fa7a6SAndroid Build Coastguard Worker                exported_progs[-1].graph_module,
415*523fa7a6SAndroid Build Coastguard Worker            )
416*523fa7a6SAndroid Build Coastguard Worker            # if node has multiple outputs, getitems will be default generated
417*523fa7a6SAndroid Build Coastguard Worker            if all(n.target != operator.getitem for n in node.users):
418*523fa7a6SAndroid Build Coastguard Worker                with gm.graph.inserting_after(node):
419*523fa7a6SAndroid Build Coastguard Worker                    getitem_node = gm.graph.call_function(
420*523fa7a6SAndroid Build Coastguard Worker                        operator.getitem,
421*523fa7a6SAndroid Build Coastguard Worker                        (node, 0),
422*523fa7a6SAndroid Build Coastguard Worker                    )
423*523fa7a6SAndroid Build Coastguard Worker                    getitem_node.meta = node.meta
424*523fa7a6SAndroid Build Coastguard Worker                    node.replace_all_uses_with(
425*523fa7a6SAndroid Build Coastguard Worker                        replace_with=getitem_node,
426*523fa7a6SAndroid Build Coastguard Worker                        delete_user_cb=lambda user: user.target != operator.getitem,
427*523fa7a6SAndroid Build Coastguard Worker                    )
428*523fa7a6SAndroid Build Coastguard Worker
429*523fa7a6SAndroid Build Coastguard Worker    gm.recompile()
430*523fa7a6SAndroid Build Coastguard Worker    return gm, exported_progs
431*523fa7a6SAndroid Build Coastguard Worker
432*523fa7a6SAndroid Build Coastguard Worker
433*523fa7a6SAndroid Build Coastguard Workerdef skip_annotation(
434*523fa7a6SAndroid Build Coastguard Worker    nn_module: torch.nn.Module,
435*523fa7a6SAndroid Build Coastguard Worker    quantizer,
436*523fa7a6SAndroid Build Coastguard Worker    partitioner,
437*523fa7a6SAndroid Build Coastguard Worker    sample_input: Tuple[torch.Tensor, ...],
438*523fa7a6SAndroid Build Coastguard Worker    calibration_cb: Callable[[torch.fx.GraphModule], None],
439*523fa7a6SAndroid Build Coastguard Worker    fp_node_id_set: set = None,
440*523fa7a6SAndroid Build Coastguard Worker    fp_node_op_set: set = None,
441*523fa7a6SAndroid Build Coastguard Worker    fallback_to_cpu: bool = True,
442*523fa7a6SAndroid Build Coastguard Worker):
443*523fa7a6SAndroid Build Coastguard Worker    r"""
444*523fa7a6SAndroid Build Coastguard Worker    Exclude speific operators from quantizer annotation.
445*523fa7a6SAndroid Build Coastguard Worker    Skipped operators will defaultly stay in CPU, set 'fallback_to_cpu'
446*523fa7a6SAndroid Build Coastguard Worker    to False for trying to delegate them with FP16 precision.
447*523fa7a6SAndroid Build Coastguard Worker
448*523fa7a6SAndroid Build Coastguard Worker    e.g.: consider following graph:
449*523fa7a6SAndroid Build Coastguard Worker    bias_1 weight_1 input_1   bias_2 weight_2 input_2
450*523fa7a6SAndroid Build Coastguard Worker      | (placeholder) |         | (placeholder) |
451*523fa7a6SAndroid Build Coastguard Worker       \      |      /           \      |      /
452*523fa7a6SAndroid Build Coastguard Worker        \     |     /             \     |     /
453*523fa7a6SAndroid Build Coastguard Worker         \    |    /               \    |    /
454*523fa7a6SAndroid Build Coastguard Worker           conv2d_1                 conv2d_2
455*523fa7a6SAndroid Build Coastguard Worker           (torch.ops.aten.conv2d.default)
456*523fa7a6SAndroid Build Coastguard Worker               \                       /
457*523fa7a6SAndroid Build Coastguard Worker                \                     /
458*523fa7a6SAndroid Build Coastguard Worker                 \_______     _______/
459*523fa7a6SAndroid Build Coastguard Worker                         add_1
460*523fa7a6SAndroid Build Coastguard Worker             (torch.ops.aten.add.default)
461*523fa7a6SAndroid Build Coastguard Worker                           |
462*523fa7a6SAndroid Build Coastguard Worker                         output
463*523fa7a6SAndroid Build Coastguard Worker
464*523fa7a6SAndroid Build Coastguard Worker    If user wants to skip convolution op by names with
465*523fa7a6SAndroid Build Coastguard Worker    'skip_node_id_set' = {"conv2d_1"}
466*523fa7a6SAndroid Build Coastguard Worker    "bias_1 / weight_1 / input_1 / input_2 / conv2d_1"
467*523fa7a6SAndroid Build Coastguard Worker    will be partitioned out and not annotated / lowered with QNN.
468*523fa7a6SAndroid Build Coastguard Worker
469*523fa7a6SAndroid Build Coastguard Worker    [Generated graph]
470*523fa7a6SAndroid Build Coastguard Worker    bias_1 weight_1 input_1   input_2
471*523fa7a6SAndroid Build Coastguard Worker      | (placeholder) |          |
472*523fa7a6SAndroid Build Coastguard Worker       \      |      /           |
473*523fa7a6SAndroid Build Coastguard Worker        \     |     /            |
474*523fa7a6SAndroid Build Coastguard Worker         \    |    /             |
475*523fa7a6SAndroid Build Coastguard Worker           conv2d_1              |
476*523fa7a6SAndroid Build Coastguard Worker              \                 /
477*523fa7a6SAndroid Build Coastguard Worker               \               /
478*523fa7a6SAndroid Build Coastguard Worker                \             /
479*523fa7a6SAndroid Build Coastguard Worker               lowered_module_1
480*523fa7a6SAndroid Build Coastguard Worker            (QNN fixed precision)
481*523fa7a6SAndroid Build Coastguard Worker                      |
482*523fa7a6SAndroid Build Coastguard Worker                    output
483*523fa7a6SAndroid Build Coastguard Worker
484*523fa7a6SAndroid Build Coastguard Worker    If user wants to skip convolution op by target with
485*523fa7a6SAndroid Build Coastguard Worker    'skip_node_op_set' = {torch.ops.aten.conv2d.default}
486*523fa7a6SAndroid Build Coastguard Worker    "bias_1 / weight_1 / input_1 / conv2d_1,
487*523fa7a6SAndroid Build Coastguard Worker     bias_2 / weight_2 / input_2 / conv2d_2"
488*523fa7a6SAndroid Build Coastguard Worker    will be partitioned out and not annotated / lowered with QNN.
489*523fa7a6SAndroid Build Coastguard Worker
490*523fa7a6SAndroid Build Coastguard Worker    [Generated graph]
491*523fa7a6SAndroid Build Coastguard Worker    bias_1 weight_1 input_1   bias_2 weight_2 input_2
492*523fa7a6SAndroid Build Coastguard Worker      | (placeholder) |         | (placeholder) |
493*523fa7a6SAndroid Build Coastguard Worker       \      |      /           \      |      /
494*523fa7a6SAndroid Build Coastguard Worker        \     |     /             \     |     /
495*523fa7a6SAndroid Build Coastguard Worker         \    |    /               \    |    /
496*523fa7a6SAndroid Build Coastguard Worker           conv2d_1                 conv2d_2
497*523fa7a6SAndroid Build Coastguard Worker           (torch.ops.aten.conv2d.default)
498*523fa7a6SAndroid Build Coastguard Worker               \                       /
499*523fa7a6SAndroid Build Coastguard Worker                \                     /
500*523fa7a6SAndroid Build Coastguard Worker                 \__               __/
501*523fa7a6SAndroid Build Coastguard Worker                    lowered_module_1
502*523fa7a6SAndroid Build Coastguard Worker                 (QNN fixed precision)
503*523fa7a6SAndroid Build Coastguard Worker                           |
504*523fa7a6SAndroid Build Coastguard Worker                         output
505*523fa7a6SAndroid Build Coastguard Worker
506*523fa7a6SAndroid Build Coastguard Worker    If user wants to delegate the skipped conv2d from above graph
507*523fa7a6SAndroid Build Coastguard Worker    with 'fallback_to_cpu' = False:
508*523fa7a6SAndroid Build Coastguard Worker
509*523fa7a6SAndroid Build Coastguard Worker    [Generated graph]
510*523fa7a6SAndroid Build Coastguard Worker       input_1         input_2
511*523fa7a6SAndroid Build Coastguard Worker    (placeholder)   (placeholder)
512*523fa7a6SAndroid Build Coastguard Worker          |               |
513*523fa7a6SAndroid Build Coastguard Worker          \               /
514*523fa7a6SAndroid Build Coastguard Worker          lowered_module_2
515*523fa7a6SAndroid Build Coastguard Worker         (QNN fp16 precision)
516*523fa7a6SAndroid Build Coastguard Worker                  |
517*523fa7a6SAndroid Build Coastguard Worker                  |
518*523fa7a6SAndroid Build Coastguard Worker          lowered_module_1
519*523fa7a6SAndroid Build Coastguard Worker         (QNN fixed precision)
520*523fa7a6SAndroid Build Coastguard Worker                  |
521*523fa7a6SAndroid Build Coastguard Worker                output
522*523fa7a6SAndroid Build Coastguard Worker
523*523fa7a6SAndroid Build Coastguard Worker    Args:
524*523fa7a6SAndroid Build Coastguard Worker        nn_module (torch.nn.Module): The module to be lowered.
525*523fa7a6SAndroid Build Coastguard Worker        quantizer (QnnQuantizer): Instance of QnnQuantizer.
526*523fa7a6SAndroid Build Coastguard Worker        partitioner (QnnPartitioner): Instance of QnnPartitioner.
527*523fa7a6SAndroid Build Coastguard Worker        sample_input ((torch.Tensor, ...)): Sample input tensors for graph exporting.
528*523fa7a6SAndroid Build Coastguard Worker        calibration_cb (callable): Callback function for user-defined calibration.
529*523fa7a6SAndroid Build Coastguard Worker        fp_node_id_set ({str, ...}): Set of operator names to be left in fp precision.
530*523fa7a6SAndroid Build Coastguard Worker        fp_node_op_set ({torch.ops.aten.xxx, ...}): Set of operator targets to be left in fp precision.
531*523fa7a6SAndroid Build Coastguard Worker        fallback_to_cpu (bool): Whether to lower skipped nodes to fp16 or not.
532*523fa7a6SAndroid Build Coastguard Worker
533*523fa7a6SAndroid Build Coastguard Worker    Returns:
534*523fa7a6SAndroid Build Coastguard Worker        exported_programs: List of programs lowered to QnnBackend (quantized graphs only).
535*523fa7a6SAndroid Build Coastguard Worker    """
536*523fa7a6SAndroid Build Coastguard Worker    from executorch.backends.qualcomm.serialization.qc_schema import (
537*523fa7a6SAndroid Build Coastguard Worker        QnnExecuTorchHtpPrecision,
538*523fa7a6SAndroid Build Coastguard Worker    )
539*523fa7a6SAndroid Build Coastguard Worker    from executorch.backends.qualcomm.serialization.qc_schema_serialize import (
540*523fa7a6SAndroid Build Coastguard Worker        flatbuffer_to_option,
541*523fa7a6SAndroid Build Coastguard Worker    )
542*523fa7a6SAndroid Build Coastguard Worker    from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
543*523fa7a6SAndroid Build Coastguard Worker    from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
544*523fa7a6SAndroid Build Coastguard Worker
545*523fa7a6SAndroid Build Coastguard Worker    def prepare_subgm(subgm, subgm_name):
546*523fa7a6SAndroid Build Coastguard Worker        # prepare current submodule for quantization annotation
547*523fa7a6SAndroid Build Coastguard Worker        subgm_prepared = prepare_pt2e(subgm, quantizer)
548*523fa7a6SAndroid Build Coastguard Worker        # overwrite this attribute or name will be set to "GraphModule"
549*523fa7a6SAndroid Build Coastguard Worker        # we could not identify each submodule if action is not performed
550*523fa7a6SAndroid Build Coastguard Worker        subgm_prepared.__class__.__name__ = subgm_name
551*523fa7a6SAndroid Build Coastguard Worker        return subgm_prepared
552*523fa7a6SAndroid Build Coastguard Worker
553*523fa7a6SAndroid Build Coastguard Worker    fp_node_id_set = fp_node_id_set if fp_node_id_set is not None else set()
554*523fa7a6SAndroid Build Coastguard Worker    fp_node_op_set = fp_node_op_set if fp_node_op_set is not None else set()
555*523fa7a6SAndroid Build Coastguard Worker    graph_module = torch.export.export(nn_module, sample_input).module()
556*523fa7a6SAndroid Build Coastguard Worker    # define node support type
557*523fa7a6SAndroid Build Coastguard Worker    capability_partitioner = CapabilityBasedPartitioner(
558*523fa7a6SAndroid Build Coastguard Worker        graph_module,
559*523fa7a6SAndroid Build Coastguard Worker        _AnnotationSkipper(fp_node_id_set, fp_node_op_set),
560*523fa7a6SAndroid Build Coastguard Worker        allows_single_node_partition=True,
561*523fa7a6SAndroid Build Coastguard Worker    )
562*523fa7a6SAndroid Build Coastguard Worker    subgm_tag = "annotated_group"
563*523fa7a6SAndroid Build Coastguard Worker    graph_module = _partition_graph_into_submodules(
564*523fa7a6SAndroid Build Coastguard Worker        gm=graph_module,
565*523fa7a6SAndroid Build Coastguard Worker        subgm_tag=subgm_tag,
566*523fa7a6SAndroid Build Coastguard Worker        subgm_cb=prepare_subgm,
567*523fa7a6SAndroid Build Coastguard Worker        ptn=capability_partitioner,
568*523fa7a6SAndroid Build Coastguard Worker    )
569*523fa7a6SAndroid Build Coastguard Worker    # perform calibration
570*523fa7a6SAndroid Build Coastguard Worker    calibration_cb(graph_module)
571*523fa7a6SAndroid Build Coastguard Worker    # convert sub modules which went through prepare_pt2e
572*523fa7a6SAndroid Build Coastguard Worker    for node in graph_module.graph.nodes:
573*523fa7a6SAndroid Build Coastguard Worker        if node.op == "call_module":
574*523fa7a6SAndroid Build Coastguard Worker            graph_module.set_submodule(
575*523fa7a6SAndroid Build Coastguard Worker                node.name, convert_pt2e(graph_module.get_submodule(node.name))
576*523fa7a6SAndroid Build Coastguard Worker            )
577*523fa7a6SAndroid Build Coastguard Worker    # canonicalize graph for lowering again
578*523fa7a6SAndroid Build Coastguard Worker    graph_module, exported_progs = _canonicalize_graph_with_lowered_module(
579*523fa7a6SAndroid Build Coastguard Worker        gm=graph_module,
580*523fa7a6SAndroid Build Coastguard Worker        subgm_tag=subgm_tag,
581*523fa7a6SAndroid Build Coastguard Worker        ptn=partitioner,
582*523fa7a6SAndroid Build Coastguard Worker    )
583*523fa7a6SAndroid Build Coastguard Worker
584*523fa7a6SAndroid Build Coastguard Worker    if not fallback_to_cpu:
585*523fa7a6SAndroid Build Coastguard Worker        try:
586*523fa7a6SAndroid Build Coastguard Worker            from executorch.exir.backend.partitioner import DelegationSpec
587*523fa7a6SAndroid Build Coastguard Worker
588*523fa7a6SAndroid Build Coastguard Worker            # change HTP compiler spec for hardware to enable fp16
589*523fa7a6SAndroid Build Coastguard Worker            qnn_option = generate_qnn_executorch_option(
590*523fa7a6SAndroid Build Coastguard Worker                partitioner.compiler_specs_snapshot
591*523fa7a6SAndroid Build Coastguard Worker            )
592*523fa7a6SAndroid Build Coastguard Worker            compile_option = flatbuffer_to_option(qnn_option)
593*523fa7a6SAndroid Build Coastguard Worker            htp_options = compile_option.backend_options.htp_options
594*523fa7a6SAndroid Build Coastguard Worker            htp_options.precision = QnnExecuTorchHtpPrecision.kHtpFp16
595*523fa7a6SAndroid Build Coastguard Worker            partitioner.delegation_spec = DelegationSpec(
596*523fa7a6SAndroid Build Coastguard Worker                "QnnBackend",
597*523fa7a6SAndroid Build Coastguard Worker                [
598*523fa7a6SAndroid Build Coastguard Worker                    CompileSpec(
599*523fa7a6SAndroid Build Coastguard Worker                        QCOM_QNN_COMPILE_SPEC, option_to_flatbuffer(compile_option)
600*523fa7a6SAndroid Build Coastguard Worker                    )
601*523fa7a6SAndroid Build Coastguard Worker                ],
602*523fa7a6SAndroid Build Coastguard Worker            )
603*523fa7a6SAndroid Build Coastguard Worker        except:
604*523fa7a6SAndroid Build Coastguard Worker            print(
605*523fa7a6SAndroid Build Coastguard Worker                "Failed to change HTP compiler spec with 'use_fp16' as True,"
606*523fa7a6SAndroid Build Coastguard Worker                " skipped operators will fallback to cpu,"
607*523fa7a6SAndroid Build Coastguard Worker            )
608*523fa7a6SAndroid Build Coastguard Worker            return graph_module, exported_progs
609*523fa7a6SAndroid Build Coastguard Worker
610*523fa7a6SAndroid Build Coastguard Worker        # try lowering skipped operator into fp16
611*523fa7a6SAndroid Build Coastguard Worker        capability_partitioner = CapabilityBasedPartitioner(
612*523fa7a6SAndroid Build Coastguard Worker            graph_module,
613*523fa7a6SAndroid Build Coastguard Worker            _AnnotationSkipper(skip_annotated_submodule=True),
614*523fa7a6SAndroid Build Coastguard Worker            allows_single_node_partition=True,
615*523fa7a6SAndroid Build Coastguard Worker        )
616*523fa7a6SAndroid Build Coastguard Worker        subgm_tag = "skipped_group"
617*523fa7a6SAndroid Build Coastguard Worker        graph_module = _partition_graph_into_submodules(
618*523fa7a6SAndroid Build Coastguard Worker            gm=graph_module,
619*523fa7a6SAndroid Build Coastguard Worker            subgm_tag=subgm_tag,
620*523fa7a6SAndroid Build Coastguard Worker            subgm_cb=lambda subgm, _: subgm,
621*523fa7a6SAndroid Build Coastguard Worker            ptn=capability_partitioner,
622*523fa7a6SAndroid Build Coastguard Worker        )
623*523fa7a6SAndroid Build Coastguard Worker        graph_module, exported_progs_fp = _canonicalize_graph_with_lowered_module(
624*523fa7a6SAndroid Build Coastguard Worker            gm=graph_module,
625*523fa7a6SAndroid Build Coastguard Worker            subgm_tag=subgm_tag,
626*523fa7a6SAndroid Build Coastguard Worker            ptn=partitioner,
627*523fa7a6SAndroid Build Coastguard Worker        )
628*523fa7a6SAndroid Build Coastguard Worker        exported_progs.extend(exported_progs_fp)
629*523fa7a6SAndroid Build Coastguard Worker
630*523fa7a6SAndroid Build Coastguard Worker    return graph_module, exported_progs
631*523fa7a6SAndroid Build Coastguard Worker
632*523fa7a6SAndroid Build Coastguard Worker
633*523fa7a6SAndroid Build Coastguard Workerdef from_context_binary(  # noqa: C901
634*523fa7a6SAndroid Build Coastguard Worker    ctx_path: str | bytes,
635*523fa7a6SAndroid Build Coastguard Worker    op_name: str,
636*523fa7a6SAndroid Build Coastguard Worker    soc_model: QcomChipset = QcomChipset.SM8650,
637*523fa7a6SAndroid Build Coastguard Worker    custom_info: Dict = None,
638*523fa7a6SAndroid Build Coastguard Worker):
639*523fa7a6SAndroid Build Coastguard Worker    from pathlib import Path
640*523fa7a6SAndroid Build Coastguard Worker
641*523fa7a6SAndroid Build Coastguard Worker    def implement_op(custom_op, op_name, outputs):
642*523fa7a6SAndroid Build Coastguard Worker        @torch.library.impl(
643*523fa7a6SAndroid Build Coastguard Worker            custom_op, str(op_name), dispatch_key="CompositeExplicitAutograd"
644*523fa7a6SAndroid Build Coastguard Worker        )
645*523fa7a6SAndroid Build Coastguard Worker        def op_impl(inputs: List[torch.Tensor]):
646*523fa7a6SAndroid Build Coastguard Worker            return tuple(
647*523fa7a6SAndroid Build Coastguard Worker                torch.zeros(tuple(v.shape), device="meta", dtype=v.dtype)
648*523fa7a6SAndroid Build Coastguard Worker                for v in outputs.values()
649*523fa7a6SAndroid Build Coastguard Worker            )
650*523fa7a6SAndroid Build Coastguard Worker
651*523fa7a6SAndroid Build Coastguard Worker    def build_graph(inputs, outputs):
652*523fa7a6SAndroid Build Coastguard Worker        # custom op declaration
653*523fa7a6SAndroid Build Coastguard Worker        inputs_str = "Tensor[] inputs"
654*523fa7a6SAndroid Build Coastguard Worker        func_proto = f"{op_name}({inputs_str}) -> Any"
655*523fa7a6SAndroid Build Coastguard Worker        custom_op = Library(OpContextLoader.namespace, "FRAGMENT")
656*523fa7a6SAndroid Build Coastguard Worker        custom_op.define(func_proto)
657*523fa7a6SAndroid Build Coastguard Worker        # custom op implementation
658*523fa7a6SAndroid Build Coastguard Worker        implement_op(custom_op, op_name, outputs)
659*523fa7a6SAndroid Build Coastguard Worker
660*523fa7a6SAndroid Build Coastguard Worker        # model architecture mimicking context binary
661*523fa7a6SAndroid Build Coastguard Worker        class Model(torch.nn.Module):
662*523fa7a6SAndroid Build Coastguard Worker            def forward(self, *inputs):
663*523fa7a6SAndroid Build Coastguard Worker                return getattr(
664*523fa7a6SAndroid Build Coastguard Worker                    getattr(torch.ops, OpContextLoader.namespace), op_name
665*523fa7a6SAndroid Build Coastguard Worker                ).default(inputs)
666*523fa7a6SAndroid Build Coastguard Worker
667*523fa7a6SAndroid Build Coastguard Worker        model = Model()
668*523fa7a6SAndroid Build Coastguard Worker        prog = torch.export.export(model, tuple(inputs.values()))
669*523fa7a6SAndroid Build Coastguard Worker        # bookkeeping for variables' life cycle
670*523fa7a6SAndroid Build Coastguard Worker        return {
671*523fa7a6SAndroid Build Coastguard Worker            "custom_op": custom_op,
672*523fa7a6SAndroid Build Coastguard Worker            "custom_module": model,
673*523fa7a6SAndroid Build Coastguard Worker            "exported_program": prog,
674*523fa7a6SAndroid Build Coastguard Worker        }
675*523fa7a6SAndroid Build Coastguard Worker
676*523fa7a6SAndroid Build Coastguard Worker    def build_tensor(tensors, dtype_map):
677*523fa7a6SAndroid Build Coastguard Worker        ret = OrderedDict()
678*523fa7a6SAndroid Build Coastguard Worker        for t in tensors:
679*523fa7a6SAndroid Build Coastguard Worker            dtype = t.GetDataType()
680*523fa7a6SAndroid Build Coastguard Worker            dtype_torch = dtype_map.get(dtype, None)
681*523fa7a6SAndroid Build Coastguard Worker            assert dtype_torch is not None, f"unknown qnn data type {dtype}"
682*523fa7a6SAndroid Build Coastguard Worker            ret[t.GetName()] = torch.zeros(tuple(t.GetDims()), dtype=dtype_torch)
683*523fa7a6SAndroid Build Coastguard Worker
684*523fa7a6SAndroid Build Coastguard Worker        return ret
685*523fa7a6SAndroid Build Coastguard Worker
686*523fa7a6SAndroid Build Coastguard Worker    def preprocess_binary(ctx_bin, compiler_specs):
687*523fa7a6SAndroid Build Coastguard Worker        qnn_mgr = PyQnnManagerAdaptor.QnnManager(
688*523fa7a6SAndroid Build Coastguard Worker            generate_qnn_executorch_option(compiler_specs),
689*523fa7a6SAndroid Build Coastguard Worker        )
690*523fa7a6SAndroid Build Coastguard Worker        return bytes(qnn_mgr.MakeBinaryInfo(ctx_bin))
691*523fa7a6SAndroid Build Coastguard Worker
692*523fa7a6SAndroid Build Coastguard Worker    # dummy compiler spec would be fine, since we're not compiling
693*523fa7a6SAndroid Build Coastguard Worker    backend_options = generate_htp_compiler_spec(use_fp16=False)
694*523fa7a6SAndroid Build Coastguard Worker    compiler_specs = generate_qnn_executorch_compiler_spec(
695*523fa7a6SAndroid Build Coastguard Worker        soc_model=soc_model,
696*523fa7a6SAndroid Build Coastguard Worker        backend_options=backend_options,
697*523fa7a6SAndroid Build Coastguard Worker        is_from_context_binary=True,
698*523fa7a6SAndroid Build Coastguard Worker    )
699*523fa7a6SAndroid Build Coastguard Worker
700*523fa7a6SAndroid Build Coastguard Worker    ctx_bin = (
701*523fa7a6SAndroid Build Coastguard Worker        ctx_path
702*523fa7a6SAndroid Build Coastguard Worker        if not isinstance(ctx_path, str)
703*523fa7a6SAndroid Build Coastguard Worker        else preprocess_binary(Path(f"{ctx_path}").read_bytes(), compiler_specs)
704*523fa7a6SAndroid Build Coastguard Worker    )
705*523fa7a6SAndroid Build Coastguard Worker
706*523fa7a6SAndroid Build Coastguard Worker    dtype_map = {}
707*523fa7a6SAndroid Build Coastguard Worker    for type_map in (QNN_QUANT_TYPE_MAP, QNN_TENSOR_TYPE_MAP):
708*523fa7a6SAndroid Build Coastguard Worker        for k, v in type_map.items():
709*523fa7a6SAndroid Build Coastguard Worker            dtype_map.setdefault(v, k)
710*523fa7a6SAndroid Build Coastguard Worker
711*523fa7a6SAndroid Build Coastguard Worker    if custom_info is not None:
712*523fa7a6SAndroid Build Coastguard Worker        # since some context binaries might fail to open on host
713*523fa7a6SAndroid Build Coastguard Worker        # if they are compiled with special flags:
714*523fa7a6SAndroid Build Coastguard Worker        # e.g. weight sharing
715*523fa7a6SAndroid Build Coastguard Worker        # use custom information here instead
716*523fa7a6SAndroid Build Coastguard Worker        inputs = build_tensor(custom_info["graph_inputs"], dtype_map)
717*523fa7a6SAndroid Build Coastguard Worker        outputs = build_tensor(custom_info["graph_outputs"], dtype_map)
718*523fa7a6SAndroid Build Coastguard Worker        graph_name = custom_info["graph_name"]
719*523fa7a6SAndroid Build Coastguard Worker    else:
720*523fa7a6SAndroid Build Coastguard Worker        # get context-binary io tensor info through qnn manager
721*523fa7a6SAndroid Build Coastguard Worker        qnn_mgr = PyQnnManagerAdaptor.QnnManager(
722*523fa7a6SAndroid Build Coastguard Worker            generate_qnn_executorch_option(compiler_specs),
723*523fa7a6SAndroid Build Coastguard Worker            ctx_bin,
724*523fa7a6SAndroid Build Coastguard Worker        )
725*523fa7a6SAndroid Build Coastguard Worker        assert qnn_mgr.Init().value == 0, "failed to load context binary"
726*523fa7a6SAndroid Build Coastguard Worker        # assume we only have one graph in current context
727*523fa7a6SAndroid Build Coastguard Worker        graph_name = qnn_mgr.GetGraphNames()[0]
728*523fa7a6SAndroid Build Coastguard Worker        qnn_mgr.AllocateTensor(graph_name)
729*523fa7a6SAndroid Build Coastguard Worker        inputs = build_tensor(qnn_mgr.GetGraphInputs(graph_name), dtype_map)
730*523fa7a6SAndroid Build Coastguard Worker        outputs = build_tensor(qnn_mgr.GetGraphOutputs(graph_name), dtype_map)
731*523fa7a6SAndroid Build Coastguard Worker        qnn_mgr.Destroy()
732*523fa7a6SAndroid Build Coastguard Worker
733*523fa7a6SAndroid Build Coastguard Worker    # generate graph specific for loading context
734*523fa7a6SAndroid Build Coastguard Worker    bundle_prog = build_graph(inputs, outputs)
735*523fa7a6SAndroid Build Coastguard Worker    bundle_prog.update({"inputs": inputs, "outputs": outputs})
736*523fa7a6SAndroid Build Coastguard Worker    edge_prog_mgr = to_edge(
737*523fa7a6SAndroid Build Coastguard Worker        programs={graph_name: bundle_prog["exported_program"]},
738*523fa7a6SAndroid Build Coastguard Worker        # do not alter name for custom op
739*523fa7a6SAndroid Build Coastguard Worker        compile_config=EdgeCompileConfig(_use_edge_ops=False),
740*523fa7a6SAndroid Build Coastguard Worker    )
741*523fa7a6SAndroid Build Coastguard Worker    # update meta with context binary
742*523fa7a6SAndroid Build Coastguard Worker    for n in edge_prog_mgr._edge_programs[graph_name].graph.nodes:
743*523fa7a6SAndroid Build Coastguard Worker        if n.op == "call_function" and OpContextLoader.namespace in str(n.target):
744*523fa7a6SAndroid Build Coastguard Worker            n.meta[OpContextLoader.meta_ctx_bin] = ctx_bin
745*523fa7a6SAndroid Build Coastguard Worker            break
746*523fa7a6SAndroid Build Coastguard Worker
747*523fa7a6SAndroid Build Coastguard Worker    bundle_prog["edge_program_manager"] = edge_prog_mgr.to_backend(
748*523fa7a6SAndroid Build Coastguard Worker        QnnPartitioner(compiler_specs)
749*523fa7a6SAndroid Build Coastguard Worker    )
750*523fa7a6SAndroid Build Coastguard Worker    return bundle_prog
751*523fa7a6SAndroid Build Coastguard Worker
752*523fa7a6SAndroid Build Coastguard Worker
753*523fa7a6SAndroid Build Coastguard Workerdef draw_graph(title, path, graph_module: torch.fx.GraphModule):
754*523fa7a6SAndroid Build Coastguard Worker    graph = passes.graph_drawer.FxGraphDrawer(graph_module, title)
755*523fa7a6SAndroid Build Coastguard Worker    with open(f"{path}/{title}.svg", "wb") as f:
756*523fa7a6SAndroid Build Coastguard Worker        f.write(graph.get_dot_graph().create_svg())
757*523fa7a6SAndroid Build Coastguard Worker
758*523fa7a6SAndroid Build Coastguard Worker
759*523fa7a6SAndroid Build Coastguard Workerdef generate_multi_graph_program(
760*523fa7a6SAndroid Build Coastguard Worker    compiler_specs: List[CompileSpec],
761*523fa7a6SAndroid Build Coastguard Worker    processed_bytes: List[bytes],
762*523fa7a6SAndroid Build Coastguard Worker    backend_config: ExecutorchBackendConfig = None,
763*523fa7a6SAndroid Build Coastguard Worker) -> ExecutorchProgramManager:
764*523fa7a6SAndroid Build Coastguard Worker    # compile multiple graphs in qcir into single context binary
765*523fa7a6SAndroid Build Coastguard Worker    graph_inputs, graph_outputs = {}, {}
766*523fa7a6SAndroid Build Coastguard Worker    qnn_mgr = PyQnnManagerAdaptor.QnnManager(
767*523fa7a6SAndroid Build Coastguard Worker        generate_qnn_executorch_option(compiler_specs), processed_bytes
768*523fa7a6SAndroid Build Coastguard Worker    )
769*523fa7a6SAndroid Build Coastguard Worker    assert qnn_mgr.Init().value == 0, "failed to load processed bytes"
770*523fa7a6SAndroid Build Coastguard Worker    binary_info = bytes(qnn_mgr.Compile())
771*523fa7a6SAndroid Build Coastguard Worker    assert len(binary_info) != 0, "failed to generate QNN context binary"
772*523fa7a6SAndroid Build Coastguard Worker    graph_names = qnn_mgr.GetGraphNames()
773*523fa7a6SAndroid Build Coastguard Worker    for graph_name in graph_names:
774*523fa7a6SAndroid Build Coastguard Worker        graph_inputs[graph_name] = qnn_mgr.GetGraphInputs(graph_name)
775*523fa7a6SAndroid Build Coastguard Worker        graph_outputs[graph_name] = qnn_mgr.GetGraphOutputs(graph_name)
776*523fa7a6SAndroid Build Coastguard Worker    qnn_mgr.Destroy()
777*523fa7a6SAndroid Build Coastguard Worker
778*523fa7a6SAndroid Build Coastguard Worker    # build custom ops with different graph signatures
779*523fa7a6SAndroid Build Coastguard Worker    compiler_options = flatbuffer_to_option(compiler_specs[0].value)
780*523fa7a6SAndroid Build Coastguard Worker    bundle_progs = [
781*523fa7a6SAndroid Build Coastguard Worker        from_context_binary(
782*523fa7a6SAndroid Build Coastguard Worker            ctx_path=binary_info,
783*523fa7a6SAndroid Build Coastguard Worker            op_name=f"loader_{graph_name}",
784*523fa7a6SAndroid Build Coastguard Worker            soc_model=compiler_options.soc_info.soc_model,
785*523fa7a6SAndroid Build Coastguard Worker            custom_info={
786*523fa7a6SAndroid Build Coastguard Worker                "graph_inputs": graph_inputs[graph_name],
787*523fa7a6SAndroid Build Coastguard Worker                "graph_outputs": graph_outputs[graph_name],
788*523fa7a6SAndroid Build Coastguard Worker                "graph_name": graph_name,
789*523fa7a6SAndroid Build Coastguard Worker            },
790*523fa7a6SAndroid Build Coastguard Worker        )
791*523fa7a6SAndroid Build Coastguard Worker        for graph_name in graph_names
792*523fa7a6SAndroid Build Coastguard Worker    ]
793*523fa7a6SAndroid Build Coastguard Worker    # leverage ExecutorchProgramManager for generating pte with multi-methods
794*523fa7a6SAndroid Build Coastguard Worker    edge_prog_mgr = to_edge(
795*523fa7a6SAndroid Build Coastguard Worker        programs={
796*523fa7a6SAndroid Build Coastguard Worker            graph_name: bundle_prog["exported_program"]
797*523fa7a6SAndroid Build Coastguard Worker            for graph_name, bundle_prog in zip(graph_names, bundle_progs)
798*523fa7a6SAndroid Build Coastguard Worker        },
799*523fa7a6SAndroid Build Coastguard Worker        # do not alter name for custom op
800*523fa7a6SAndroid Build Coastguard Worker        compile_config=EdgeCompileConfig(_use_edge_ops=False),
801*523fa7a6SAndroid Build Coastguard Worker    )
802*523fa7a6SAndroid Build Coastguard Worker    # restore meta losed in generating EdgeProgramManager
803*523fa7a6SAndroid Build Coastguard Worker    for graph_name in graph_names:
804*523fa7a6SAndroid Build Coastguard Worker        for n in edge_prog_mgr._edge_programs[graph_name].graph.nodes:
805*523fa7a6SAndroid Build Coastguard Worker            if graph_name in n.name:
806*523fa7a6SAndroid Build Coastguard Worker                n.meta[OpContextLoader.meta_ctx_bin] = binary_info
807*523fa7a6SAndroid Build Coastguard Worker                break
808*523fa7a6SAndroid Build Coastguard Worker
809*523fa7a6SAndroid Build Coastguard Worker    return edge_prog_mgr.to_backend(QnnPartitioner(compiler_specs)).to_executorch(
810*523fa7a6SAndroid Build Coastguard Worker        config=backend_config or ExecutorchBackendConfig()
811*523fa7a6SAndroid Build Coastguard Worker    )
812*523fa7a6SAndroid Build Coastguard Worker
813*523fa7a6SAndroid Build Coastguard Worker
814*523fa7a6SAndroid Build Coastguard Workerdef generate_htp_compiler_spec(
815*523fa7a6SAndroid Build Coastguard Worker    use_fp16: bool,
816*523fa7a6SAndroid Build Coastguard Worker    use_dlbc: bool = False,
817*523fa7a6SAndroid Build Coastguard Worker    use_multi_contexts: bool = False,
818*523fa7a6SAndroid Build Coastguard Worker) -> QnnExecuTorchBackendOptions:
819*523fa7a6SAndroid Build Coastguard Worker    """
820*523fa7a6SAndroid Build Coastguard Worker    Helper function generating backend options for QNN HTP
821*523fa7a6SAndroid Build Coastguard Worker
822*523fa7a6SAndroid Build Coastguard Worker    Args:
823*523fa7a6SAndroid Build Coastguard Worker        use_fp16: If true, the model is compiled to QNN HTP fp16 runtime.
824*523fa7a6SAndroid Build Coastguard Worker            Note that not all SoC support QNN HTP fp16. Only premium tier SoC
825*523fa7a6SAndroid Build Coastguard Worker            like Snapdragon 8 Gen 1 or newer can support HTP fp16.
826*523fa7a6SAndroid Build Coastguard Worker        use_dlbc: Deep Learning Bandwidth Compression allows inputs to be
827*523fa7a6SAndroid Build Coastguard Worker            compressed, such that the processing bandwidth can be lowered.
828*523fa7a6SAndroid Build Coastguard Worker        use_multi_contexts: When multiple contexts are generated inside the same
829*523fa7a6SAndroid Build Coastguard Worker            pte, it is possible to reserve a single spill-fill allocation that
830*523fa7a6SAndroid Build Coastguard Worker            could be re-used across all the splits.
831*523fa7a6SAndroid Build Coastguard Worker
832*523fa7a6SAndroid Build Coastguard Worker    Returns:
833*523fa7a6SAndroid Build Coastguard Worker        QnnExecuTorchHtpBackendOptions: backend options for QNN HTP.
834*523fa7a6SAndroid Build Coastguard Worker    """
835*523fa7a6SAndroid Build Coastguard Worker    htp_options = QnnExecuTorchHtpBackendOptions()
836*523fa7a6SAndroid Build Coastguard Worker    htp_options.precision = (
837*523fa7a6SAndroid Build Coastguard Worker        QnnExecuTorchHtpPrecision.kHtpFp16
838*523fa7a6SAndroid Build Coastguard Worker        if use_fp16
839*523fa7a6SAndroid Build Coastguard Worker        else QnnExecuTorchHtpPrecision.kHtpQuantized
840*523fa7a6SAndroid Build Coastguard Worker    )
841*523fa7a6SAndroid Build Coastguard Worker    # This actually is not an option which can affect the compiled blob.
842*523fa7a6SAndroid Build Coastguard Worker    # But we don't have other place to pass this option at execution stage.
843*523fa7a6SAndroid Build Coastguard Worker    # TODO: enable voting mechanism in runtime and make this as an option
844*523fa7a6SAndroid Build Coastguard Worker    htp_options.performance_mode = QnnExecuTorchHtpPerformanceMode.kHtpBurst
845*523fa7a6SAndroid Build Coastguard Worker    htp_options.use_multi_contexts = use_multi_contexts
846*523fa7a6SAndroid Build Coastguard Worker    htp_options.use_dlbc = use_dlbc
847*523fa7a6SAndroid Build Coastguard Worker    return QnnExecuTorchBackendOptions(
848*523fa7a6SAndroid Build Coastguard Worker        backend_type=QnnExecuTorchBackendType.kHtpBackend,
849*523fa7a6SAndroid Build Coastguard Worker        htp_options=htp_options,
850*523fa7a6SAndroid Build Coastguard Worker    )
851*523fa7a6SAndroid Build Coastguard Worker
852*523fa7a6SAndroid Build Coastguard Worker
853*523fa7a6SAndroid Build Coastguard Workerdef generate_qnn_executorch_compiler_spec(
854*523fa7a6SAndroid Build Coastguard Worker    soc_model: QcomChipset,
855*523fa7a6SAndroid Build Coastguard Worker    backend_options: QnnExecuTorchBackendOptions,
856*523fa7a6SAndroid Build Coastguard Worker    debug: bool = False,
857*523fa7a6SAndroid Build Coastguard Worker    saver: bool = False,
858*523fa7a6SAndroid Build Coastguard Worker    online_prepare: bool = False,
859*523fa7a6SAndroid Build Coastguard Worker    dump_intermediate_outputs: bool = False,
860*523fa7a6SAndroid Build Coastguard Worker    profile: bool = False,
861*523fa7a6SAndroid Build Coastguard Worker    optrace: bool = False,
862*523fa7a6SAndroid Build Coastguard Worker    shared_buffer: bool = False,
863*523fa7a6SAndroid Build Coastguard Worker    is_from_context_binary: bool = False,
864*523fa7a6SAndroid Build Coastguard Worker    multiple_graphs: bool = False,
865*523fa7a6SAndroid Build Coastguard Worker    graph_name: str = "forward",
866*523fa7a6SAndroid Build Coastguard Worker) -> List[CompileSpec]:
867*523fa7a6SAndroid Build Coastguard Worker    """
868*523fa7a6SAndroid Build Coastguard Worker    Helper function generating compiler specs for Qualcomm AI Engine Direct
869*523fa7a6SAndroid Build Coastguard Worker
870*523fa7a6SAndroid Build Coastguard Worker    Args:
871*523fa7a6SAndroid Build Coastguard Worker        soc_model: The SoC you plan to run the compiled model. Please check
872*523fa7a6SAndroid Build Coastguard Worker            QcomChipset for supported SoC.
873*523fa7a6SAndroid Build Coastguard Worker            SM8450 (Snapdragon 8 Gen 1)
874*523fa7a6SAndroid Build Coastguard Worker            SM8475(Snapdragon 8 Gen 1+)
875*523fa7a6SAndroid Build Coastguard Worker            SM8550(Snapdragon 8 Gen 2)
876*523fa7a6SAndroid Build Coastguard Worker            SM8650(Snapdragon 8 Gen 3)
877*523fa7a6SAndroid Build Coastguard Worker        backend_options: Options required by different backends.
878*523fa7a6SAndroid Build Coastguard Worker        debug: Enable verbose logging. Disclaimer: this option must change in
879*523fa7a6SAndroid Build Coastguard Worker            the near future.
880*523fa7a6SAndroid Build Coastguard Worker        online_prepare: Compose QNN graph on device if set to True
881*523fa7a6SAndroid Build Coastguard Worker        saver: Instead of compiling the model, run QNN Saver. Please check
882*523fa7a6SAndroid Build Coastguard Worker            documents of Qualcomm AI Engine Direct SDK. This feature is usually
883*523fa7a6SAndroid Build Coastguard Worker            for debugging purpose.
884*523fa7a6SAndroid Build Coastguard Worker        dump_intermediate_outputs: If tensor dump is enabled, all intermediate tensors output will be dumped.
885*523fa7a6SAndroid Build Coastguard Worker            This option exists for debugging accuracy issues
886*523fa7a6SAndroid Build Coastguard Worker        profile: Enable profile the performance of per operator.
887*523fa7a6SAndroid Build Coastguard Worker            Note that for now only support kProfileDetailed to
888*523fa7a6SAndroid Build Coastguard Worker            profile the performance of each operator with cycle unit.
889*523fa7a6SAndroid Build Coastguard Worker        shared_buffer: Enables usage of shared buffer between application
890*523fa7a6SAndroid Build Coastguard Worker            and backend for graph I/O.
891*523fa7a6SAndroid Build Coastguard Worker        is_from_context_binary: True if current graph comes from pre-built context binary.
892*523fa7a6SAndroid Build Coastguard Worker        multiple_graphs: True if multiple methods are expected to have in single .pte file.
893*523fa7a6SAndroid Build Coastguard Worker            Please see test cases for post-processing example.
894*523fa7a6SAndroid Build Coastguard Worker        graph_name: Assign unique graph name if 'multiple_graphs' is used.
895*523fa7a6SAndroid Build Coastguard Worker
896*523fa7a6SAndroid Build Coastguard Worker    Returns:
897*523fa7a6SAndroid Build Coastguard Worker        List[CompileSpec]: Compiler specs for Qualcomm AI Engine Direct.
898*523fa7a6SAndroid Build Coastguard Worker
899*523fa7a6SAndroid Build Coastguard Worker    Raises:
900*523fa7a6SAndroid Build Coastguard Worker        ValueError: The value QcomChipset is currently not supported.
901*523fa7a6SAndroid Build Coastguard Worker        ValueError: Confliction between compiler specs.
902*523fa7a6SAndroid Build Coastguard Worker    """
903*523fa7a6SAndroid Build Coastguard Worker    _supported_soc_models = {soc_model.value for soc_model in QcomChipset}
904*523fa7a6SAndroid Build Coastguard Worker    if soc_model not in _supported_soc_models:
905*523fa7a6SAndroid Build Coastguard Worker        raise ValueError(f"unknown SoC model for QNN: {soc_model}")
906*523fa7a6SAndroid Build Coastguard Worker
907*523fa7a6SAndroid Build Coastguard Worker    if profile and dump_intermediate_outputs:
908*523fa7a6SAndroid Build Coastguard Worker        warnings.warn(
909*523fa7a6SAndroid Build Coastguard Worker            "It is not recommended to turn on both profiling and dump_intermediate_outputs the same time"
910*523fa7a6SAndroid Build Coastguard Worker            ", because dump_intermediate_outputs will cause performance drop.",
911*523fa7a6SAndroid Build Coastguard Worker            stacklevel=1,
912*523fa7a6SAndroid Build Coastguard Worker        )
913*523fa7a6SAndroid Build Coastguard Worker
914*523fa7a6SAndroid Build Coastguard Worker    qnn_executorch_options = QnnExecuTorchOptions(
915*523fa7a6SAndroid Build Coastguard Worker        _soc_info_table[soc_model], backend_options
916*523fa7a6SAndroid Build Coastguard Worker    )
917*523fa7a6SAndroid Build Coastguard Worker    qnn_executorch_options.graph_name = graph_name
918*523fa7a6SAndroid Build Coastguard Worker    qnn_executorch_options.log_level = (
919*523fa7a6SAndroid Build Coastguard Worker        QnnExecuTorchLogLevel.kLogLevelDebug
920*523fa7a6SAndroid Build Coastguard Worker        if debug
921*523fa7a6SAndroid Build Coastguard Worker        else QnnExecuTorchLogLevel.kLogLevelWarn
922*523fa7a6SAndroid Build Coastguard Worker    )
923*523fa7a6SAndroid Build Coastguard Worker
924*523fa7a6SAndroid Build Coastguard Worker    qnn_executorch_options.dump_intermediate_outputs = dump_intermediate_outputs
925*523fa7a6SAndroid Build Coastguard Worker
926*523fa7a6SAndroid Build Coastguard Worker    if saver:
927*523fa7a6SAndroid Build Coastguard Worker        qnn_executorch_options.library_path = "libQnnSaver.so"
928*523fa7a6SAndroid Build Coastguard Worker
929*523fa7a6SAndroid Build Coastguard Worker    if optrace:
930*523fa7a6SAndroid Build Coastguard Worker        qnn_executorch_options.profile_level = QnnExecuTorchProfileLevel.kProfileOptrace
931*523fa7a6SAndroid Build Coastguard Worker    elif profile:
932*523fa7a6SAndroid Build Coastguard Worker        qnn_executorch_options.profile_level = (
933*523fa7a6SAndroid Build Coastguard Worker            QnnExecuTorchProfileLevel.kProfileDetailed
934*523fa7a6SAndroid Build Coastguard Worker        )
935*523fa7a6SAndroid Build Coastguard Worker    else:
936*523fa7a6SAndroid Build Coastguard Worker        qnn_executorch_options.profile_level = QnnExecuTorchProfileLevel.kProfileOff
937*523fa7a6SAndroid Build Coastguard Worker
938*523fa7a6SAndroid Build Coastguard Worker    if (
939*523fa7a6SAndroid Build Coastguard Worker        online_prepare
940*523fa7a6SAndroid Build Coastguard Worker        and backend_options.backend_type == QnnExecuTorchBackendType.kHtpBackend
941*523fa7a6SAndroid Build Coastguard Worker        and backend_options.htp_options.use_multi_contexts
942*523fa7a6SAndroid Build Coastguard Worker    ):
943*523fa7a6SAndroid Build Coastguard Worker        raise ValueError(
944*523fa7a6SAndroid Build Coastguard Worker            "'use_multi_context' could not function in online prepare mode, "
945*523fa7a6SAndroid Build Coastguard Worker            "please set 'online_prepare' to False"
946*523fa7a6SAndroid Build Coastguard Worker        )
947*523fa7a6SAndroid Build Coastguard Worker
948*523fa7a6SAndroid Build Coastguard Worker    qnn_executorch_options.shared_buffer = shared_buffer
949*523fa7a6SAndroid Build Coastguard Worker    qnn_executorch_options.online_prepare = online_prepare
950*523fa7a6SAndroid Build Coastguard Worker    qnn_executorch_options.is_from_context_binary = is_from_context_binary
951*523fa7a6SAndroid Build Coastguard Worker    qnn_executorch_options.multiple_graphs = multiple_graphs
952*523fa7a6SAndroid Build Coastguard Worker
953*523fa7a6SAndroid Build Coastguard Worker    if multiple_graphs:
954*523fa7a6SAndroid Build Coastguard Worker        # enable weight sharing mechanism if multiple graphs appear
955*523fa7a6SAndroid Build Coastguard Worker        if backend_options.backend_type == QnnExecuTorchBackendType.kHtpBackend:
956*523fa7a6SAndroid Build Coastguard Worker            backend_options.htp_options.use_weight_sharing = True
957*523fa7a6SAndroid Build Coastguard Worker
958*523fa7a6SAndroid Build Coastguard Worker    return [
959*523fa7a6SAndroid Build Coastguard Worker        CompileSpec(QCOM_QNN_COMPILE_SPEC, option_to_flatbuffer(qnn_executorch_options))
960*523fa7a6SAndroid Build Coastguard Worker    ]
961*523fa7a6SAndroid Build Coastguard Worker
962*523fa7a6SAndroid Build Coastguard Worker
963*523fa7a6SAndroid Build Coastguard Workerdef get_soc_to_arch_map():
964*523fa7a6SAndroid Build Coastguard Worker    return {
965*523fa7a6SAndroid Build Coastguard Worker        "SSG2115P": HtpArch.V73,
966*523fa7a6SAndroid Build Coastguard Worker        "SM8650": HtpArch.V75,
967*523fa7a6SAndroid Build Coastguard Worker        "SM8550": HtpArch.V73,
968*523fa7a6SAndroid Build Coastguard Worker        "SM8475": HtpArch.V69,
969*523fa7a6SAndroid Build Coastguard Worker        "SM8450": HtpArch.V69,
970*523fa7a6SAndroid Build Coastguard Worker        "SA8295": HtpArch.V68,
971*523fa7a6SAndroid Build Coastguard Worker    }
972*523fa7a6SAndroid Build Coastguard Worker
973*523fa7a6SAndroid Build Coastguard Worker
974*523fa7a6SAndroid Build Coastguard Workerdef get_soc_to_chipset_map():
975*523fa7a6SAndroid Build Coastguard Worker    return {
976*523fa7a6SAndroid Build Coastguard Worker        "SSG2115P": QcomChipset.SSG2115P,
977*523fa7a6SAndroid Build Coastguard Worker        "SM8650": QcomChipset.SM8650,
978*523fa7a6SAndroid Build Coastguard Worker        "SM8550": QcomChipset.SM8550,
979*523fa7a6SAndroid Build Coastguard Worker        "SM8475": QcomChipset.SM8475,
980*523fa7a6SAndroid Build Coastguard Worker        "SM8450": QcomChipset.SM8450,
981*523fa7a6SAndroid Build Coastguard Worker        "SA8295": QcomChipset.SA8295,
982*523fa7a6SAndroid Build Coastguard Worker    }
983*523fa7a6SAndroid Build Coastguard Worker
984*523fa7a6SAndroid Build Coastguard Worker
985*523fa7a6SAndroid Build Coastguard Workerdef tag_quant_io(gm: torch.fx.GraphModule, get_quant_io_dtype_fn: Callable):
986*523fa7a6SAndroid Build Coastguard Worker    """
987*523fa7a6SAndroid Build Coastguard Worker    Tag io nodes which get/output quantized tensor. No need to insert q/dq in qnn_preprocess
988*523fa7a6SAndroid Build Coastguard Worker    """
989*523fa7a6SAndroid Build Coastguard Worker    for node in gm.graph.nodes:
990*523fa7a6SAndroid Build Coastguard Worker        if dtype := get_quant_io_dtype_fn(node):
991*523fa7a6SAndroid Build Coastguard Worker            node.meta[QCOM_QUANTIZED_IO] = dtype
992