xref: /aosp_15_r20/external/executorch/backends/qualcomm/utils/utils.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import operator
8import warnings
9from collections import OrderedDict
10from typing import Callable, Dict, FrozenSet, List, Tuple
11
12import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManagerAdaptor
13
14import executorch.exir as exir
15
16import torch
17from executorch.backends.qualcomm._passes.annotate_and_quant_scalar import (
18    AnnotateAndQuantScalar,
19)
20from executorch.backends.qualcomm._passes.annotate_decomposed import AnnotateDecomposed
21from executorch.backends.qualcomm._passes.annotate_quant_attrs import AnnotateQuantAttrs
22from executorch.backends.qualcomm._passes.convert_binary_op_with_scalar import (
23    ConvertBinaryOpsWithScalar,
24)
25from executorch.backends.qualcomm._passes.convert_bmm_to_matmul import (
26    ConvertBmmToMatmul,
27)
28from executorch.backends.qualcomm._passes.convert_interpolate_with_upsample2d import (
29    ConvertInterpolateWithUpsample2D,
30)
31from executorch.backends.qualcomm._passes.convert_prelu import ConvertPReLU
32from executorch.backends.qualcomm._passes.convert_to_linear import ConvertToLinear
33from executorch.backends.qualcomm._passes.expand_broadcast_tensor_shape import (
34    ExpandBroadcastTensorShape,
35)
36from executorch.backends.qualcomm._passes.fold_qdq import FoldQDQ
37from executorch.backends.qualcomm._passes.i64_to_i32 import I64toI32
38from executorch.backends.qualcomm._passes.layout_transform import LayoutTransform
39from executorch.backends.qualcomm._passes.recompose_pixel_unshuffle import (
40    RecomposePixelUnshuffle,
41)
42from executorch.backends.qualcomm._passes.recompose_rms_norm import RecomposeRmsNorm
43from executorch.backends.qualcomm._passes.remove_redundancy import RemoveRedundancy
44from executorch.backends.qualcomm._passes.replace_index_put_input import (
45    ReplaceIndexPutInput,
46)
47
48from executorch.backends.qualcomm.builders.node_visitor import (
49    QNN_QUANT_TYPE_MAP,
50    QNN_TENSOR_TYPE_MAP,
51)
52from executorch.backends.qualcomm.builders.qnn_constants import OpContextLoader
53from executorch.backends.qualcomm.partition.qnn_partitioner import (
54    generate_qnn_executorch_option,
55    QnnPartitioner,
56)
57from executorch.backends.qualcomm.serialization.qc_schema import (
58    _soc_info_table,
59    HtpArch,
60    QcomChipset,
61    QnnExecuTorchBackendOptions,
62    QnnExecuTorchBackendType,
63    QnnExecuTorchHtpBackendOptions,
64    QnnExecuTorchHtpPerformanceMode,
65    QnnExecuTorchHtpPrecision,
66    QnnExecuTorchLogLevel,
67    QnnExecuTorchOptions,
68    QnnExecuTorchProfileLevel,
69)
70from executorch.backends.qualcomm.serialization.qc_schema_serialize import (
71    flatbuffer_to_option,
72    option_to_flatbuffer,
73)
74from executorch.backends.qualcomm.utils.constants import (
75    QCOM_PASS_EXPAND_BROADCAST_SHAPE,
76    QCOM_PASS_SKIP_ADVANCED_REQUANT,
77    QCOM_QNN_COMPILE_SPEC,
78    QCOM_QUANTIZED_IO,
79)
80
81from executorch.exir import (
82    EdgeCompileConfig,
83    ExecutorchProgramManager,
84    ExirExportedProgram,
85    to_edge,
86)
87from executorch.exir.backend.compile_spec_schema import CompileSpec
88from executorch.exir.capture import ExecutorchBackendConfig
89from executorch.exir.lowered_backend_module import LoweredBackendModule
90from executorch.exir.program._program import _get_updated_graph_signature
91from torch._decomp import core_aten_decompositions as torch_core_aten_decompositions
92from torch.export.exported_program import ExportedProgram
93from torch.fx import passes
94from torch.fx.passes.operator_support import OperatorSupportBase
95from torch.library import Library
96
97
98class _AnnotationSkipper(OperatorSupportBase):
99    """
100    Class used to partition out unwanted graph nodes.
101    e.g. - nodes are prevented from quantization annotation
102         - nodes have been grouped together as a submodule
103
104    Attributes
105    ----------
106    fp_node_id_set : set
107        a set contains nodes' name to be left in fp precision
108    fp_node_op_set : set
109        a set contains nodes' target (aten dialect) to be left in fp precision
110    skip_annotated_submodule : bool
111        flag to skip annotated submodule or not
112
113    Methods
114    -------
115    should_delegate(n: torch.fx.Node)
116        identify the residual nodes haven't be lowered with fixed-precision
117    should_skip(n: torch.fx.Node)
118        identify the nodes should be kept out with fixed-precision or not
119    is_node_supported(_, node: torch.fx.Node)
120        overridden method for graph partitioning
121    """
122
123    def __init__(
124        self,
125        fp_node_id_set: set = None,
126        fp_node_op_set: set = None,
127        skip_annotated_submodule: bool = False,
128    ):
129        self.fp_node_id_set = fp_node_id_set
130        self.fp_node_op_set = fp_node_op_set
131        self.skip_annotated_submodule = skip_annotated_submodule
132
133    def should_delegate(self, n: torch.fx.Node):
134        return n.op == "call_function" and n.target != operator.getitem
135
136    def should_skip(self, n: torch.fx.Node):
137        return n.name in self.fp_node_id_set or n.target in self.fp_node_op_set
138
139    def is_node_supported(self, _, node: torch.fx.Node) -> bool:
140        if self.skip_annotated_submodule:
141            if node.op == "get_attr":
142                return all(self.should_delegate(user) for user in node.users)
143            return self.should_delegate(node)
144
145        if any(
146            [
147                node.op in ("placeholder", "output"),
148                self.should_skip(node),
149                # check if parameters belong to fallbacked operator
150                (
151                    node.op == "get_attr"
152                    and all(self.should_skip(user) for user in node.users)
153                ),
154            ]
155        ):
156            print(f"[QNN Quantizer Annotation]: {node.name} | Skipped")
157            return False
158
159        return True
160
161
162def qnn_capture_config():
163    return exir.CaptureConfig(enable_aot=True)
164
165
166def qnn_edge_config() -> exir.EdgeCompileConfig:
167    return exir.EdgeCompileConfig(
168        _check_ir_validity=False,
169        _skip_dim_order=True,  # TODO(T182928844): Delegate dim order op to backend.
170    )
171
172
173def convert_linear_to_conv2d(module: torch.nn.Module):
174    class Conv2D(torch.nn.Module):
175        def __init__(self, weight, bias=None):
176            super().__init__()
177            use_bias = bias is not None
178            self.conv = torch.nn.Conv2d(
179                in_channels=weight.shape[0],
180                out_channels=weight.shape[1],
181                kernel_size=1,
182                padding=0,
183                bias=use_bias,
184            )
185            self.conv.weight = torch.nn.Parameter(weight.reshape(*weight.shape, 1, 1))
186            if use_bias:
187                self.conv.bias = torch.nn.Parameter(bias)
188
189        def forward(self, x):
190            rank = x.dim()
191            x = x.unsqueeze(-1) if rank == 3 else x.reshape(1, *x.shape, 1)
192            x = torch.transpose(x, 1, 2)
193            res = self.conv(x)
194            res = torch.transpose(res, 1, 2)
195            res = res.squeeze(-1) if rank == 3 else res.reshape(*res.shape[1:3])
196            return res
197
198    def replace_linear(module: torch.nn.Module):
199        attr_strs = dir(module)
200        if isinstance(module, torch.nn.ModuleList):
201            attr_strs += [str(i) for i in range(len(module))]
202
203        for attr_str in attr_strs:
204            target_attr = getattr(module, attr_str)
205            if isinstance(target_attr, torch.nn.Linear):
206                setattr(module, attr_str, Conv2D(target_attr.weight, target_attr.bias))
207
208        for _, sub_module in module.named_children():
209            sub_module = replace_linear(sub_module)
210        return module
211
212    return replace_linear(module)
213
214
215def update_spill_fill_size(
216    exported_program: ExportedProgram | List[LoweredBackendModule],
217):
218    # check if user specifies to use multi_contexts
219    # this is a generic approach in case there exists multiple backends
220    def get_program_info(program):
221        def process_exported_program(prog):
222            max_sf_buf_size, module_map = 0, {}
223            for _, m in prog.graph_module._modules.items():
224                # currently only 1 compile spec is expected in each partition
225                options = flatbuffer_to_option(m.compile_specs[0].value)
226                if (
227                    options.backend_options.backend_type
228                    == QnnExecuTorchBackendType.kHtpBackend
229                    and options.backend_options.htp_options.use_multi_contexts
230                ):
231                    qnn_mgr = PyQnnManagerAdaptor.QnnManager(
232                        m.compile_specs[0].value, m.processed_bytes
233                    )
234                    assert qnn_mgr.Init().value == 0, "failed to load context binary"
235                    max_sf_buf_size = max(
236                        max_sf_buf_size, qnn_mgr.GetSpillFillBufferSize()
237                    )
238                    module_map[m] = options
239                    qnn_mgr.Destroy()
240            return max_sf_buf_size, module_map
241
242        def process_lowered_module(module):
243            qnn_mgr = PyQnnManagerAdaptor.QnnManager(
244                module.compile_specs[0].value, module.processed_bytes
245            )
246            assert qnn_mgr.Init().value == 0, "failed to load context binary"
247            spill_fill_size = qnn_mgr.GetSpillFillBufferSize()
248            qnn_mgr.Destroy()
249            return spill_fill_size, {
250                module: flatbuffer_to_option(module.compile_specs[0].value)
251            }
252
253        dispatch = {
254            ExportedProgram: process_exported_program,
255            LoweredBackendModule: process_lowered_module,
256        }
257        return dispatch[type(program)](program)
258
259    def update_program(max_sf_buf_size, module_map):
260        def set_spec(module, options):
261            spec = CompileSpec(QCOM_QNN_COMPILE_SPEC, option_to_flatbuffer(options))
262            if isinstance(module, ExportedProgram):
263                module.compile_specs[0] = spec
264            else:
265                module._compile_specs[0] = spec
266
267        for module, options in module_map.items():
268            options.backend_options.htp_options.max_sf_buf_size = max_sf_buf_size
269            set_spec(module, options)
270
271    if isinstance(exported_program, list):
272        max_sf_size, modules_map = 0, {}
273        for prog in exported_program:
274            max_sf_buf_size, module_map = get_program_info(prog)
275            max_sf_size = max(max_sf_size, max_sf_buf_size)
276            modules_map.update(module_map)
277        update_program(max_sf_size, modules_map)
278    else:
279        update_program(*get_program_info(exported_program))
280
281
282def get_decomp_table() -> Dict[torch._ops.OperatorBase, Callable]:
283    source_decompositions = torch_core_aten_decompositions()
284    # The below super ops are supported by QNN
285    remove_decompositions = [
286        torch.ops.aten.pixel_shuffle.default,
287        torch.ops.aten.pixel_unshuffle.default,
288        torch.ops.aten.hardsigmoid.default,
289        torch.ops.aten.hardswish.default,
290        torch.ops.aten._safe_softmax.default,
291    ]
292
293    for key in remove_decompositions:
294        source_decompositions.pop(key)
295
296    return source_decompositions
297
298
299def _transform(
300    edge_program: ExportedProgram, custom_pass_config: FrozenSet[str] = frozenset()
301) -> ExportedProgram:
302    # currently ExirExportedProgram.transform does not accept
303    # changes of input number which was caused by FoldQDQ
304    # apply passes one by one here to avoid IR capture failure
305    graph_module = edge_program.graph_module
306    RemoveRedundancy()(graph_module)
307    RecomposePixelUnshuffle()(graph_module)
308    RecomposeRmsNorm()(graph_module)
309    ConvertToLinear()(graph_module)
310    ConvertPReLU(edge_program)(graph_module)
311    ConvertBmmToMatmul()(graph_module)
312    ConvertInterpolateWithUpsample2D()(graph_module)
313    I64toI32(edge_program)(graph_module)
314    AnnotateQuantAttrs(
315        edge_program, QCOM_PASS_SKIP_ADVANCED_REQUANT in custom_pass_config
316    )(graph_module)
317    AnnotateAndQuantScalar(edge_program)(graph_module)
318    AnnotateDecomposed(edge_program)(graph_module)
319    FoldQDQ()(graph_module)
320    # this pass is not necessary for network without layout-sensitive ops
321    # enable defaultly will introduce overhead from extra view_copy nodes
322    if QCOM_PASS_EXPAND_BROADCAST_SHAPE in custom_pass_config:
323        ExpandBroadcastTensorShape()(graph_module)
324    LayoutTransform(edge_program)(graph_module)
325    ReplaceIndexPutInput(edge_program)(graph_module)
326
327    # Since QDQ nodes are stripped, update graph signature again to validate program
328    edge_program._graph_signature = _get_updated_graph_signature(
329        edge_program.graph_signature,
330        edge_program.graph_module,
331    )
332    edge_program._validate()
333    return edge_program
334
335
336def capture_program(
337    module: torch.nn.Module,
338    inputs: Tuple[torch.Tensor],
339    custom_pass_config: FrozenSet[str] = frozenset(),
340) -> exir.ExirExportedProgram:
341    ep = torch.export.export(module, inputs)
342    decomposed_ep = ep.run_decompositions(get_decomp_table())
343    # We choose call_operator by target in ConvertBinaryOpsWithScalar
344    # because it is the same source_fn_stack for MultiheadAttention
345    # TODO: Should modify the scalar op in the op builder instead of
346    #       using transformation
347    core_ep = ExirExportedProgram(decomposed_ep, False)
348    core_ep.transform(ConvertBinaryOpsWithScalar())
349    edge_ep = core_ep.to_edge(qnn_edge_config())
350    _transform(edge_ep.exported_program, custom_pass_config)
351    return edge_ep
352
353
354def _partition_graph_into_submodules(gm, subgm_tag, subgm_cb, ptn):
355    from torch.fx.passes.utils.fuser_utils import (
356        erase_nodes,
357        fuse_as_graphmodule,
358        insert_subgm,
359        legalize_graph,
360        topo_sort,
361    )
362
363    partitions = ptn.propose_partitions()
364    # insert meta for each partition group
365    for i, partition in enumerate(partitions):
366        for node in partition.nodes:
367            node.meta[subgm_tag] = i
368
369    for i in range(len(partitions)):
370        # find nodes with same group id in current graph
371        node_list = [
372            node for node in gm.graph.nodes if node.meta.get(subgm_tag, "") == i
373        ]
374        # fuse group nodes into submodule
375        sorted_nodes = topo_sort(node_list)
376        submodule_name = f"{subgm_tag}_{i}"
377        subgm, orig_inputs, orig_outputs = fuse_as_graphmodule(
378            gm, sorted_nodes, submodule_name
379        )
380        # insert submodule & trim group nodes
381        gm = insert_subgm(
382            gm,
383            subgm_cb(subgm, submodule_name),
384            orig_inputs,
385            orig_outputs,
386        )
387        erase_nodes(gm, sorted_nodes)
388        legalize_graph(gm)
389
390    gm.recompile()
391    return gm
392
393
394def _canonicalize_graph_with_lowered_module(gm, subgm_tag, ptn):
395    from executorch.exir.backend.backend_api import to_backend
396
397    # return lowered program for user to debug
398    exported_progs = []
399    # partition each submodule which went through convert_pt2e
400    for node in gm.graph.nodes:
401        if node.op == "call_module" and subgm_tag in node.name:
402            # obtain sample inputs through meta
403            subgm_input = [
404                torch.ones(arg.meta["val"].shape, dtype=arg.meta["val"].dtype)
405                for arg in node.args
406            ]
407            # program meets QNN backend requirement
408            sub_prog = capture_program(gm.get_submodule(node.name), tuple(subgm_input))
409            # start lowering with given partitioner
410            exported_progs.append(to_backend(sub_prog.exported_program, ptn))
411            # replace submodule with lowered module
412            gm.set_submodule(
413                node.name,
414                exported_progs[-1].graph_module,
415            )
416            # if node has multiple outputs, getitems will be default generated
417            if all(n.target != operator.getitem for n in node.users):
418                with gm.graph.inserting_after(node):
419                    getitem_node = gm.graph.call_function(
420                        operator.getitem,
421                        (node, 0),
422                    )
423                    getitem_node.meta = node.meta
424                    node.replace_all_uses_with(
425                        replace_with=getitem_node,
426                        delete_user_cb=lambda user: user.target != operator.getitem,
427                    )
428
429    gm.recompile()
430    return gm, exported_progs
431
432
433def skip_annotation(
434    nn_module: torch.nn.Module,
435    quantizer,
436    partitioner,
437    sample_input: Tuple[torch.Tensor, ...],
438    calibration_cb: Callable[[torch.fx.GraphModule], None],
439    fp_node_id_set: set = None,
440    fp_node_op_set: set = None,
441    fallback_to_cpu: bool = True,
442):
443    r"""
444    Exclude speific operators from quantizer annotation.
445    Skipped operators will defaultly stay in CPU, set 'fallback_to_cpu'
446    to False for trying to delegate them with FP16 precision.
447
448    e.g.: consider following graph:
449    bias_1 weight_1 input_1   bias_2 weight_2 input_2
450      | (placeholder) |         | (placeholder) |
451       \      |      /           \      |      /
452        \     |     /             \     |     /
453         \    |    /               \    |    /
454           conv2d_1                 conv2d_2
455           (torch.ops.aten.conv2d.default)
456               \                       /
457                \                     /
458                 \_______     _______/
459                         add_1
460             (torch.ops.aten.add.default)
461                           |
462                         output
463
464    If user wants to skip convolution op by names with
465    'skip_node_id_set' = {"conv2d_1"}
466    "bias_1 / weight_1 / input_1 / input_2 / conv2d_1"
467    will be partitioned out and not annotated / lowered with QNN.
468
469    [Generated graph]
470    bias_1 weight_1 input_1   input_2
471      | (placeholder) |          |
472       \      |      /           |
473        \     |     /            |
474         \    |    /             |
475           conv2d_1              |
476              \                 /
477               \               /
478                \             /
479               lowered_module_1
480            (QNN fixed precision)
481                      |
482                    output
483
484    If user wants to skip convolution op by target with
485    'skip_node_op_set' = {torch.ops.aten.conv2d.default}
486    "bias_1 / weight_1 / input_1 / conv2d_1,
487     bias_2 / weight_2 / input_2 / conv2d_2"
488    will be partitioned out and not annotated / lowered with QNN.
489
490    [Generated graph]
491    bias_1 weight_1 input_1   bias_2 weight_2 input_2
492      | (placeholder) |         | (placeholder) |
493       \      |      /           \      |      /
494        \     |     /             \     |     /
495         \    |    /               \    |    /
496           conv2d_1                 conv2d_2
497           (torch.ops.aten.conv2d.default)
498               \                       /
499                \                     /
500                 \__               __/
501                    lowered_module_1
502                 (QNN fixed precision)
503                           |
504                         output
505
506    If user wants to delegate the skipped conv2d from above graph
507    with 'fallback_to_cpu' = False:
508
509    [Generated graph]
510       input_1         input_2
511    (placeholder)   (placeholder)
512          |               |
513          \               /
514          lowered_module_2
515         (QNN fp16 precision)
516                  |
517                  |
518          lowered_module_1
519         (QNN fixed precision)
520                  |
521                output
522
523    Args:
524        nn_module (torch.nn.Module): The module to be lowered.
525        quantizer (QnnQuantizer): Instance of QnnQuantizer.
526        partitioner (QnnPartitioner): Instance of QnnPartitioner.
527        sample_input ((torch.Tensor, ...)): Sample input tensors for graph exporting.
528        calibration_cb (callable): Callback function for user-defined calibration.
529        fp_node_id_set ({str, ...}): Set of operator names to be left in fp precision.
530        fp_node_op_set ({torch.ops.aten.xxx, ...}): Set of operator targets to be left in fp precision.
531        fallback_to_cpu (bool): Whether to lower skipped nodes to fp16 or not.
532
533    Returns:
534        exported_programs: List of programs lowered to QnnBackend (quantized graphs only).
535    """
536    from executorch.backends.qualcomm.serialization.qc_schema import (
537        QnnExecuTorchHtpPrecision,
538    )
539    from executorch.backends.qualcomm.serialization.qc_schema_serialize import (
540        flatbuffer_to_option,
541    )
542    from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
543    from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
544
545    def prepare_subgm(subgm, subgm_name):
546        # prepare current submodule for quantization annotation
547        subgm_prepared = prepare_pt2e(subgm, quantizer)
548        # overwrite this attribute or name will be set to "GraphModule"
549        # we could not identify each submodule if action is not performed
550        subgm_prepared.__class__.__name__ = subgm_name
551        return subgm_prepared
552
553    fp_node_id_set = fp_node_id_set if fp_node_id_set is not None else set()
554    fp_node_op_set = fp_node_op_set if fp_node_op_set is not None else set()
555    graph_module = torch.export.export(nn_module, sample_input).module()
556    # define node support type
557    capability_partitioner = CapabilityBasedPartitioner(
558        graph_module,
559        _AnnotationSkipper(fp_node_id_set, fp_node_op_set),
560        allows_single_node_partition=True,
561    )
562    subgm_tag = "annotated_group"
563    graph_module = _partition_graph_into_submodules(
564        gm=graph_module,
565        subgm_tag=subgm_tag,
566        subgm_cb=prepare_subgm,
567        ptn=capability_partitioner,
568    )
569    # perform calibration
570    calibration_cb(graph_module)
571    # convert sub modules which went through prepare_pt2e
572    for node in graph_module.graph.nodes:
573        if node.op == "call_module":
574            graph_module.set_submodule(
575                node.name, convert_pt2e(graph_module.get_submodule(node.name))
576            )
577    # canonicalize graph for lowering again
578    graph_module, exported_progs = _canonicalize_graph_with_lowered_module(
579        gm=graph_module,
580        subgm_tag=subgm_tag,
581        ptn=partitioner,
582    )
583
584    if not fallback_to_cpu:
585        try:
586            from executorch.exir.backend.partitioner import DelegationSpec
587
588            # change HTP compiler spec for hardware to enable fp16
589            qnn_option = generate_qnn_executorch_option(
590                partitioner.compiler_specs_snapshot
591            )
592            compile_option = flatbuffer_to_option(qnn_option)
593            htp_options = compile_option.backend_options.htp_options
594            htp_options.precision = QnnExecuTorchHtpPrecision.kHtpFp16
595            partitioner.delegation_spec = DelegationSpec(
596                "QnnBackend",
597                [
598                    CompileSpec(
599                        QCOM_QNN_COMPILE_SPEC, option_to_flatbuffer(compile_option)
600                    )
601                ],
602            )
603        except:
604            print(
605                "Failed to change HTP compiler spec with 'use_fp16' as True,"
606                " skipped operators will fallback to cpu,"
607            )
608            return graph_module, exported_progs
609
610        # try lowering skipped operator into fp16
611        capability_partitioner = CapabilityBasedPartitioner(
612            graph_module,
613            _AnnotationSkipper(skip_annotated_submodule=True),
614            allows_single_node_partition=True,
615        )
616        subgm_tag = "skipped_group"
617        graph_module = _partition_graph_into_submodules(
618            gm=graph_module,
619            subgm_tag=subgm_tag,
620            subgm_cb=lambda subgm, _: subgm,
621            ptn=capability_partitioner,
622        )
623        graph_module, exported_progs_fp = _canonicalize_graph_with_lowered_module(
624            gm=graph_module,
625            subgm_tag=subgm_tag,
626            ptn=partitioner,
627        )
628        exported_progs.extend(exported_progs_fp)
629
630    return graph_module, exported_progs
631
632
633def from_context_binary(  # noqa: C901
634    ctx_path: str | bytes,
635    op_name: str,
636    soc_model: QcomChipset = QcomChipset.SM8650,
637    custom_info: Dict = None,
638):
639    from pathlib import Path
640
641    def implement_op(custom_op, op_name, outputs):
642        @torch.library.impl(
643            custom_op, str(op_name), dispatch_key="CompositeExplicitAutograd"
644        )
645        def op_impl(inputs: List[torch.Tensor]):
646            return tuple(
647                torch.zeros(tuple(v.shape), device="meta", dtype=v.dtype)
648                for v in outputs.values()
649            )
650
651    def build_graph(inputs, outputs):
652        # custom op declaration
653        inputs_str = "Tensor[] inputs"
654        func_proto = f"{op_name}({inputs_str}) -> Any"
655        custom_op = Library(OpContextLoader.namespace, "FRAGMENT")
656        custom_op.define(func_proto)
657        # custom op implementation
658        implement_op(custom_op, op_name, outputs)
659
660        # model architecture mimicking context binary
661        class Model(torch.nn.Module):
662            def forward(self, *inputs):
663                return getattr(
664                    getattr(torch.ops, OpContextLoader.namespace), op_name
665                ).default(inputs)
666
667        model = Model()
668        prog = torch.export.export(model, tuple(inputs.values()))
669        # bookkeeping for variables' life cycle
670        return {
671            "custom_op": custom_op,
672            "custom_module": model,
673            "exported_program": prog,
674        }
675
676    def build_tensor(tensors, dtype_map):
677        ret = OrderedDict()
678        for t in tensors:
679            dtype = t.GetDataType()
680            dtype_torch = dtype_map.get(dtype, None)
681            assert dtype_torch is not None, f"unknown qnn data type {dtype}"
682            ret[t.GetName()] = torch.zeros(tuple(t.GetDims()), dtype=dtype_torch)
683
684        return ret
685
686    def preprocess_binary(ctx_bin, compiler_specs):
687        qnn_mgr = PyQnnManagerAdaptor.QnnManager(
688            generate_qnn_executorch_option(compiler_specs),
689        )
690        return bytes(qnn_mgr.MakeBinaryInfo(ctx_bin))
691
692    # dummy compiler spec would be fine, since we're not compiling
693    backend_options = generate_htp_compiler_spec(use_fp16=False)
694    compiler_specs = generate_qnn_executorch_compiler_spec(
695        soc_model=soc_model,
696        backend_options=backend_options,
697        is_from_context_binary=True,
698    )
699
700    ctx_bin = (
701        ctx_path
702        if not isinstance(ctx_path, str)
703        else preprocess_binary(Path(f"{ctx_path}").read_bytes(), compiler_specs)
704    )
705
706    dtype_map = {}
707    for type_map in (QNN_QUANT_TYPE_MAP, QNN_TENSOR_TYPE_MAP):
708        for k, v in type_map.items():
709            dtype_map.setdefault(v, k)
710
711    if custom_info is not None:
712        # since some context binaries might fail to open on host
713        # if they are compiled with special flags:
714        # e.g. weight sharing
715        # use custom information here instead
716        inputs = build_tensor(custom_info["graph_inputs"], dtype_map)
717        outputs = build_tensor(custom_info["graph_outputs"], dtype_map)
718        graph_name = custom_info["graph_name"]
719    else:
720        # get context-binary io tensor info through qnn manager
721        qnn_mgr = PyQnnManagerAdaptor.QnnManager(
722            generate_qnn_executorch_option(compiler_specs),
723            ctx_bin,
724        )
725        assert qnn_mgr.Init().value == 0, "failed to load context binary"
726        # assume we only have one graph in current context
727        graph_name = qnn_mgr.GetGraphNames()[0]
728        qnn_mgr.AllocateTensor(graph_name)
729        inputs = build_tensor(qnn_mgr.GetGraphInputs(graph_name), dtype_map)
730        outputs = build_tensor(qnn_mgr.GetGraphOutputs(graph_name), dtype_map)
731        qnn_mgr.Destroy()
732
733    # generate graph specific for loading context
734    bundle_prog = build_graph(inputs, outputs)
735    bundle_prog.update({"inputs": inputs, "outputs": outputs})
736    edge_prog_mgr = to_edge(
737        programs={graph_name: bundle_prog["exported_program"]},
738        # do not alter name for custom op
739        compile_config=EdgeCompileConfig(_use_edge_ops=False),
740    )
741    # update meta with context binary
742    for n in edge_prog_mgr._edge_programs[graph_name].graph.nodes:
743        if n.op == "call_function" and OpContextLoader.namespace in str(n.target):
744            n.meta[OpContextLoader.meta_ctx_bin] = ctx_bin
745            break
746
747    bundle_prog["edge_program_manager"] = edge_prog_mgr.to_backend(
748        QnnPartitioner(compiler_specs)
749    )
750    return bundle_prog
751
752
753def draw_graph(title, path, graph_module: torch.fx.GraphModule):
754    graph = passes.graph_drawer.FxGraphDrawer(graph_module, title)
755    with open(f"{path}/{title}.svg", "wb") as f:
756        f.write(graph.get_dot_graph().create_svg())
757
758
759def generate_multi_graph_program(
760    compiler_specs: List[CompileSpec],
761    processed_bytes: List[bytes],
762    backend_config: ExecutorchBackendConfig = None,
763) -> ExecutorchProgramManager:
764    # compile multiple graphs in qcir into single context binary
765    graph_inputs, graph_outputs = {}, {}
766    qnn_mgr = PyQnnManagerAdaptor.QnnManager(
767        generate_qnn_executorch_option(compiler_specs), processed_bytes
768    )
769    assert qnn_mgr.Init().value == 0, "failed to load processed bytes"
770    binary_info = bytes(qnn_mgr.Compile())
771    assert len(binary_info) != 0, "failed to generate QNN context binary"
772    graph_names = qnn_mgr.GetGraphNames()
773    for graph_name in graph_names:
774        graph_inputs[graph_name] = qnn_mgr.GetGraphInputs(graph_name)
775        graph_outputs[graph_name] = qnn_mgr.GetGraphOutputs(graph_name)
776    qnn_mgr.Destroy()
777
778    # build custom ops with different graph signatures
779    compiler_options = flatbuffer_to_option(compiler_specs[0].value)
780    bundle_progs = [
781        from_context_binary(
782            ctx_path=binary_info,
783            op_name=f"loader_{graph_name}",
784            soc_model=compiler_options.soc_info.soc_model,
785            custom_info={
786                "graph_inputs": graph_inputs[graph_name],
787                "graph_outputs": graph_outputs[graph_name],
788                "graph_name": graph_name,
789            },
790        )
791        for graph_name in graph_names
792    ]
793    # leverage ExecutorchProgramManager for generating pte with multi-methods
794    edge_prog_mgr = to_edge(
795        programs={
796            graph_name: bundle_prog["exported_program"]
797            for graph_name, bundle_prog in zip(graph_names, bundle_progs)
798        },
799        # do not alter name for custom op
800        compile_config=EdgeCompileConfig(_use_edge_ops=False),
801    )
802    # restore meta losed in generating EdgeProgramManager
803    for graph_name in graph_names:
804        for n in edge_prog_mgr._edge_programs[graph_name].graph.nodes:
805            if graph_name in n.name:
806                n.meta[OpContextLoader.meta_ctx_bin] = binary_info
807                break
808
809    return edge_prog_mgr.to_backend(QnnPartitioner(compiler_specs)).to_executorch(
810        config=backend_config or ExecutorchBackendConfig()
811    )
812
813
814def generate_htp_compiler_spec(
815    use_fp16: bool,
816    use_dlbc: bool = False,
817    use_multi_contexts: bool = False,
818) -> QnnExecuTorchBackendOptions:
819    """
820    Helper function generating backend options for QNN HTP
821
822    Args:
823        use_fp16: If true, the model is compiled to QNN HTP fp16 runtime.
824            Note that not all SoC support QNN HTP fp16. Only premium tier SoC
825            like Snapdragon 8 Gen 1 or newer can support HTP fp16.
826        use_dlbc: Deep Learning Bandwidth Compression allows inputs to be
827            compressed, such that the processing bandwidth can be lowered.
828        use_multi_contexts: When multiple contexts are generated inside the same
829            pte, it is possible to reserve a single spill-fill allocation that
830            could be re-used across all the splits.
831
832    Returns:
833        QnnExecuTorchHtpBackendOptions: backend options for QNN HTP.
834    """
835    htp_options = QnnExecuTorchHtpBackendOptions()
836    htp_options.precision = (
837        QnnExecuTorchHtpPrecision.kHtpFp16
838        if use_fp16
839        else QnnExecuTorchHtpPrecision.kHtpQuantized
840    )
841    # This actually is not an option which can affect the compiled blob.
842    # But we don't have other place to pass this option at execution stage.
843    # TODO: enable voting mechanism in runtime and make this as an option
844    htp_options.performance_mode = QnnExecuTorchHtpPerformanceMode.kHtpBurst
845    htp_options.use_multi_contexts = use_multi_contexts
846    htp_options.use_dlbc = use_dlbc
847    return QnnExecuTorchBackendOptions(
848        backend_type=QnnExecuTorchBackendType.kHtpBackend,
849        htp_options=htp_options,
850    )
851
852
853def generate_qnn_executorch_compiler_spec(
854    soc_model: QcomChipset,
855    backend_options: QnnExecuTorchBackendOptions,
856    debug: bool = False,
857    saver: bool = False,
858    online_prepare: bool = False,
859    dump_intermediate_outputs: bool = False,
860    profile: bool = False,
861    optrace: bool = False,
862    shared_buffer: bool = False,
863    is_from_context_binary: bool = False,
864    multiple_graphs: bool = False,
865    graph_name: str = "forward",
866) -> List[CompileSpec]:
867    """
868    Helper function generating compiler specs for Qualcomm AI Engine Direct
869
870    Args:
871        soc_model: The SoC you plan to run the compiled model. Please check
872            QcomChipset for supported SoC.
873            SM8450 (Snapdragon 8 Gen 1)
874            SM8475(Snapdragon 8 Gen 1+)
875            SM8550(Snapdragon 8 Gen 2)
876            SM8650(Snapdragon 8 Gen 3)
877        backend_options: Options required by different backends.
878        debug: Enable verbose logging. Disclaimer: this option must change in
879            the near future.
880        online_prepare: Compose QNN graph on device if set to True
881        saver: Instead of compiling the model, run QNN Saver. Please check
882            documents of Qualcomm AI Engine Direct SDK. This feature is usually
883            for debugging purpose.
884        dump_intermediate_outputs: If tensor dump is enabled, all intermediate tensors output will be dumped.
885            This option exists for debugging accuracy issues
886        profile: Enable profile the performance of per operator.
887            Note that for now only support kProfileDetailed to
888            profile the performance of each operator with cycle unit.
889        shared_buffer: Enables usage of shared buffer between application
890            and backend for graph I/O.
891        is_from_context_binary: True if current graph comes from pre-built context binary.
892        multiple_graphs: True if multiple methods are expected to have in single .pte file.
893            Please see test cases for post-processing example.
894        graph_name: Assign unique graph name if 'multiple_graphs' is used.
895
896    Returns:
897        List[CompileSpec]: Compiler specs for Qualcomm AI Engine Direct.
898
899    Raises:
900        ValueError: The value QcomChipset is currently not supported.
901        ValueError: Confliction between compiler specs.
902    """
903    _supported_soc_models = {soc_model.value for soc_model in QcomChipset}
904    if soc_model not in _supported_soc_models:
905        raise ValueError(f"unknown SoC model for QNN: {soc_model}")
906
907    if profile and dump_intermediate_outputs:
908        warnings.warn(
909            "It is not recommended to turn on both profiling and dump_intermediate_outputs the same time"
910            ", because dump_intermediate_outputs will cause performance drop.",
911            stacklevel=1,
912        )
913
914    qnn_executorch_options = QnnExecuTorchOptions(
915        _soc_info_table[soc_model], backend_options
916    )
917    qnn_executorch_options.graph_name = graph_name
918    qnn_executorch_options.log_level = (
919        QnnExecuTorchLogLevel.kLogLevelDebug
920        if debug
921        else QnnExecuTorchLogLevel.kLogLevelWarn
922    )
923
924    qnn_executorch_options.dump_intermediate_outputs = dump_intermediate_outputs
925
926    if saver:
927        qnn_executorch_options.library_path = "libQnnSaver.so"
928
929    if optrace:
930        qnn_executorch_options.profile_level = QnnExecuTorchProfileLevel.kProfileOptrace
931    elif profile:
932        qnn_executorch_options.profile_level = (
933            QnnExecuTorchProfileLevel.kProfileDetailed
934        )
935    else:
936        qnn_executorch_options.profile_level = QnnExecuTorchProfileLevel.kProfileOff
937
938    if (
939        online_prepare
940        and backend_options.backend_type == QnnExecuTorchBackendType.kHtpBackend
941        and backend_options.htp_options.use_multi_contexts
942    ):
943        raise ValueError(
944            "'use_multi_context' could not function in online prepare mode, "
945            "please set 'online_prepare' to False"
946        )
947
948    qnn_executorch_options.shared_buffer = shared_buffer
949    qnn_executorch_options.online_prepare = online_prepare
950    qnn_executorch_options.is_from_context_binary = is_from_context_binary
951    qnn_executorch_options.multiple_graphs = multiple_graphs
952
953    if multiple_graphs:
954        # enable weight sharing mechanism if multiple graphs appear
955        if backend_options.backend_type == QnnExecuTorchBackendType.kHtpBackend:
956            backend_options.htp_options.use_weight_sharing = True
957
958    return [
959        CompileSpec(QCOM_QNN_COMPILE_SPEC, option_to_flatbuffer(qnn_executorch_options))
960    ]
961
962
963def get_soc_to_arch_map():
964    return {
965        "SSG2115P": HtpArch.V73,
966        "SM8650": HtpArch.V75,
967        "SM8550": HtpArch.V73,
968        "SM8475": HtpArch.V69,
969        "SM8450": HtpArch.V69,
970        "SA8295": HtpArch.V68,
971    }
972
973
974def get_soc_to_chipset_map():
975    return {
976        "SSG2115P": QcomChipset.SSG2115P,
977        "SM8650": QcomChipset.SM8650,
978        "SM8550": QcomChipset.SM8550,
979        "SM8475": QcomChipset.SM8475,
980        "SM8450": QcomChipset.SM8450,
981        "SA8295": QcomChipset.SA8295,
982    }
983
984
985def tag_quant_io(gm: torch.fx.GraphModule, get_quant_io_dtype_fn: Callable):
986    """
987    Tag io nodes which get/output quantized tensor. No need to insert q/dq in qnn_preprocess
988    """
989    for node in gm.graph.nodes:
990        if dtype := get_quant_io_dtype_fn(node):
991            node.meta[QCOM_QUANTIZED_IO] = dtype
992