xref: /aosp_15_r20/external/executorch/examples/qualcomm/qaihub_scripts/utils/utils.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import gc
8
9import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManagerAdaptor
10
11from executorch.backends.qualcomm.utils.utils import (
12    generate_qnn_executorch_option,
13    update_spill_fill_size,
14)
15
16
17def preprocess_binary(ctx_bin, compiler_specs):
18    qnn_mgr = PyQnnManagerAdaptor.QnnManager(
19        generate_qnn_executorch_option(compiler_specs),
20    )
21    return bytes(qnn_mgr.MakeBinaryInfo(ctx_bin))
22
23
24def get_encoding(
25    path_to_shard: str,
26    compiler_specs: str,
27    get_input: bool,
28    get_output: bool,
29    num_input: int,
30    num_output: int,
31):
32    encoding_list = []
33    with open(path_to_shard, "rb") as f:
34        ctx_bin = preprocess_binary(f.read(), compiler_specs)
35        qnn_mgr = PyQnnManagerAdaptor.QnnManager(
36            generate_qnn_executorch_option(compiler_specs), ctx_bin
37        )
38        assert qnn_mgr.Init().value == 0, "failed to load context binary"
39        graph_name = qnn_mgr.GetGraphNames()[0]
40        qnn_mgr.AllocateTensor(graph_name)
41        if get_input:
42            encoding_input = {"scale": [], "offset": []}
43            for i in range(num_input):
44                inputs = qnn_mgr.GetGraphInputs(graph_name)[i]
45                encoding = inputs.GetEncodings()
46                encoding_input["scale"].append(encoding.data["scale"].item())
47                encoding_input["offset"].append(encoding.data["offset"].item())
48            encoding_list.append(encoding_input)
49        if get_output:
50            encoding_output = {"scale": [], "offset": []}
51            for i in range(num_output):
52                outputs = qnn_mgr.GetGraphOutputs(graph_name)[i]
53                encoding = outputs.GetEncodings()
54                encoding_output["scale"].append(encoding.data["scale"].item())
55                encoding_output["offset"].append(encoding.data["offset"].item())
56            encoding_list.append(encoding_output)
57        qnn_mgr.Destroy()
58    return encoding_list
59
60
61def gen_pte_from_ctx_bin(artifact, pte_names, bundle_programs, backend_config):
62    edge_prog_mgrs = [prog["edge_program_manager"] for prog in bundle_programs]
63    # Setup spill-fill buffer for relieving runtime memory usage
64    update_spill_fill_size(
65        [
66            prog_mgr._edge_programs[list(prog_mgr.methods)[0]]
67            for prog_mgr in edge_prog_mgrs
68        ]
69    )
70    # Export pte files
71    pte_files = []
72    for pte_name in pte_names:
73        print(f"{pte_name} generating...")
74        pte_files.append(f"{artifact}/{pte_name}.pte")
75        with open(pte_files[-1], "wb") as f:
76            edge_prog_mgrs[0].to_executorch(config=backend_config).write_to_file(f)
77        # GC for reducing host memory consuming
78        bundle_programs.pop(0)
79        edge_prog_mgrs.pop(0)
80        gc.collect()
81
82    return pte_files
83