xref: /aosp_15_r20/external/executorch/examples/qualcomm/oss_scripts/fbnet.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import json
8import os
9import re
10from multiprocessing.connection import Client
11
12import numpy as np
13import timm
14from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
15from executorch.examples.qualcomm.utils import (
16    build_executorch_binary,
17    get_imagenet_dataset,
18    make_output_dir,
19    setup_common_args_and_variables,
20    SimpleADB,
21    topk_accuracy,
22)
23
24
25def main(args):
26    if not args.compile_only and args.device is None:
27        raise RuntimeError(
28            "device serial is required if not compile only. "
29            "Please specify a device serial by -s/--device argument."
30        )
31
32    # ensure the working directory exist.
33    os.makedirs(args.artifact, exist_ok=True)
34
35    instance = timm.create_model("fbnetc_100", pretrained=True).eval()
36
37    data_num = 100
38    inputs, targets, input_list = get_imagenet_dataset(
39        dataset_path=f"{args.dataset}",
40        data_size=data_num,
41        image_shape=(299, 299),
42    )
43
44    pte_filename = "fbnet"
45
46    build_executorch_binary(
47        instance,
48        inputs[0],
49        args.model,
50        f"{args.artifact}/{pte_filename}",
51        inputs,
52        quant_dtype=QuantDtype.use_8a8w,
53        shared_buffer=args.shared_buffer,
54    )
55
56    if args.compile_only:
57        return
58
59    adb = SimpleADB(
60        qnn_sdk=os.getenv("QNN_SDK_ROOT"),
61        build_path=f"{args.build_folder}",
62        pte_path=f"{args.artifact}/{pte_filename}.pte",
63        workspace=f"/data/local/tmp/executorch/{pte_filename}",
64        device_id=args.device,
65        host_id=args.host,
66        soc_model=args.model,
67    )
68    adb.push(inputs=inputs, input_list=input_list)
69    adb.execute()
70
71    # collect output data
72    output_data_folder = f"{args.artifact}/outputs"
73    make_output_dir(output_data_folder)
74
75    output_raws = []
76
77    def post_process():
78        for f in sorted(
79            os.listdir(output_data_folder), key=lambda f: int(f.split("_")[1])
80        ):
81            filename = os.path.join(output_data_folder, f)
82            if re.match(r"^output_[0-9]+_[1-9].raw$", f):
83                os.remove(filename)
84            else:
85                output = np.fromfile(filename, dtype=np.float32)
86                output_raws.append(output)
87
88    adb.pull(output_path=args.artifact, callback=post_process)
89
90    # top-k analysis
91    predictions = []
92    for i in range(data_num):
93        predictions.append(
94            np.fromfile(
95                os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32
96            )
97        )
98
99    k_val = [1, 5]
100    topk = [topk_accuracy(predictions, targets, k).item() for k in k_val]
101    if args.ip and args.port != -1:
102        with Client((args.ip, args.port)) as conn:
103            conn.send(json.dumps({f"top_{k}": topk[i] for i, k in enumerate(k_val)}))
104    else:
105        for i, k in enumerate(k_val):
106            print(f"top_{k}->{topk[i]}%")
107
108
109if __name__ == "__main__":
110    parser = setup_common_args_and_variables()
111    parser.add_argument(
112        "-a",
113        "--artifact",
114        help="path for storing generated artifacts by this example. Default ./fbnet",
115        default="./fbnet",
116        type=str,
117    )
118
119    parser.add_argument(
120        "-d",
121        "--dataset",
122        help=(
123            "path to the validation folder of ImageNet dataset. "
124            "e.g. --dataset imagenet-mini/val "
125            "for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)"
126        ),
127        type=str,
128        required=True,
129    )
130
131    args = parser.parse_args()
132    try:
133        main(args)
134    except Exception as e:
135        if args.ip and args.port != -1:
136            with Client((args.ip, args.port)) as conn:
137                conn.send(json.dumps({"Error": str(e)}))
138        else:
139            raise Exception(e)
140