xref: /aosp_15_r20/external/executorch/examples/qualcomm/scripts/mobilenet_v3.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import json
8import os
9from multiprocessing.connection import Client
10
11import numpy as np
12
13import torch
14from executorch.examples.models.mobilenet_v3 import MV3Model
15from executorch.examples.qualcomm.utils import (
16    build_executorch_binary,
17    get_imagenet_dataset,
18    make_output_dir,
19    parse_skip_delegation_node,
20    setup_common_args_and_variables,
21    SimpleADB,
22    topk_accuracy,
23)
24
25
26def main(args):
27    skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args)
28
29    # ensure the working directory exist.
30    os.makedirs(args.artifact, exist_ok=True)
31
32    if not args.compile_only and args.device is None:
33        raise RuntimeError(
34            "device serial is required if not compile only. "
35            "Please specify a device serial by -s/--device argument."
36        )
37
38    data_num = 100
39    if args.compile_only:
40        inputs = [(torch.rand(1, 3, 224, 224),)]
41    else:
42        inputs, targets, input_list = get_imagenet_dataset(
43            dataset_path=f"{args.dataset}",
44            data_size=data_num,
45            image_shape=(256, 256),
46            crop_size=224,
47        )
48    pte_filename = "mv3_qnn_float16"
49    instance = MV3Model()
50    build_executorch_binary(
51        instance.get_eager_model().eval(),
52        instance.get_example_inputs(),
53        args.model,
54        f"{args.artifact}/{pte_filename}",
55        inputs,
56        skip_node_id_set=skip_node_id_set,
57        skip_node_op_set=skip_node_op_set,
58        shared_buffer=args.shared_buffer,
59    )
60
61    if args.compile_only:
62        return
63
64    adb = SimpleADB(
65        qnn_sdk=os.getenv("QNN_SDK_ROOT"),
66        build_path=f"{args.build_folder}",
67        pte_path=f"{args.artifact}/{pte_filename}.pte",
68        workspace=f"/data/local/tmp/executorch/{pte_filename}",
69        device_id=args.device,
70        host_id=args.host,
71        soc_model=args.model,
72        shared_buffer=args.shared_buffer,
73    )
74    adb.push(inputs=inputs, input_list=input_list)
75    adb.execute()
76
77    # collect output data
78    output_data_folder = f"{args.artifact}/outputs"
79    make_output_dir(output_data_folder)
80
81    adb.pull(output_path=args.artifact)
82
83    # top-k analysis
84    predictions = []
85    for i in range(data_num):
86        predictions.append(
87            np.fromfile(
88                os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32
89            )
90        )
91
92    k_val = [1, 5]
93    topk = [topk_accuracy(predictions, targets, k).item() for k in k_val]
94    if args.ip and args.port != -1:
95        with Client((args.ip, args.port)) as conn:
96            conn.send(json.dumps({f"top_{k}": topk[i] for i, k in enumerate(k_val)}))
97    else:
98        for i, k in enumerate(k_val):
99            print(f"top_{k}->{topk[i]}%")
100
101
102if __name__ == "__main__":
103    parser = setup_common_args_and_variables()
104
105    parser.add_argument(
106        "-d",
107        "--dataset",
108        help=(
109            "path to the validation folder of ImageNet dataset. "
110            "e.g. --dataset imagenet-mini/val "
111            "for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)"
112        ),
113        type=str,
114        required=False,
115    )
116
117    parser.add_argument(
118        "-a",
119        "--artifact",
120        help="path for storing generated artifacts by this example. "
121        "Default ./mobilenet_v3",
122        default="./mobilenet_v3",
123        type=str,
124    )
125
126    args = parser.parse_args()
127    try:
128        main(args)
129    except Exception as e:
130        if args.ip and args.port != -1:
131            with Client((args.ip, args.port)) as conn:
132                conn.send(json.dumps({"Error": str(e)}))
133        else:
134            raise Exception(e)
135