xref: /aosp_15_r20/external/executorch/examples/qualcomm/oss_scripts/gMLP_image_classification.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7
8import json
9import os
10from multiprocessing.connection import Client
11
12import numpy as np
13import timm
14import torch
15
16from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
17from executorch.examples.qualcomm.utils import (
18    build_executorch_binary,
19    get_imagenet_dataset,
20    make_output_dir,
21    parse_skip_delegation_node,
22    setup_common_args_and_variables,
23    SimpleADB,
24    topk_accuracy,
25)
26
27
28def main(args):
29    skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args)
30
31    # ensure the working directory exist.
32    os.makedirs(args.artifact, exist_ok=True)
33
34    if not args.compile_only and args.device is None:
35        raise RuntimeError(
36            "device serial is required if not compile only. "
37            "Please specify a device serial by -s/--device argument."
38        )
39
40    data_num = 100
41    inputs, targets, input_list = get_imagenet_dataset(
42        dataset_path=f"{args.dataset}",
43        data_size=data_num,
44        image_shape=(224, 224),
45    )
46
47    pte_filename = "gMLP_image_classification_qnn"
48    model = timm.create_model("gmlp_s16_224", pretrained=True).eval()
49    sample_input = (torch.randn(1, 3, 224, 224),)
50
51    build_executorch_binary(
52        model,
53        sample_input,
54        args.model,
55        f"{args.artifact}/{pte_filename}",
56        dataset=inputs,
57        skip_node_id_set=skip_node_id_set,
58        skip_node_op_set=skip_node_op_set,
59        quant_dtype=QuantDtype.use_8a8w,
60        shared_buffer=args.shared_buffer,
61    )
62
63    if args.compile_only:
64        return
65
66    adb = SimpleADB(
67        qnn_sdk=os.getenv("QNN_SDK_ROOT"),
68        build_path=f"{args.build_folder}",
69        pte_path=f"{args.artifact}/{pte_filename}.pte",
70        workspace=f"/data/local/tmp/executorch/{pte_filename}",
71        device_id=args.device,
72        host_id=args.host,
73        soc_model=args.model,
74        shared_buffer=args.shared_buffer,
75    )
76    adb.push(inputs=inputs, input_list=input_list)
77    adb.execute()
78
79    # collect output data
80    output_data_folder = f"{args.artifact}/outputs"
81    make_output_dir(output_data_folder)
82
83    adb.pull(output_path=args.artifact)
84
85    # top-k analysis
86    predictions = []
87    for i in range(data_num):
88        prediction = np.fromfile(
89            os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32
90        ).reshape([1, 1000])
91        for i in range(prediction.shape[0]):
92            predictions.append(prediction[i])
93
94    k_val = [1, 5]
95    topk = [topk_accuracy(predictions, targets, k).item() for k in k_val]
96    if args.ip and args.port != -1:
97        with Client((args.ip, args.port)) as conn:
98            conn.send(json.dumps({f"top_{k}": topk[i] for i, k in enumerate(k_val)}))
99    else:
100        for i, k in enumerate(k_val):
101            print(f"top_{k}->{topk[i]}%")
102
103
104if __name__ == "__main__":
105    parser = setup_common_args_and_variables()
106
107    parser.add_argument(
108        "-a",
109        "--artifact",
110        help="path for storing generated artifacts by this example. Default ./gMLP_image_classification",
111        default="./gMLP_image_classification",
112        type=str,
113    )
114
115    parser.add_argument(
116        "-d",
117        "--dataset",
118        help=(
119            "path to the validation folder of ImageNet dataset. "
120            "e.g. --dataset imagenet-mini/val "
121            "for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)"
122        ),
123        type=str,
124        required=True,
125    )
126
127    args = parser.parse_args()
128    try:
129        main(args)
130    except Exception as e:
131        if args.ip and args.port != -1:
132            with Client((args.ip, args.port)) as conn:
133                conn.send(json.dumps({"Error": str(e)}))
134        else:
135            raise Exception(e)
136