1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Qualcomm Innovation Center, Inc. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Workerimport json 8*523fa7a6SAndroid Build Coastguard Workerimport os 9*523fa7a6SAndroid Build Coastguard Workerimport re 10*523fa7a6SAndroid Build Coastguard Workerfrom multiprocessing.connection import Client 11*523fa7a6SAndroid Build Coastguard Worker 12*523fa7a6SAndroid Build Coastguard Workerimport numpy as np 13*523fa7a6SAndroid Build Coastguard Workerimport piq 14*523fa7a6SAndroid Build Coastguard Workerimport torch 15*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm.quantizer.quantizer import QuantDtype 16*523fa7a6SAndroid Build Coastguard Workerfrom executorch.examples.models.edsr import EdsrModel 17*523fa7a6SAndroid Build Coastguard Workerfrom executorch.examples.qualcomm.utils import ( 18*523fa7a6SAndroid Build Coastguard Worker build_executorch_binary, 19*523fa7a6SAndroid Build Coastguard Worker make_output_dir, 20*523fa7a6SAndroid Build Coastguard Worker parse_skip_delegation_node, 21*523fa7a6SAndroid Build Coastguard Worker setup_common_args_and_variables, 22*523fa7a6SAndroid Build Coastguard Worker SimpleADB, 23*523fa7a6SAndroid Build Coastguard Worker) 24*523fa7a6SAndroid Build Coastguard Worker 25*523fa7a6SAndroid Build Coastguard Workerfrom PIL import Image 26*523fa7a6SAndroid Build Coastguard Workerfrom torch.utils.data import Dataset 27*523fa7a6SAndroid Build Coastguard Workerfrom torchsr.datasets import B100 28*523fa7a6SAndroid Build Coastguard Workerfrom torchvision.transforms.functional import to_pil_image, to_tensor 29*523fa7a6SAndroid Build Coastguard Worker 30*523fa7a6SAndroid Build Coastguard Worker 31*523fa7a6SAndroid Build Coastguard Workerclass SrDataset(Dataset): 32*523fa7a6SAndroid Build Coastguard Worker def __init__(self, hr_dir: str, lr_dir: str): 33*523fa7a6SAndroid Build Coastguard Worker self.input_size = np.asanyarray([224, 224]) 34*523fa7a6SAndroid Build Coastguard Worker self.hr = [] 35*523fa7a6SAndroid Build Coastguard Worker self.lr = [] 36*523fa7a6SAndroid Build Coastguard Worker 37*523fa7a6SAndroid Build Coastguard Worker for file in sorted(os.listdir(hr_dir)): 38*523fa7a6SAndroid Build Coastguard Worker self.hr.append(self._resize_img(os.path.join(hr_dir, file), 2)) 39*523fa7a6SAndroid Build Coastguard Worker 40*523fa7a6SAndroid Build Coastguard Worker for file in sorted(os.listdir(lr_dir)): 41*523fa7a6SAndroid Build Coastguard Worker self.lr.append(self._resize_img(os.path.join(lr_dir, file), 1)) 42*523fa7a6SAndroid Build Coastguard Worker 43*523fa7a6SAndroid Build Coastguard Worker if len(self.hr) != len(self.lr): 44*523fa7a6SAndroid Build Coastguard Worker raise AssertionError( 45*523fa7a6SAndroid Build Coastguard Worker "The number of high resolution pics is not equal to low " 46*523fa7a6SAndroid Build Coastguard Worker "resolution pics" 47*523fa7a6SAndroid Build Coastguard Worker ) 48*523fa7a6SAndroid Build Coastguard Worker 49*523fa7a6SAndroid Build Coastguard Worker def __getitem__(self, idx: int): 50*523fa7a6SAndroid Build Coastguard Worker return self.hr[idx], self.lr[idx] 51*523fa7a6SAndroid Build Coastguard Worker 52*523fa7a6SAndroid Build Coastguard Worker def __len__(self): 53*523fa7a6SAndroid Build Coastguard Worker return len(self.lr) 54*523fa7a6SAndroid Build Coastguard Worker 55*523fa7a6SAndroid Build Coastguard Worker def _resize_img(self, file: str, scale: int): 56*523fa7a6SAndroid Build Coastguard Worker with Image.open(file) as img: 57*523fa7a6SAndroid Build Coastguard Worker return to_tensor(img.resize(tuple(self.input_size * scale))).unsqueeze(0) 58*523fa7a6SAndroid Build Coastguard Worker 59*523fa7a6SAndroid Build Coastguard Worker def get_input_list(self): 60*523fa7a6SAndroid Build Coastguard Worker input_list = "" 61*523fa7a6SAndroid Build Coastguard Worker for i in range(len(self.lr)): 62*523fa7a6SAndroid Build Coastguard Worker input_list += f"input_{i}_0.raw\n" 63*523fa7a6SAndroid Build Coastguard Worker return input_list 64*523fa7a6SAndroid Build Coastguard Worker 65*523fa7a6SAndroid Build Coastguard Worker 66*523fa7a6SAndroid Build Coastguard Workerdef get_b100( 67*523fa7a6SAndroid Build Coastguard Worker dataset_dir: str, 68*523fa7a6SAndroid Build Coastguard Worker): 69*523fa7a6SAndroid Build Coastguard Worker hr_dir = f"{dataset_dir}/sr_bm_dataset/SRBenchmarks/benchmark/B100/HR" 70*523fa7a6SAndroid Build Coastguard Worker lr_dir = f"{dataset_dir}/sr_bm_dataset/SRBenchmarks/benchmark/B100/LR_bicubic/X2" 71*523fa7a6SAndroid Build Coastguard Worker 72*523fa7a6SAndroid Build Coastguard Worker if not os.path.exists(hr_dir) or not os.path.exists(lr_dir): 73*523fa7a6SAndroid Build Coastguard Worker B100(root=f"{dataset_dir}/sr_bm_dataset", scale=2, download=True) 74*523fa7a6SAndroid Build Coastguard Worker 75*523fa7a6SAndroid Build Coastguard Worker return SrDataset(hr_dir, lr_dir) 76*523fa7a6SAndroid Build Coastguard Worker 77*523fa7a6SAndroid Build Coastguard Worker 78*523fa7a6SAndroid Build Coastguard Workerdef get_dataset(hr_dir: str, lr_dir: str, default_dataset: str, dataset_dir: str): 79*523fa7a6SAndroid Build Coastguard Worker if not (lr_dir and hr_dir) and not default_dataset: 80*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError( 81*523fa7a6SAndroid Build Coastguard Worker "Nither custom dataset is provided nor using default dataset." 82*523fa7a6SAndroid Build Coastguard Worker ) 83*523fa7a6SAndroid Build Coastguard Worker 84*523fa7a6SAndroid Build Coastguard Worker if (lr_dir and hr_dir) and default_dataset: 85*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError("Either use custom dataset, or use default dataset.") 86*523fa7a6SAndroid Build Coastguard Worker 87*523fa7a6SAndroid Build Coastguard Worker if default_dataset: 88*523fa7a6SAndroid Build Coastguard Worker return get_b100(dataset_dir) 89*523fa7a6SAndroid Build Coastguard Worker 90*523fa7a6SAndroid Build Coastguard Worker return SrDataset(hr_dir, lr_dir) 91*523fa7a6SAndroid Build Coastguard Worker 92*523fa7a6SAndroid Build Coastguard Worker 93*523fa7a6SAndroid Build Coastguard Workerdef main(args): 94*523fa7a6SAndroid Build Coastguard Worker skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) 95*523fa7a6SAndroid Build Coastguard Worker 96*523fa7a6SAndroid Build Coastguard Worker # ensure the working directory exist. 97*523fa7a6SAndroid Build Coastguard Worker os.makedirs(args.artifact, exist_ok=True) 98*523fa7a6SAndroid Build Coastguard Worker 99*523fa7a6SAndroid Build Coastguard Worker if not args.compile_only and args.device is None: 100*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError( 101*523fa7a6SAndroid Build Coastguard Worker "device serial is required if not compile only. " 102*523fa7a6SAndroid Build Coastguard Worker "Please specify a device serial by -s/--device argument." 103*523fa7a6SAndroid Build Coastguard Worker ) 104*523fa7a6SAndroid Build Coastguard Worker 105*523fa7a6SAndroid Build Coastguard Worker dataset = get_dataset( 106*523fa7a6SAndroid Build Coastguard Worker args.hr_ref_dir, args.lr_dir, args.default_dataset, args.artifact 107*523fa7a6SAndroid Build Coastguard Worker ) 108*523fa7a6SAndroid Build Coastguard Worker 109*523fa7a6SAndroid Build Coastguard Worker inputs, targets, input_list = dataset.lr, dataset.hr, dataset.get_input_list() 110*523fa7a6SAndroid Build Coastguard Worker pte_filename = "edsr_qnn_q8" 111*523fa7a6SAndroid Build Coastguard Worker instance = EdsrModel() 112*523fa7a6SAndroid Build Coastguard Worker 113*523fa7a6SAndroid Build Coastguard Worker build_executorch_binary( 114*523fa7a6SAndroid Build Coastguard Worker instance.get_eager_model().eval(), 115*523fa7a6SAndroid Build Coastguard Worker (inputs[0],), 116*523fa7a6SAndroid Build Coastguard Worker args.model, 117*523fa7a6SAndroid Build Coastguard Worker f"{args.artifact}/{pte_filename}", 118*523fa7a6SAndroid Build Coastguard Worker [(input,) for input in inputs], 119*523fa7a6SAndroid Build Coastguard Worker skip_node_id_set=skip_node_id_set, 120*523fa7a6SAndroid Build Coastguard Worker skip_node_op_set=skip_node_op_set, 121*523fa7a6SAndroid Build Coastguard Worker quant_dtype=QuantDtype.use_8a8w, 122*523fa7a6SAndroid Build Coastguard Worker shared_buffer=args.shared_buffer, 123*523fa7a6SAndroid Build Coastguard Worker ) 124*523fa7a6SAndroid Build Coastguard Worker 125*523fa7a6SAndroid Build Coastguard Worker if args.compile_only: 126*523fa7a6SAndroid Build Coastguard Worker return 127*523fa7a6SAndroid Build Coastguard Worker 128*523fa7a6SAndroid Build Coastguard Worker adb = SimpleADB( 129*523fa7a6SAndroid Build Coastguard Worker qnn_sdk=os.getenv("QNN_SDK_ROOT"), 130*523fa7a6SAndroid Build Coastguard Worker build_path=f"{args.build_folder}", 131*523fa7a6SAndroid Build Coastguard Worker pte_path=f"{args.artifact}/{pte_filename}.pte", 132*523fa7a6SAndroid Build Coastguard Worker workspace=f"/data/local/tmp/executorch/{pte_filename}", 133*523fa7a6SAndroid Build Coastguard Worker device_id=args.device, 134*523fa7a6SAndroid Build Coastguard Worker host_id=args.host, 135*523fa7a6SAndroid Build Coastguard Worker soc_model=args.model, 136*523fa7a6SAndroid Build Coastguard Worker shared_buffer=args.shared_buffer, 137*523fa7a6SAndroid Build Coastguard Worker ) 138*523fa7a6SAndroid Build Coastguard Worker adb.push(inputs=inputs, input_list=input_list) 139*523fa7a6SAndroid Build Coastguard Worker adb.execute() 140*523fa7a6SAndroid Build Coastguard Worker 141*523fa7a6SAndroid Build Coastguard Worker # collect output data 142*523fa7a6SAndroid Build Coastguard Worker output_data_folder = f"{args.artifact}/outputs" 143*523fa7a6SAndroid Build Coastguard Worker output_pic_folder = f"{args.artifact}/output_pics" 144*523fa7a6SAndroid Build Coastguard Worker make_output_dir(output_data_folder) 145*523fa7a6SAndroid Build Coastguard Worker make_output_dir(output_pic_folder) 146*523fa7a6SAndroid Build Coastguard Worker 147*523fa7a6SAndroid Build Coastguard Worker output_raws = [] 148*523fa7a6SAndroid Build Coastguard Worker 149*523fa7a6SAndroid Build Coastguard Worker def post_process(): 150*523fa7a6SAndroid Build Coastguard Worker cnt = 0 151*523fa7a6SAndroid Build Coastguard Worker output_shape = tuple(targets[0].size()) 152*523fa7a6SAndroid Build Coastguard Worker for f in sorted( 153*523fa7a6SAndroid Build Coastguard Worker os.listdir(output_data_folder), key=lambda f: int(f.split("_")[1]) 154*523fa7a6SAndroid Build Coastguard Worker ): 155*523fa7a6SAndroid Build Coastguard Worker filename = os.path.join(output_data_folder, f) 156*523fa7a6SAndroid Build Coastguard Worker if re.match(r"^output_[0-9]+_[1-9].raw$", f): 157*523fa7a6SAndroid Build Coastguard Worker os.remove(filename) 158*523fa7a6SAndroid Build Coastguard Worker else: 159*523fa7a6SAndroid Build Coastguard Worker output = np.fromfile(filename, dtype=np.float32) 160*523fa7a6SAndroid Build Coastguard Worker output = torch.tensor(output).reshape(output_shape).clamp(0, 1) 161*523fa7a6SAndroid Build Coastguard Worker output_raws.append(output) 162*523fa7a6SAndroid Build Coastguard Worker to_pil_image(output.squeeze(0)).save( 163*523fa7a6SAndroid Build Coastguard Worker os.path.join(output_pic_folder, str(cnt) + ".png") 164*523fa7a6SAndroid Build Coastguard Worker ) 165*523fa7a6SAndroid Build Coastguard Worker cnt += 1 166*523fa7a6SAndroid Build Coastguard Worker 167*523fa7a6SAndroid Build Coastguard Worker adb.pull(output_path=args.artifact, callback=post_process) 168*523fa7a6SAndroid Build Coastguard Worker 169*523fa7a6SAndroid Build Coastguard Worker psnr_list = [] 170*523fa7a6SAndroid Build Coastguard Worker ssim_list = [] 171*523fa7a6SAndroid Build Coastguard Worker for i, hr in enumerate(targets): 172*523fa7a6SAndroid Build Coastguard Worker psnr_list.append(piq.psnr(hr, output_raws[i])) 173*523fa7a6SAndroid Build Coastguard Worker ssim_list.append(piq.ssim(hr, output_raws[i])) 174*523fa7a6SAndroid Build Coastguard Worker 175*523fa7a6SAndroid Build Coastguard Worker avg_PSNR = sum(psnr_list).item() / len(psnr_list) 176*523fa7a6SAndroid Build Coastguard Worker avg_SSIM = sum(ssim_list).item() / len(ssim_list) 177*523fa7a6SAndroid Build Coastguard Worker if args.ip and args.port != -1: 178*523fa7a6SAndroid Build Coastguard Worker with Client((args.ip, args.port)) as conn: 179*523fa7a6SAndroid Build Coastguard Worker conn.send(json.dumps({"PSNR": avg_PSNR, "SSIM": avg_SSIM})) 180*523fa7a6SAndroid Build Coastguard Worker else: 181*523fa7a6SAndroid Build Coastguard Worker print(f"Average of PNSR is: {avg_PSNR}") 182*523fa7a6SAndroid Build Coastguard Worker print(f"Average of SSIM is: {avg_SSIM}") 183*523fa7a6SAndroid Build Coastguard Worker 184*523fa7a6SAndroid Build Coastguard Worker 185*523fa7a6SAndroid Build Coastguard Workerif __name__ == "__main__": 186*523fa7a6SAndroid Build Coastguard Worker parser = setup_common_args_and_variables() 187*523fa7a6SAndroid Build Coastguard Worker 188*523fa7a6SAndroid Build Coastguard Worker parser.add_argument( 189*523fa7a6SAndroid Build Coastguard Worker "-a", 190*523fa7a6SAndroid Build Coastguard Worker "--artifact", 191*523fa7a6SAndroid Build Coastguard Worker help="path for storing generated artifacts by this example. Default ./edsr", 192*523fa7a6SAndroid Build Coastguard Worker default="./edsr", 193*523fa7a6SAndroid Build Coastguard Worker type=str, 194*523fa7a6SAndroid Build Coastguard Worker ) 195*523fa7a6SAndroid Build Coastguard Worker 196*523fa7a6SAndroid Build Coastguard Worker parser.add_argument( 197*523fa7a6SAndroid Build Coastguard Worker "-r", 198*523fa7a6SAndroid Build Coastguard Worker "--hr_ref_dir", 199*523fa7a6SAndroid Build Coastguard Worker help="Path to the high resolution images", 200*523fa7a6SAndroid Build Coastguard Worker default="", 201*523fa7a6SAndroid Build Coastguard Worker type=str, 202*523fa7a6SAndroid Build Coastguard Worker ) 203*523fa7a6SAndroid Build Coastguard Worker 204*523fa7a6SAndroid Build Coastguard Worker parser.add_argument( 205*523fa7a6SAndroid Build Coastguard Worker "-l", 206*523fa7a6SAndroid Build Coastguard Worker "--lr_dir", 207*523fa7a6SAndroid Build Coastguard Worker help="Path to the low resolution image inputs", 208*523fa7a6SAndroid Build Coastguard Worker default="", 209*523fa7a6SAndroid Build Coastguard Worker type=str, 210*523fa7a6SAndroid Build Coastguard Worker ) 211*523fa7a6SAndroid Build Coastguard Worker 212*523fa7a6SAndroid Build Coastguard Worker parser.add_argument( 213*523fa7a6SAndroid Build Coastguard Worker "-d", 214*523fa7a6SAndroid Build Coastguard Worker "--default_dataset", 215*523fa7a6SAndroid Build Coastguard Worker help="If specified, download and use B100 dataset by torchSR API", 216*523fa7a6SAndroid Build Coastguard Worker action="store_true", 217*523fa7a6SAndroid Build Coastguard Worker default=False, 218*523fa7a6SAndroid Build Coastguard Worker ) 219*523fa7a6SAndroid Build Coastguard Worker 220*523fa7a6SAndroid Build Coastguard Worker args = parser.parse_args() 221*523fa7a6SAndroid Build Coastguard Worker try: 222*523fa7a6SAndroid Build Coastguard Worker main(args) 223*523fa7a6SAndroid Build Coastguard Worker except Exception as e: 224*523fa7a6SAndroid Build Coastguard Worker if args.ip and args.port != -1: 225*523fa7a6SAndroid Build Coastguard Worker with Client((args.ip, args.port)) as conn: 226*523fa7a6SAndroid Build Coastguard Worker conn.send(json.dumps({"Error": str(e)})) 227*523fa7a6SAndroid Build Coastguard Worker else: 228*523fa7a6SAndroid Build Coastguard Worker raise Exception(e) 229