xref: /aosp_15_r20/external/executorch/examples/qualcomm/scripts/torchvision_vit.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import json
8import os
9from multiprocessing.connection import Client
10
11import numpy as np
12
13import torch
14from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
15from executorch.examples.models.torchvision_vit.model import TorchVisionViTModel
16from executorch.examples.qualcomm.utils import (
17    build_executorch_binary,
18    get_imagenet_dataset,
19    make_output_dir,
20    setup_common_args_and_variables,
21    SimpleADB,
22    topk_accuracy,
23)
24
25
26def main(args):
27    # ensure the working directory exist.
28    os.makedirs(args.artifact, exist_ok=True)
29
30    data_num = 100
31    if args.compile_only:
32        inputs = [(torch.rand(1, 3, 224, 224),)]
33    else:
34        inputs, targets, input_list = get_imagenet_dataset(
35            dataset_path=f"{args.dataset}",
36            data_size=data_num,
37            image_shape=(256, 256),
38            crop_size=224,
39        )
40
41    pte_filename = "vit_qnn_q8"
42    instance = TorchVisionViTModel()
43    build_executorch_binary(
44        instance.get_eager_model().eval(),
45        instance.get_example_inputs(),
46        args.model,
47        f"{args.artifact}/{pte_filename}",
48        inputs,
49        quant_dtype=QuantDtype.use_8a8w,
50        shared_buffer=args.shared_buffer,
51    )
52
53    if args.compile_only:
54        return
55
56    adb = SimpleADB(
57        qnn_sdk=os.getenv("QNN_SDK_ROOT"),
58        build_path=f"{args.build_folder}",
59        pte_path=f"{args.artifact}/{pte_filename}.pte",
60        workspace=f"/data/local/tmp/executorch/{pte_filename}",
61        device_id=args.device,
62        host_id=args.host,
63        soc_model=args.model,
64        shared_buffer=args.shared_buffer,
65    )
66    adb.push(inputs=inputs, input_list=input_list)
67    adb.execute()
68
69    # collect output data
70    output_data_folder = f"{args.artifact}/outputs"
71    make_output_dir(output_data_folder)
72
73    adb.pull(output_path=args.artifact)
74
75    # top-k analysis
76    predictions = []
77    for i in range(data_num):
78        predictions.append(
79            np.fromfile(
80                os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32
81            )
82        )
83
84    k_val = [1, 5]
85    topk = [topk_accuracy(predictions, targets, k).item() for k in k_val]
86    if args.ip and args.port != -1:
87        with Client((args.ip, args.port)) as conn:
88            conn.send(json.dumps({f"top_{k}": topk[i] for i, k in enumerate(k_val)}))
89    else:
90        for i, k in enumerate(k_val):
91            print(f"top_{k}->{topk[i]}%")
92
93
94if __name__ == "__main__":
95    parser = setup_common_args_and_variables()
96    parser.add_argument(
97        "-d",
98        "--dataset",
99        help=(
100            "path to the validation folder of ImageNet dataset. "
101            "e.g. --dataset imagenet-mini/val "
102            "for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)"
103        ),
104        type=str,
105        required=False,
106    )
107    parser.add_argument(
108        "-a",
109        "--artifact",
110        help="path for storing generated artifacts by this example. "
111        "Default ./torchvision_vit",
112        default="./torchvision_vit",
113        type=str,
114    )
115
116    args = parser.parse_args()
117    try:
118        main(args)
119    except Exception as e:
120        if args.ip and args.port != -1:
121            with Client((args.ip, args.port)) as conn:
122                conn.send(json.dumps({"Error": str(e)}))
123        else:
124            raise Exception(e)
125