1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import json 8import os 9import sys 10from multiprocessing.connection import Client 11 12import numpy as np 13 14import torch 15from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype 16from executorch.examples.qualcomm.utils import ( 17 build_executorch_binary, 18 make_output_dir, 19 parse_skip_delegation_node, 20 setup_common_args_and_variables, 21 SimpleADB, 22) 23 24 25def get_instance(): 26 import torchvision 27 from torchvision.models.detection import RetinaNet_ResNet50_FPN_V2_Weights 28 29 model = torchvision.models.detection.retinanet_resnet50_fpn_v2( 30 weights=RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT 31 ) 32 33 # the post-process part in vanilla forward method failed to be exported 34 # here we only gather the network structure for torch.export.export to work 35 def forward_without_metrics(self, image): 36 features = self.backbone(image) 37 return self.head(list(features.values())) 38 39 model.forward = lambda img: forward_without_metrics(model, img) 40 return model.eval() 41 42 43def get_dataset(data_size, dataset_dir): 44 from torchvision import datasets, transforms 45 46 class COCODataset(datasets.CocoDetection): 47 def __init__(self, dataset_root): 48 self.images_path = os.path.join(dataset_root, "val2017") 49 self.annots_path = os.path.join( 50 dataset_root, "annotations/instances_val2017.json" 51 ) 52 self.img_shape = (640, 640) 53 self.preprocess = transforms.Compose( 54 [ 55 transforms.PILToTensor(), 56 transforms.ConvertImageDtype(torch.float), 57 transforms.Resize(self.img_shape), 58 transforms.Normalize( 59 mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 60 ), 61 ] 62 ) 63 with open(self.annots_path, "r") as f: 64 data = json.load(f) 65 categories = data["categories"] 66 self.label_names = { 67 category["id"]: category["name"] for category in categories 68 } 69 70 super().__init__(root=self.images_path, annFile=self.annots_path) 71 72 def __getitem__(self, index): 73 img, target = super().__getitem__(index) 74 75 bboxes, labels = [], [] 76 for obj in target: 77 bboxes.append(self.resize_bbox(obj["bbox"], img.size)) 78 labels.append(obj["category_id"]) 79 80 # return empty list if no label exists 81 return ( 82 self.preprocess(img), 83 torch.stack(bboxes) if len(bboxes) > 0 else [], 84 torch.Tensor(labels).to(torch.int) if len(labels) > 0 else [], 85 ) 86 87 def resize_bbox(self, bbox, orig_shape): 88 # bypass if no label exists 89 if len(bbox) == 0: 90 return 91 92 y_scale = float(self.img_shape[0]) / orig_shape[0] 93 x_scale = float(self.img_shape[1]) / orig_shape[1] 94 # bbox: [(upper-left) x, y, w, h] 95 bbox[2] += bbox[0] 96 bbox[3] += bbox[1] 97 # rescale bbox according to image shape 98 bbox[0] = y_scale * bbox[0] 99 bbox[2] = y_scale * bbox[2] 100 bbox[1] = x_scale * bbox[1] 101 bbox[3] = x_scale * bbox[3] 102 return torch.Tensor(bbox) 103 104 dataset = COCODataset(dataset_root=dataset_dir) 105 test_loader = torch.utils.data.DataLoader(dataset=dataset, shuffle=True) 106 inputs, input_list = [], "" 107 bboxes, targets = [], [] 108 for index, (img, boxes, labels) in enumerate(test_loader): 109 if index >= data_size: 110 break 111 inputs.append((img,)) 112 input_list += f"input_{index}_0.raw\n" 113 bboxes.append(boxes) 114 targets.append(labels) 115 116 return inputs, input_list, bboxes, targets, dataset.label_names 117 118 119def calculate_precision( 120 true_boxes, true_labels, det_boxes, det_labels, tp, fp, top_k, iou_thres 121): 122 import torchvision 123 124 def collect_data(boxes, labels, top_k=-1): 125 # extract data up to top_k length 126 top_k = labels.size(0) if top_k == -1 else top_k 127 len_labels = min(labels.size(0), top_k) 128 boxes, labels = boxes[:len_labels, :], labels[:len_labels] 129 # how many labels do we have in current data 130 cls = set(labels[:len_labels].tolist()) 131 map = {index: [] for index in cls} 132 # stack data in same class 133 for j in range(len_labels): 134 index = labels[j].item() 135 if index in cls: 136 map[index].append(boxes[j, :]) 137 return {k: torch.stack(v) for k, v in map.items()} 138 139 preds = collect_data(det_boxes, det_labels, top_k=top_k) 140 targets = collect_data(true_boxes.squeeze(0), true_labels.squeeze(0)) 141 # evaluate data with labels presenting in ground truth data 142 for index in targets.keys(): 143 # there is no precision gain for predictions not present in ground truth data 144 if index in preds: 145 # targets shape: (M, 4), preds shape: (N, 4) 146 # shape after box_iou: (M, N), iou shape: (M) 147 # true-positive: how many predictions meet the iou threshold. i.e. k of M 148 # false-positive: M - true-positive = M - k 149 iou, _ = torchvision.ops.box_iou(targets[index], preds[index]).max(0) 150 tps = torch.where(iou >= iou_thres, 1, 0).sum().item() 151 tp[index - 1] += tps 152 fp[index - 1] += iou.nelement() - tps 153 154 155def eval_metric(instance, heads, images, bboxes, targets, classes): 156 tp, fp = classes * [0], classes * [0] 157 head_label = ["cls_logits", "bbox_regression"] 158 159 # feature size should be changed if input size got altered 160 feature_size = [80, 40, 20, 10, 5] 161 feature_maps = [torch.zeros(1, 256, h, h) for h in feature_size] 162 for head, image, true_boxes, true_labels in zip(heads, images, bboxes, targets): 163 anchors = instance.anchor_generator( 164 image_list=image, 165 feature_maps=feature_maps, 166 ) 167 num_anchors_per_level = [hw**2 * 9 for hw in feature_size] 168 # split outputs per level 169 split_head_outputs = { 170 head_label[i]: list(h.split(num_anchors_per_level, dim=1)) 171 for i, h in enumerate(head) 172 } 173 split_anchors = [list(a.split(num_anchors_per_level)) for a in anchors] 174 # compute the detections (based on official post-process method) 175 detection = instance.postprocess_detections( 176 head_outputs=split_head_outputs, 177 anchors=split_anchors, 178 image_shapes=[image.image_sizes], 179 ) 180 # no contribution to precision 181 if len(true_labels) == 0: 182 continue 183 # here we select top 10 confidence and iou >= 0.5 as the criteria 184 calculate_precision( 185 true_boxes=true_boxes, 186 true_labels=true_labels, 187 det_boxes=detection[0]["boxes"], 188 det_labels=detection[0]["labels"], 189 tp=tp, 190 fp=fp, 191 top_k=10, 192 iou_thres=0.5, 193 ) 194 195 # remove labels which does not appear in current dataset 196 AP = torch.Tensor( 197 [ 198 tp[i] * 1.0 / (tp[i] + fp[i]) if tp[i] + fp[i] > 0 else -1 199 for i in range(len(tp)) 200 ] 201 ) 202 missed_labels = torch.where(AP == -1, 1, 0).sum() 203 mAP = AP.where(AP != -1, 0).sum() / (AP.nelement() - missed_labels) 204 return AP, mAP.item() 205 206 207def main(args): 208 from pprint import PrettyPrinter 209 210 from torchvision.models.detection.image_list import ImageList 211 212 skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) 213 214 # ensure the working directory exist 215 os.makedirs(args.artifact, exist_ok=True) 216 217 if not args.compile_only and args.device is None: 218 raise RuntimeError( 219 "device serial is required if not compile only. " 220 "Please specify a device serial by -s/--device argument." 221 ) 222 223 model = get_instance() 224 225 # retrieve dataset 226 data_num = 100 227 # 91 classes appear in COCO dataset 228 n_classes, n_coord_of_bbox = 91, 4 229 inputs, input_list, bboxes, targets, label_names = get_dataset( 230 data_size=data_num, dataset_dir=args.dataset 231 ) 232 pte_filename = "retinanet_qnn" 233 build_executorch_binary( 234 model, 235 inputs[0], 236 args.model, 237 f"{args.artifact}/{pte_filename}", 238 inputs, 239 skip_node_id_set=skip_node_id_set, 240 skip_node_op_set=skip_node_op_set, 241 quant_dtype=QuantDtype.use_8a8w, 242 shared_buffer=args.shared_buffer, 243 ) 244 245 if args.compile_only: 246 sys.exit(0) 247 248 adb = SimpleADB( 249 qnn_sdk=os.getenv("QNN_SDK_ROOT"), 250 build_path=f"{args.build_folder}", 251 pte_path=f"{args.artifact}/{pte_filename}.pte", 252 workspace=f"/data/local/tmp/executorch/{pte_filename}", 253 device_id=args.device, 254 host_id=args.host, 255 soc_model=args.model, 256 shared_buffer=args.shared_buffer, 257 ) 258 adb.push(inputs=inputs, input_list=input_list) 259 adb.execute() 260 261 # collect output data 262 output_data_folder = f"{args.artifact}/outputs" 263 make_output_dir(output_data_folder) 264 adb.pull(output_path=args.artifact) 265 266 predictions, classes = [], [n_classes, n_coord_of_bbox] 267 for i in range(data_num): 268 result = [] 269 for j, dim in enumerate(classes): 270 data_np = np.fromfile( 271 os.path.join(output_data_folder, f"output_{i}_{j}.raw"), 272 dtype=np.float32, 273 ) 274 result.append(torch.from_numpy(data_np).reshape(1, -1, dim)) 275 predictions.append(result) 276 277 # evaluate metrics 278 AP, mAP = eval_metric( 279 instance=model, 280 heads=predictions, 281 images=[ImageList(img[0], tuple(img[0].shape[-2:])) for img in inputs], 282 bboxes=bboxes, 283 targets=targets, 284 classes=n_classes, 285 ) 286 287 if args.ip and args.port != -1: 288 with Client((args.ip, args.port)) as conn: 289 conn.send(json.dumps({"mAP": mAP})) 290 else: 291 print("\nMean Average Precision (mAP): %.3f" % mAP) 292 print("\nAverage Precision of Classes (AP):") 293 PrettyPrinter().pprint( 294 {label_names[i + 1]: AP[i].item() for i in range(n_classes) if AP[i] != -1} 295 ) 296 297 298if __name__ == "__main__": 299 parser = setup_common_args_and_variables() 300 301 parser.add_argument( 302 "-a", 303 "--artifact", 304 help="path for storing generated artifacts by this example. " 305 "Default ./retinanet", 306 default="./retinanet", 307 type=str, 308 ) 309 parser.add_argument( 310 "-d", 311 "--dataset", 312 help=( 313 "path to the validation folder of COCO2017 dataset. " 314 "e.g. --dataset PATH/TO/COCO (which contains 'val_2017' & 'annotations'), " 315 "dataset could be downloaded via http://images.cocodataset.org/zips/val2017.zip & " 316 "http://images.cocodataset.org/annotations/annotations_trainval2017.zip" 317 ), 318 type=str, 319 required=True, 320 ) 321 322 args = parser.parse_args() 323 try: 324 main(args) 325 except Exception as e: 326 if args.ip and args.port != -1: 327 with Client((args.ip, args.port)) as conn: 328 conn.send(json.dumps({"Error": str(e)})) 329 else: 330 raise Exception(e) 331