xref: /aosp_15_r20/external/executorch/examples/qualcomm/oss_scripts/retinanet.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import json
8import os
9import sys
10from multiprocessing.connection import Client
11
12import numpy as np
13
14import torch
15from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
16from executorch.examples.qualcomm.utils import (
17    build_executorch_binary,
18    make_output_dir,
19    parse_skip_delegation_node,
20    setup_common_args_and_variables,
21    SimpleADB,
22)
23
24
25def get_instance():
26    import torchvision
27    from torchvision.models.detection import RetinaNet_ResNet50_FPN_V2_Weights
28
29    model = torchvision.models.detection.retinanet_resnet50_fpn_v2(
30        weights=RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT
31    )
32
33    # the post-process part in vanilla forward method failed to be exported
34    # here we only gather the network structure for torch.export.export to work
35    def forward_without_metrics(self, image):
36        features = self.backbone(image)
37        return self.head(list(features.values()))
38
39    model.forward = lambda img: forward_without_metrics(model, img)
40    return model.eval()
41
42
43def get_dataset(data_size, dataset_dir):
44    from torchvision import datasets, transforms
45
46    class COCODataset(datasets.CocoDetection):
47        def __init__(self, dataset_root):
48            self.images_path = os.path.join(dataset_root, "val2017")
49            self.annots_path = os.path.join(
50                dataset_root, "annotations/instances_val2017.json"
51            )
52            self.img_shape = (640, 640)
53            self.preprocess = transforms.Compose(
54                [
55                    transforms.PILToTensor(),
56                    transforms.ConvertImageDtype(torch.float),
57                    transforms.Resize(self.img_shape),
58                    transforms.Normalize(
59                        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
60                    ),
61                ]
62            )
63            with open(self.annots_path, "r") as f:
64                data = json.load(f)
65                categories = data["categories"]
66                self.label_names = {
67                    category["id"]: category["name"] for category in categories
68                }
69
70            super().__init__(root=self.images_path, annFile=self.annots_path)
71
72        def __getitem__(self, index):
73            img, target = super().__getitem__(index)
74
75            bboxes, labels = [], []
76            for obj in target:
77                bboxes.append(self.resize_bbox(obj["bbox"], img.size))
78                labels.append(obj["category_id"])
79
80            # return empty list if no label exists
81            return (
82                self.preprocess(img),
83                torch.stack(bboxes) if len(bboxes) > 0 else [],
84                torch.Tensor(labels).to(torch.int) if len(labels) > 0 else [],
85            )
86
87        def resize_bbox(self, bbox, orig_shape):
88            # bypass if no label exists
89            if len(bbox) == 0:
90                return
91
92            y_scale = float(self.img_shape[0]) / orig_shape[0]
93            x_scale = float(self.img_shape[1]) / orig_shape[1]
94            # bbox: [(upper-left) x, y, w, h]
95            bbox[2] += bbox[0]
96            bbox[3] += bbox[1]
97            # rescale bbox according to image shape
98            bbox[0] = y_scale * bbox[0]
99            bbox[2] = y_scale * bbox[2]
100            bbox[1] = x_scale * bbox[1]
101            bbox[3] = x_scale * bbox[3]
102            return torch.Tensor(bbox)
103
104    dataset = COCODataset(dataset_root=dataset_dir)
105    test_loader = torch.utils.data.DataLoader(dataset=dataset, shuffle=True)
106    inputs, input_list = [], ""
107    bboxes, targets = [], []
108    for index, (img, boxes, labels) in enumerate(test_loader):
109        if index >= data_size:
110            break
111        inputs.append((img,))
112        input_list += f"input_{index}_0.raw\n"
113        bboxes.append(boxes)
114        targets.append(labels)
115
116    return inputs, input_list, bboxes, targets, dataset.label_names
117
118
119def calculate_precision(
120    true_boxes, true_labels, det_boxes, det_labels, tp, fp, top_k, iou_thres
121):
122    import torchvision
123
124    def collect_data(boxes, labels, top_k=-1):
125        # extract data up to top_k length
126        top_k = labels.size(0) if top_k == -1 else top_k
127        len_labels = min(labels.size(0), top_k)
128        boxes, labels = boxes[:len_labels, :], labels[:len_labels]
129        # how many labels do we have in current data
130        cls = set(labels[:len_labels].tolist())
131        map = {index: [] for index in cls}
132        # stack data in same class
133        for j in range(len_labels):
134            index = labels[j].item()
135            if index in cls:
136                map[index].append(boxes[j, :])
137        return {k: torch.stack(v) for k, v in map.items()}
138
139    preds = collect_data(det_boxes, det_labels, top_k=top_k)
140    targets = collect_data(true_boxes.squeeze(0), true_labels.squeeze(0))
141    # evaluate data with labels presenting in ground truth data
142    for index in targets.keys():
143        # there is no precision gain for predictions not present in ground truth data
144        if index in preds:
145            # targets shape: (M, 4), preds shape: (N, 4)
146            # shape after box_iou: (M, N), iou shape: (M)
147            # true-positive: how many predictions meet the iou threshold. i.e. k of M
148            # false-positive: M - true-positive = M - k
149            iou, _ = torchvision.ops.box_iou(targets[index], preds[index]).max(0)
150            tps = torch.where(iou >= iou_thres, 1, 0).sum().item()
151            tp[index - 1] += tps
152            fp[index - 1] += iou.nelement() - tps
153
154
155def eval_metric(instance, heads, images, bboxes, targets, classes):
156    tp, fp = classes * [0], classes * [0]
157    head_label = ["cls_logits", "bbox_regression"]
158
159    # feature size should be changed if input size got altered
160    feature_size = [80, 40, 20, 10, 5]
161    feature_maps = [torch.zeros(1, 256, h, h) for h in feature_size]
162    for head, image, true_boxes, true_labels in zip(heads, images, bboxes, targets):
163        anchors = instance.anchor_generator(
164            image_list=image,
165            feature_maps=feature_maps,
166        )
167        num_anchors_per_level = [hw**2 * 9 for hw in feature_size]
168        # split outputs per level
169        split_head_outputs = {
170            head_label[i]: list(h.split(num_anchors_per_level, dim=1))
171            for i, h in enumerate(head)
172        }
173        split_anchors = [list(a.split(num_anchors_per_level)) for a in anchors]
174        # compute the detections (based on official post-process method)
175        detection = instance.postprocess_detections(
176            head_outputs=split_head_outputs,
177            anchors=split_anchors,
178            image_shapes=[image.image_sizes],
179        )
180        # no contribution to precision
181        if len(true_labels) == 0:
182            continue
183        # here we select top 10 confidence and iou >= 0.5 as the criteria
184        calculate_precision(
185            true_boxes=true_boxes,
186            true_labels=true_labels,
187            det_boxes=detection[0]["boxes"],
188            det_labels=detection[0]["labels"],
189            tp=tp,
190            fp=fp,
191            top_k=10,
192            iou_thres=0.5,
193        )
194
195    # remove labels which does not appear in current dataset
196    AP = torch.Tensor(
197        [
198            tp[i] * 1.0 / (tp[i] + fp[i]) if tp[i] + fp[i] > 0 else -1
199            for i in range(len(tp))
200        ]
201    )
202    missed_labels = torch.where(AP == -1, 1, 0).sum()
203    mAP = AP.where(AP != -1, 0).sum() / (AP.nelement() - missed_labels)
204    return AP, mAP.item()
205
206
207def main(args):
208    from pprint import PrettyPrinter
209
210    from torchvision.models.detection.image_list import ImageList
211
212    skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args)
213
214    # ensure the working directory exist
215    os.makedirs(args.artifact, exist_ok=True)
216
217    if not args.compile_only and args.device is None:
218        raise RuntimeError(
219            "device serial is required if not compile only. "
220            "Please specify a device serial by -s/--device argument."
221        )
222
223    model = get_instance()
224
225    # retrieve dataset
226    data_num = 100
227    # 91 classes appear in COCO dataset
228    n_classes, n_coord_of_bbox = 91, 4
229    inputs, input_list, bboxes, targets, label_names = get_dataset(
230        data_size=data_num, dataset_dir=args.dataset
231    )
232    pte_filename = "retinanet_qnn"
233    build_executorch_binary(
234        model,
235        inputs[0],
236        args.model,
237        f"{args.artifact}/{pte_filename}",
238        inputs,
239        skip_node_id_set=skip_node_id_set,
240        skip_node_op_set=skip_node_op_set,
241        quant_dtype=QuantDtype.use_8a8w,
242        shared_buffer=args.shared_buffer,
243    )
244
245    if args.compile_only:
246        sys.exit(0)
247
248    adb = SimpleADB(
249        qnn_sdk=os.getenv("QNN_SDK_ROOT"),
250        build_path=f"{args.build_folder}",
251        pte_path=f"{args.artifact}/{pte_filename}.pte",
252        workspace=f"/data/local/tmp/executorch/{pte_filename}",
253        device_id=args.device,
254        host_id=args.host,
255        soc_model=args.model,
256        shared_buffer=args.shared_buffer,
257    )
258    adb.push(inputs=inputs, input_list=input_list)
259    adb.execute()
260
261    # collect output data
262    output_data_folder = f"{args.artifact}/outputs"
263    make_output_dir(output_data_folder)
264    adb.pull(output_path=args.artifact)
265
266    predictions, classes = [], [n_classes, n_coord_of_bbox]
267    for i in range(data_num):
268        result = []
269        for j, dim in enumerate(classes):
270            data_np = np.fromfile(
271                os.path.join(output_data_folder, f"output_{i}_{j}.raw"),
272                dtype=np.float32,
273            )
274            result.append(torch.from_numpy(data_np).reshape(1, -1, dim))
275        predictions.append(result)
276
277    # evaluate metrics
278    AP, mAP = eval_metric(
279        instance=model,
280        heads=predictions,
281        images=[ImageList(img[0], tuple(img[0].shape[-2:])) for img in inputs],
282        bboxes=bboxes,
283        targets=targets,
284        classes=n_classes,
285    )
286
287    if args.ip and args.port != -1:
288        with Client((args.ip, args.port)) as conn:
289            conn.send(json.dumps({"mAP": mAP}))
290    else:
291        print("\nMean Average Precision (mAP): %.3f" % mAP)
292        print("\nAverage Precision of Classes (AP):")
293        PrettyPrinter().pprint(
294            {label_names[i + 1]: AP[i].item() for i in range(n_classes) if AP[i] != -1}
295        )
296
297
298if __name__ == "__main__":
299    parser = setup_common_args_and_variables()
300
301    parser.add_argument(
302        "-a",
303        "--artifact",
304        help="path for storing generated artifacts by this example. "
305        "Default ./retinanet",
306        default="./retinanet",
307        type=str,
308    )
309    parser.add_argument(
310        "-d",
311        "--dataset",
312        help=(
313            "path to the validation folder of COCO2017 dataset. "
314            "e.g. --dataset PATH/TO/COCO (which contains 'val_2017' & 'annotations'), "
315            "dataset could be downloaded via http://images.cocodataset.org/zips/val2017.zip & "
316            "http://images.cocodataset.org/annotations/annotations_trainval2017.zip"
317        ),
318        type=str,
319        required=True,
320    )
321
322    args = parser.parse_args()
323    try:
324        main(args)
325    except Exception as e:
326        if args.ip and args.port != -1:
327            with Client((args.ip, args.port)) as conn:
328                conn.send(json.dumps({"Error": str(e)}))
329        else:
330            raise Exception(e)
331