xref: /aosp_15_r20/external/executorch/examples/qualcomm/scripts/inception_v4.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import json
8import os
9from multiprocessing.connection import Client
10
11import numpy as np
12
13import torch
14from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
15from executorch.examples.models.inception_v4 import InceptionV4Model
16from executorch.examples.qualcomm.utils import (
17    build_executorch_binary,
18    get_imagenet_dataset,
19    make_output_dir,
20    parse_skip_delegation_node,
21    setup_common_args_and_variables,
22    SimpleADB,
23    topk_accuracy,
24)
25
26
27def main(args):
28    skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args)
29
30    # ensure the working directory exist.
31    os.makedirs(args.artifact, exist_ok=True)
32
33    if not args.compile_only and args.device is None:
34        raise RuntimeError(
35            "device serial is required if not compile only. "
36            "Please specify a device serial by -s/--device argument."
37        )
38
39    data_num = 100
40    if args.compile_only:
41        inputs = [(torch.rand(1, 3, 299, 299),)]
42    else:
43        inputs, targets, input_list = get_imagenet_dataset(
44            dataset_path=f"{args.dataset}",
45            data_size=data_num,
46            image_shape=(299, 299),
47        )
48    pte_filename = "ic4_qnn_q8"
49    instance = InceptionV4Model()
50    build_executorch_binary(
51        instance.get_eager_model().eval(),
52        instance.get_example_inputs(),
53        args.model,
54        f"{args.artifact}/{pte_filename}",
55        inputs,
56        skip_node_id_set=skip_node_id_set,
57        skip_node_op_set=skip_node_op_set,
58        quant_dtype=QuantDtype.use_8a8w,
59        shared_buffer=args.shared_buffer,
60    )
61
62    if args.compile_only:
63        return
64
65    adb = SimpleADB(
66        qnn_sdk=os.getenv("QNN_SDK_ROOT"),
67        build_path=f"{args.build_folder}",
68        pte_path=f"{args.artifact}/{pte_filename}.pte",
69        workspace=f"/data/local/tmp/executorch/{pte_filename}",
70        device_id=args.device,
71        host_id=args.host,
72        soc_model=args.model,
73        shared_buffer=args.shared_buffer,
74    )
75    adb.push(inputs=inputs, input_list=input_list)
76    adb.execute()
77
78    # collect output data
79    output_data_folder = f"{args.artifact}/outputs"
80    make_output_dir(output_data_folder)
81
82    adb.pull(output_path=args.artifact)
83
84    # top-k analysis
85    predictions = []
86    for i in range(data_num):
87        predictions.append(
88            np.fromfile(
89                os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32
90            )
91        )
92
93    k_val = [1, 5]
94    topk = [topk_accuracy(predictions, targets, k).item() for k in k_val]
95    if args.ip and args.port != -1:
96        with Client((args.ip, args.port)) as conn:
97            conn.send(json.dumps({f"top_{k}": topk[i] for i, k in enumerate(k_val)}))
98    else:
99        for i, k in enumerate(k_val):
100            print(f"top_{k}->{topk[i]}%")
101
102
103if __name__ == "__main__":
104    parser = setup_common_args_and_variables()
105
106    parser.add_argument(
107        "-d",
108        "--dataset",
109        help=(
110            "path to the validation folder of ImageNet dataset. "
111            "e.g. --dataset imagenet-mini/val "
112            "for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)"
113        ),
114        type=str,
115        required=False,
116    )
117
118    parser.add_argument(
119        "-a",
120        "--artifact",
121        help="path for storing generated artifacts by this example. "
122        "Default ./inception_v4",
123        default="./inception_v4",
124        type=str,
125    )
126
127    args = parser.parse_args()
128    try:
129        main(args)
130    except Exception as e:
131        if args.ip and args.port != -1:
132            with Client((args.ip, args.port)) as conn:
133                conn.send(json.dumps({"Error": str(e)}))
134        else:
135            raise Exception(e)
136