xref: /aosp_15_r20/external/executorch/examples/qualcomm/oss_scripts/ssd300_vgg16.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import json
8import os
9import sys
10from multiprocessing.connection import Client
11from pprint import PrettyPrinter
12
13import numpy as np
14import torch
15
16from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
17from executorch.examples.qualcomm.utils import (
18    build_executorch_binary,
19    make_output_dir,
20    parse_skip_delegation_node,
21    setup_common_args_and_variables,
22    SimpleADB,
23)
24
25
26def create_data_lists(voc07_path, data_size):
27    """
28    Create lists of images, the bounding boxes and labels of the objects in these images, and save these to file.
29
30    :param voc07_path: path to the 'VOC2007' folder
31    """
32    from utils import parse_annotation
33
34    voc07_path = os.path.abspath(voc07_path)
35
36    # Test data
37    test_images = []
38    test_objects = []
39    n_objects = 0
40
41    # Find IDs of images in the test data
42    with open(os.path.join(voc07_path, "ImageSets/Main/test.txt")) as f:
43        ids = f.read().splitlines()
44
45    for index, id in enumerate(ids):
46        if index >= data_size:
47            break
48        # Parse annotation's XML file
49        objects = parse_annotation(os.path.join(voc07_path, "Annotations", id + ".xml"))
50        if len(objects) == 0:
51            continue
52        test_objects.append(objects)
53        n_objects += len(objects)
54        test_images.append(os.path.join(voc07_path, "JPEGImages", id + ".jpg"))
55
56    assert len(test_objects) == len(test_images)
57
58    # TEST_images.json stores the file name of the images, and TEST_objects.json stores info such as boxes, labels, and difficulties
59    with open(os.path.join(voc07_path, "TEST_images.json"), "w") as j:
60        json.dump(test_images, j)
61    with open(os.path.join(voc07_path, "TEST_objects.json"), "w") as j:
62        json.dump(test_objects, j)
63
64    print(
65        "\nThere are %d test images containing a total of %d objects. Files have been saved to %s."
66        % (len(test_images), n_objects, os.path.abspath(voc07_path))
67    )
68
69
70def get_dataset(data_size, dataset_dir, download):
71    from datasets import PascalVOCDataset
72    from torchvision import datasets
73
74    if download:
75        datasets.VOCSegmentation(
76            root=os.path.join(dataset_dir, "voc_image"),
77            year="2007",
78            image_set="test",
79            download=True,
80        )
81    voc07_path = os.path.join(dataset_dir, "voc_image", "VOCdevkit", "VOC2007")
82    create_data_lists(voc07_path, data_size)
83
84    # voc07_path is where the data and ground truth json file will be stored
85    test_dataset = PascalVOCDataset(voc07_path, split="test", keep_difficult=True)
86
87    test_loader = torch.utils.data.DataLoader(
88        test_dataset, shuffle=True, collate_fn=test_dataset.collate_fn
89    )
90
91    inputs, input_list = [], ""
92    true_boxes = []
93    true_labels = []
94    true_difficulties = []
95    for index, (images, boxes, labels, difficulties) in enumerate(test_loader):
96        if index >= data_size:
97            break
98        inputs.append((images,))
99        input_list += f"input_{index}_0.raw\n"
100        true_boxes.extend(boxes)
101        true_labels.extend(labels)
102        true_difficulties.extend(difficulties)
103
104    return inputs, input_list, true_boxes, true_labels, true_difficulties
105
106
107def SSD300VGG16(pretrained_weight_model):
108    from model import SSD300
109
110    model = SSD300(n_classes=21)
111    # TODO: If possible, it's better to set weights_only to True
112    # https://pytorch.org/docs/stable/generated/torch.load.html
113    checkpoint = torch.load(
114        pretrained_weight_model, map_location="cpu", weights_only=False
115    )
116    model.load_state_dict(checkpoint["model"].state_dict())
117
118    return model.eval()
119
120
121def main(args):
122    sys.path.insert(0, args.oss_repo)
123
124    skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args)
125
126    # ensure the working directory exist.
127    os.makedirs(args.artifact, exist_ok=True)
128
129    if not args.compile_only and args.device is None:
130        raise RuntimeError(
131            "device serial is required if not compile only. "
132            "Please specify a device serial by -s/--device argument."
133        )
134
135    data_num = 100
136    inputs, input_list, true_boxes, true_labels, true_difficulties = get_dataset(
137        data_size=data_num, dataset_dir=args.artifact, download=args.download
138    )
139
140    pte_filename = "ssd300_vgg16_qnn"
141    model = SSD300VGG16(args.pretrained_weight)
142
143    sample_input = (torch.randn((1, 3, 300, 300)),)
144    build_executorch_binary(
145        model,
146        sample_input,
147        args.model,
148        f"{args.artifact}/{pte_filename}",
149        inputs,
150        skip_node_id_set=skip_node_id_set,
151        skip_node_op_set=skip_node_op_set,
152        quant_dtype=QuantDtype.use_8a8w,
153        shared_buffer=args.shared_buffer,
154    )
155
156    if args.compile_only:
157        return
158
159    adb = SimpleADB(
160        qnn_sdk=os.getenv("QNN_SDK_ROOT"),
161        build_path=f"{args.build_folder}",
162        pte_path=f"{args.artifact}/{pte_filename}.pte",
163        workspace=f"/data/local/tmp/executorch/{pte_filename}",
164        device_id=args.device,
165        host_id=args.host,
166        soc_model=args.model,
167    )
168    adb.push(inputs=inputs, input_list=input_list)
169    adb.execute()
170
171    # collect output data
172    output_data_folder = f"{args.artifact}/outputs"
173    make_output_dir(output_data_folder)
174
175    det_boxes = []
176    det_labels = []
177    det_scores = []
178
179    def post_process():
180        from utils import calculate_mAP
181
182        np.set_printoptions(threshold=np.inf)
183
184        # output_xxx_0.raw is output of boxes, and output_xxx_1.raw is output of classes
185        for file_index in range(data_num):
186            boxes_filename = os.path.join(
187                output_data_folder, f"output_{file_index}_0.raw"
188            )
189            category_filename = os.path.join(
190                output_data_folder, f"output_{file_index}_1.raw"
191            )
192
193            predicted_locs = np.fromfile(boxes_filename, dtype=np.float32).reshape(
194                [1, 8732, 4]
195            )
196            predicted_locs = torch.tensor(predicted_locs)
197
198            predicted_scores = np.fromfile(category_filename, dtype=np.float32).reshape(
199                [1, 8732, 21]
200            )
201            predicted_scores = torch.tensor(predicted_scores)
202
203            det_boxes_batch, det_labels_batch, det_scores_batch = model.detect_objects(
204                predicted_locs,
205                predicted_scores,
206                min_score=0.01,
207                max_overlap=0.45,
208                top_k=200,
209            )
210
211            det_boxes.extend(det_boxes_batch)
212            det_labels.extend(det_labels_batch)
213            det_scores.extend(det_scores_batch)
214
215        pp = PrettyPrinter()
216        # Calculate mAP
217        APs, mAP = calculate_mAP(
218            det_boxes,
219            det_labels,
220            det_scores,
221            true_boxes,
222            true_labels,
223            true_difficulties,
224        )
225        if args.ip and args.port != -1:
226            with Client((args.ip, args.port)) as conn:
227                conn.send(json.dumps({"mAP": float(mAP)}))
228        else:
229            print("\nMean Average Precision (mAP): %.3f" % mAP)
230            pp.pprint(APs)
231
232    adb.pull(output_path=args.artifact, callback=post_process)
233
234
235if __name__ == "__main__":
236    parser = setup_common_args_and_variables()
237
238    parser.add_argument(
239        "-a",
240        "--artifact",
241        help="path for storing generated artifacts by this example. Default ./ssd300_vgg16",
242        default="./ssd300_vgg16",
243        type=str,
244    )
245
246    parser.add_argument(
247        "-d",
248        "--download",
249        help="If specified, download VOCSegmentation dataset by torchvision API",
250        action="store_true",
251        default=False,
252    )
253
254    parser.add_argument(
255        "--oss_repo",
256        help=(
257            "Repository that contains model backbone and score calculation."
258            "e.g., --M ./a-PyTorch-Tutorial-to-Object-Detection"
259            "Please clone the repository from https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Object-Detection"
260        ),
261        type=str,
262        required=True,
263    )
264
265    parser.add_argument(
266        "-p",
267        "--pretrained_weight",
268        help=(
269            "Location of model pretrained weight."
270            "e.g., -p ./checkpoint_ssd300.pth.tar"
271            "Pretrained model can be found in the link https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Object-Detection, under the Training Section"
272        ),
273        type=str,
274        required=True,
275    )
276
277    args = parser.parse_args()
278    try:
279        main(args)
280    except Exception as e:
281        if args.ip and args.port != -1:
282            with Client((args.ip, args.port)) as conn:
283                conn.send(json.dumps({"Error": str(e)}))
284        else:
285            raise Exception(e)
286