1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import json 8import os 9from multiprocessing.connection import Client 10 11import numpy as np 12import piq 13import torch 14 15from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype 16from executorch.examples.qualcomm.scripts.edsr import get_dataset 17from executorch.examples.qualcomm.utils import ( 18 build_executorch_binary, 19 make_output_dir, 20 parse_skip_delegation_node, 21 setup_common_args_and_variables, 22 SimpleADB, 23) 24 25from torchvision.transforms.functional import to_pil_image 26 27 28def get_instance(repo: str): 29 import sys 30 31 sys.path.insert(0, repo) 32 33 from RealESRGAN import RealESRGAN 34 35 # required by layout transform 36 sys.setrecursionlimit(2000) 37 model = RealESRGAN(torch.device("cpu"), scale=2) 38 model.load_weights("weights/RealESRGAN_x2.pth", download=True) 39 return model.model.eval() 40 41 42def main(args): 43 skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) 44 45 # ensure the working directory exist. 46 os.makedirs(args.artifact, exist_ok=True) 47 48 if not args.compile_only and args.device is None: 49 raise RuntimeError( 50 "device serial is required if not compile only. " 51 "Please specify a device serial by -s/--device argument." 52 ) 53 54 dataset = get_dataset( 55 args.hr_ref_dir, args.lr_dir, args.default_dataset, args.artifact 56 ) 57 58 inputs, targets, input_list = dataset.lr, dataset.hr, dataset.get_input_list() 59 pte_filename = "esrgan_qnn" 60 instance = get_instance(args.oss_repo) 61 62 build_executorch_binary( 63 instance, 64 (inputs[0],), 65 args.model, 66 f"{args.artifact}/{pte_filename}", 67 [(input,) for input in inputs], 68 skip_node_id_set=skip_node_id_set, 69 skip_node_op_set=skip_node_op_set, 70 quant_dtype=QuantDtype.use_8a8w, 71 shared_buffer=args.shared_buffer, 72 ) 73 74 if args.compile_only: 75 return 76 77 adb = SimpleADB( 78 qnn_sdk=os.getenv("QNN_SDK_ROOT"), 79 build_path=f"{args.build_folder}", 80 pte_path=f"{args.artifact}/{pte_filename}.pte", 81 workspace=f"/data/local/tmp/executorch/{pte_filename}", 82 device_id=args.device, 83 host_id=args.host, 84 soc_model=args.model, 85 ) 86 adb.push(inputs=inputs, input_list=input_list) 87 adb.execute() 88 89 # collect output data 90 output_data_folder = f"{args.artifact}/outputs" 91 output_pic_folder = f"{args.artifact}/output_pics" 92 make_output_dir(output_data_folder) 93 make_output_dir(output_pic_folder) 94 95 output_raws = [] 96 97 def post_process(): 98 cnt = 0 99 output_shape = tuple(targets[0].size()) 100 for f in sorted( 101 os.listdir(output_data_folder), key=lambda f: int(f.split("_")[1]) 102 ): 103 filename = os.path.join(output_data_folder, f) 104 output = np.fromfile(filename, dtype=np.float32) 105 output = torch.tensor(output).reshape(output_shape).clamp(0, 1) 106 output_raws.append(output) 107 to_pil_image(output.squeeze(0)).save( 108 os.path.join(output_pic_folder, str(cnt) + ".png") 109 ) 110 cnt += 1 111 112 adb.pull(output_path=args.artifact, callback=post_process) 113 114 psnr_list = [] 115 ssim_list = [] 116 for i, hr in enumerate(targets): 117 psnr_list.append(piq.psnr(hr, output_raws[i])) 118 ssim_list.append(piq.ssim(hr, output_raws[i])) 119 120 avg_PSNR = sum(psnr_list).item() / len(psnr_list) 121 avg_SSIM = sum(ssim_list).item() / len(ssim_list) 122 if args.ip and args.port != -1: 123 with Client((args.ip, args.port)) as conn: 124 conn.send(json.dumps({"PSNR": avg_PSNR, "SSIM": avg_SSIM})) 125 else: 126 print(f"Average of PSNR is: {avg_PSNR}") 127 print(f"Average of SSIM is: {avg_SSIM}") 128 129 130if __name__ == "__main__": 131 parser = setup_common_args_and_variables() 132 133 parser.add_argument( 134 "-a", 135 "--artifact", 136 help="path for storing generated artifacts by this example. Default ./esrgan", 137 default="./esrgan", 138 type=str, 139 ) 140 141 parser.add_argument( 142 "-r", 143 "--hr_ref_dir", 144 help="Path to the high resolution images", 145 default="", 146 type=str, 147 ) 148 149 parser.add_argument( 150 "-l", 151 "--lr_dir", 152 help="Path to the low resolution image inputs", 153 default="", 154 type=str, 155 ) 156 157 parser.add_argument( 158 "-d", 159 "--default_dataset", 160 help="If specified, download and use B100 dataset by torchSR API", 161 action="store_true", 162 default=False, 163 ) 164 165 parser.add_argument( 166 "--oss_repo", 167 help="Path to cloned https://github.com/ai-forever/Real-ESRGAN", 168 type=str, 169 required=True, 170 ) 171 172 args = parser.parse_args() 173 try: 174 main(args) 175 except Exception as e: 176 if args.ip and args.port != -1: 177 with Client((args.ip, args.port)) as conn: 178 conn.send(json.dumps({"Error": str(e)})) 179 else: 180 raise Exception(e) 181