1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import json 8import os 9import sys 10from multiprocessing.connection import Client 11from pprint import PrettyPrinter 12 13import numpy as np 14import torch 15 16from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype 17from executorch.examples.qualcomm.utils import ( 18 build_executorch_binary, 19 make_output_dir, 20 parse_skip_delegation_node, 21 setup_common_args_and_variables, 22 SimpleADB, 23) 24 25 26def create_data_lists(voc07_path, data_size): 27 """ 28 Create lists of images, the bounding boxes and labels of the objects in these images, and save these to file. 29 30 :param voc07_path: path to the 'VOC2007' folder 31 """ 32 from utils import parse_annotation 33 34 voc07_path = os.path.abspath(voc07_path) 35 36 # Test data 37 test_images = [] 38 test_objects = [] 39 n_objects = 0 40 41 # Find IDs of images in the test data 42 with open(os.path.join(voc07_path, "ImageSets/Main/test.txt")) as f: 43 ids = f.read().splitlines() 44 45 for index, id in enumerate(ids): 46 if index >= data_size: 47 break 48 # Parse annotation's XML file 49 objects = parse_annotation(os.path.join(voc07_path, "Annotations", id + ".xml")) 50 if len(objects) == 0: 51 continue 52 test_objects.append(objects) 53 n_objects += len(objects) 54 test_images.append(os.path.join(voc07_path, "JPEGImages", id + ".jpg")) 55 56 assert len(test_objects) == len(test_images) 57 58 # TEST_images.json stores the file name of the images, and TEST_objects.json stores info such as boxes, labels, and difficulties 59 with open(os.path.join(voc07_path, "TEST_images.json"), "w") as j: 60 json.dump(test_images, j) 61 with open(os.path.join(voc07_path, "TEST_objects.json"), "w") as j: 62 json.dump(test_objects, j) 63 64 print( 65 "\nThere are %d test images containing a total of %d objects. Files have been saved to %s." 66 % (len(test_images), n_objects, os.path.abspath(voc07_path)) 67 ) 68 69 70def get_dataset(data_size, dataset_dir, download): 71 from datasets import PascalVOCDataset 72 from torchvision import datasets 73 74 if download: 75 datasets.VOCSegmentation( 76 root=os.path.join(dataset_dir, "voc_image"), 77 year="2007", 78 image_set="test", 79 download=True, 80 ) 81 voc07_path = os.path.join(dataset_dir, "voc_image", "VOCdevkit", "VOC2007") 82 create_data_lists(voc07_path, data_size) 83 84 # voc07_path is where the data and ground truth json file will be stored 85 test_dataset = PascalVOCDataset(voc07_path, split="test", keep_difficult=True) 86 87 test_loader = torch.utils.data.DataLoader( 88 test_dataset, shuffle=True, collate_fn=test_dataset.collate_fn 89 ) 90 91 inputs, input_list = [], "" 92 true_boxes = [] 93 true_labels = [] 94 true_difficulties = [] 95 for index, (images, boxes, labels, difficulties) in enumerate(test_loader): 96 if index >= data_size: 97 break 98 inputs.append((images,)) 99 input_list += f"input_{index}_0.raw\n" 100 true_boxes.extend(boxes) 101 true_labels.extend(labels) 102 true_difficulties.extend(difficulties) 103 104 return inputs, input_list, true_boxes, true_labels, true_difficulties 105 106 107def SSD300VGG16(pretrained_weight_model): 108 from model import SSD300 109 110 model = SSD300(n_classes=21) 111 # TODO: If possible, it's better to set weights_only to True 112 # https://pytorch.org/docs/stable/generated/torch.load.html 113 checkpoint = torch.load( 114 pretrained_weight_model, map_location="cpu", weights_only=False 115 ) 116 model.load_state_dict(checkpoint["model"].state_dict()) 117 118 return model.eval() 119 120 121def main(args): 122 sys.path.insert(0, args.oss_repo) 123 124 skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) 125 126 # ensure the working directory exist. 127 os.makedirs(args.artifact, exist_ok=True) 128 129 if not args.compile_only and args.device is None: 130 raise RuntimeError( 131 "device serial is required if not compile only. " 132 "Please specify a device serial by -s/--device argument." 133 ) 134 135 data_num = 100 136 inputs, input_list, true_boxes, true_labels, true_difficulties = get_dataset( 137 data_size=data_num, dataset_dir=args.artifact, download=args.download 138 ) 139 140 pte_filename = "ssd300_vgg16_qnn" 141 model = SSD300VGG16(args.pretrained_weight) 142 143 sample_input = (torch.randn((1, 3, 300, 300)),) 144 build_executorch_binary( 145 model, 146 sample_input, 147 args.model, 148 f"{args.artifact}/{pte_filename}", 149 inputs, 150 skip_node_id_set=skip_node_id_set, 151 skip_node_op_set=skip_node_op_set, 152 quant_dtype=QuantDtype.use_8a8w, 153 shared_buffer=args.shared_buffer, 154 ) 155 156 if args.compile_only: 157 return 158 159 adb = SimpleADB( 160 qnn_sdk=os.getenv("QNN_SDK_ROOT"), 161 build_path=f"{args.build_folder}", 162 pte_path=f"{args.artifact}/{pte_filename}.pte", 163 workspace=f"/data/local/tmp/executorch/{pte_filename}", 164 device_id=args.device, 165 host_id=args.host, 166 soc_model=args.model, 167 ) 168 adb.push(inputs=inputs, input_list=input_list) 169 adb.execute() 170 171 # collect output data 172 output_data_folder = f"{args.artifact}/outputs" 173 make_output_dir(output_data_folder) 174 175 det_boxes = [] 176 det_labels = [] 177 det_scores = [] 178 179 def post_process(): 180 from utils import calculate_mAP 181 182 np.set_printoptions(threshold=np.inf) 183 184 # output_xxx_0.raw is output of boxes, and output_xxx_1.raw is output of classes 185 for file_index in range(data_num): 186 boxes_filename = os.path.join( 187 output_data_folder, f"output_{file_index}_0.raw" 188 ) 189 category_filename = os.path.join( 190 output_data_folder, f"output_{file_index}_1.raw" 191 ) 192 193 predicted_locs = np.fromfile(boxes_filename, dtype=np.float32).reshape( 194 [1, 8732, 4] 195 ) 196 predicted_locs = torch.tensor(predicted_locs) 197 198 predicted_scores = np.fromfile(category_filename, dtype=np.float32).reshape( 199 [1, 8732, 21] 200 ) 201 predicted_scores = torch.tensor(predicted_scores) 202 203 det_boxes_batch, det_labels_batch, det_scores_batch = model.detect_objects( 204 predicted_locs, 205 predicted_scores, 206 min_score=0.01, 207 max_overlap=0.45, 208 top_k=200, 209 ) 210 211 det_boxes.extend(det_boxes_batch) 212 det_labels.extend(det_labels_batch) 213 det_scores.extend(det_scores_batch) 214 215 pp = PrettyPrinter() 216 # Calculate mAP 217 APs, mAP = calculate_mAP( 218 det_boxes, 219 det_labels, 220 det_scores, 221 true_boxes, 222 true_labels, 223 true_difficulties, 224 ) 225 if args.ip and args.port != -1: 226 with Client((args.ip, args.port)) as conn: 227 conn.send(json.dumps({"mAP": float(mAP)})) 228 else: 229 print("\nMean Average Precision (mAP): %.3f" % mAP) 230 pp.pprint(APs) 231 232 adb.pull(output_path=args.artifact, callback=post_process) 233 234 235if __name__ == "__main__": 236 parser = setup_common_args_and_variables() 237 238 parser.add_argument( 239 "-a", 240 "--artifact", 241 help="path for storing generated artifacts by this example. Default ./ssd300_vgg16", 242 default="./ssd300_vgg16", 243 type=str, 244 ) 245 246 parser.add_argument( 247 "-d", 248 "--download", 249 help="If specified, download VOCSegmentation dataset by torchvision API", 250 action="store_true", 251 default=False, 252 ) 253 254 parser.add_argument( 255 "--oss_repo", 256 help=( 257 "Repository that contains model backbone and score calculation." 258 "e.g., --M ./a-PyTorch-Tutorial-to-Object-Detection" 259 "Please clone the repository from https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Object-Detection" 260 ), 261 type=str, 262 required=True, 263 ) 264 265 parser.add_argument( 266 "-p", 267 "--pretrained_weight", 268 help=( 269 "Location of model pretrained weight." 270 "e.g., -p ./checkpoint_ssd300.pth.tar" 271 "Pretrained model can be found in the link https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Object-Detection, under the Training Section" 272 ), 273 type=str, 274 required=True, 275 ) 276 277 args = parser.parse_args() 278 try: 279 main(args) 280 except Exception as e: 281 if args.ip and args.port != -1: 282 with Client((args.ip, args.port)) as conn: 283 conn.send(json.dumps({"Error": str(e)})) 284 else: 285 raise Exception(e) 286