xref: /aosp_15_r20/external/executorch/examples/qualcomm/oss_scripts/dino_v2.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import json
8import os
9from multiprocessing.connection import Client
10
11import numpy as np
12import torch
13from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
14
15from executorch.examples.qualcomm.utils import (
16    build_executorch_binary,
17    get_imagenet_dataset,
18    make_output_dir,
19    parse_skip_delegation_node,
20    setup_common_args_and_variables,
21    SimpleADB,
22    topk_accuracy,
23)
24
25from transformers import Dinov2ForImageClassification
26
27
28def get_instance():
29    model = Dinov2ForImageClassification.from_pretrained(
30        "facebook/dinov2-small-imagenet1k-1-layer"
31    )
32
33    return model.eval()
34
35
36def main(args):
37    skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args)
38
39    # ensure the working directory exist.
40    os.makedirs(args.artifact, exist_ok=True)
41
42    if not args.compile_only and args.device is None:
43        raise RuntimeError(
44            "device serial is required if not compile only. "
45            "Please specify a device serial by -s/--device argument."
46        )
47
48    img_size, data_num = 224, 100
49    inputs, targets, input_list = get_imagenet_dataset(
50        dataset_path=f"{args.dataset}",
51        data_size=data_num,
52        image_shape=(256, 256),
53        crop_size=img_size,
54    )
55    sample_input = (torch.randn((1, 3, img_size, img_size)),)
56
57    pte_filename = "dino_v2"
58    instance = get_instance()
59    build_executorch_binary(
60        instance,
61        sample_input,
62        args.model,
63        f"{args.artifact}/{pte_filename}",
64        inputs,
65        skip_node_id_set=skip_node_id_set,
66        skip_node_op_set=skip_node_op_set,
67        quant_dtype=QuantDtype.use_8a8w,
68        shared_buffer=args.shared_buffer,
69    )
70
71    if args.compile_only:
72        return
73
74    adb = SimpleADB(
75        qnn_sdk=os.getenv("QNN_SDK_ROOT"),
76        build_path=f"{args.build_folder}",
77        pte_path=f"{args.artifact}/{pte_filename}.pte",
78        workspace=f"/data/local/tmp/executorch/{pte_filename}",
79        device_id=args.device,
80        host_id=args.host,
81        soc_model=args.model,
82    )
83    adb.push(inputs=inputs, input_list=input_list)
84    adb.execute()
85
86    # collect output data
87    output_data_folder = f"{args.artifact}/outputs"
88    make_output_dir(output_data_folder)
89
90    adb.pull(output_path=args.artifact)
91
92    # top-k analysis
93    predictions = []
94    for i in range(data_num):
95        predictions.append(
96            np.fromfile(
97                os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32
98            )
99        )
100
101    k_val = [1, 5]
102    topk = [topk_accuracy(predictions, targets, k).item() for k in k_val]
103    if args.ip and args.port != -1:
104        with Client((args.ip, args.port)) as conn:
105            conn.send(json.dumps({f"top_{k}": topk[i] for i, k in enumerate(k_val)}))
106    else:
107        for i, k in enumerate(k_val):
108            print(f"top_{k}->{topk[i]}%")
109
110
111if __name__ == "__main__":
112    parser = setup_common_args_and_variables()
113
114    parser.add_argument(
115        "-a",
116        "--artifact",
117        help="Path for storing generated artifacts by this example. Default ./dino_v2",
118        default="./dino_v2",
119        type=str,
120    )
121
122    parser.add_argument(
123        "-d",
124        "--dataset",
125        help=(
126            "path to the validation folder of ImageNet dataset. "
127            "e.g. --dataset imagenet-mini/val "
128            "for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)"
129        ),
130        type=str,
131        required=True,
132    )
133
134    args = parser.parse_args()
135    try:
136        main(args)
137    except Exception as e:
138        if args.ip and args.port != -1:
139            with Client((args.ip, args.port)) as conn:
140                conn.send(json.dumps({"Error": str(e)}))
141        else:
142            raise Exception(e)
143