# Copyright (c) Qualcomm Innovation Center, Inc. # All rights reserved # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import json import os import re from multiprocessing.connection import Client import numpy as np import piq import torch from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.examples.models.edsr import EdsrModel from executorch.examples.qualcomm.utils import ( build_executorch_binary, make_output_dir, parse_skip_delegation_node, setup_common_args_and_variables, SimpleADB, ) from PIL import Image from torch.utils.data import Dataset from torchsr.datasets import B100 from torchvision.transforms.functional import to_pil_image, to_tensor class SrDataset(Dataset): def __init__(self, hr_dir: str, lr_dir: str): self.input_size = np.asanyarray([224, 224]) self.hr = [] self.lr = [] for file in sorted(os.listdir(hr_dir)): self.hr.append(self._resize_img(os.path.join(hr_dir, file), 2)) for file in sorted(os.listdir(lr_dir)): self.lr.append(self._resize_img(os.path.join(lr_dir, file), 1)) if len(self.hr) != len(self.lr): raise AssertionError( "The number of high resolution pics is not equal to low " "resolution pics" ) def __getitem__(self, idx: int): return self.hr[idx], self.lr[idx] def __len__(self): return len(self.lr) def _resize_img(self, file: str, scale: int): with Image.open(file) as img: return to_tensor(img.resize(tuple(self.input_size * scale))).unsqueeze(0) def get_input_list(self): input_list = "" for i in range(len(self.lr)): input_list += f"input_{i}_0.raw\n" return input_list def get_b100( dataset_dir: str, ): hr_dir = f"{dataset_dir}/sr_bm_dataset/SRBenchmarks/benchmark/B100/HR" lr_dir = f"{dataset_dir}/sr_bm_dataset/SRBenchmarks/benchmark/B100/LR_bicubic/X2" if not os.path.exists(hr_dir) or not os.path.exists(lr_dir): B100(root=f"{dataset_dir}/sr_bm_dataset", scale=2, download=True) return SrDataset(hr_dir, lr_dir) def get_dataset(hr_dir: str, lr_dir: str, default_dataset: str, dataset_dir: str): if not (lr_dir and hr_dir) and not default_dataset: raise RuntimeError( "Nither custom dataset is provided nor using default dataset." ) if (lr_dir and hr_dir) and default_dataset: raise RuntimeError("Either use custom dataset, or use default dataset.") if default_dataset: return get_b100(dataset_dir) return SrDataset(hr_dir, lr_dir) def main(args): skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) # ensure the working directory exist. os.makedirs(args.artifact, exist_ok=True) if not args.compile_only and args.device is None: raise RuntimeError( "device serial is required if not compile only. " "Please specify a device serial by -s/--device argument." ) dataset = get_dataset( args.hr_ref_dir, args.lr_dir, args.default_dataset, args.artifact ) inputs, targets, input_list = dataset.lr, dataset.hr, dataset.get_input_list() pte_filename = "edsr_qnn_q8" instance = EdsrModel() build_executorch_binary( instance.get_eager_model().eval(), (inputs[0],), args.model, f"{args.artifact}/{pte_filename}", [(input,) for input in inputs], skip_node_id_set=skip_node_id_set, skip_node_op_set=skip_node_op_set, quant_dtype=QuantDtype.use_8a8w, shared_buffer=args.shared_buffer, ) if args.compile_only: return adb = SimpleADB( qnn_sdk=os.getenv("QNN_SDK_ROOT"), build_path=f"{args.build_folder}", pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", device_id=args.device, host_id=args.host, soc_model=args.model, shared_buffer=args.shared_buffer, ) adb.push(inputs=inputs, input_list=input_list) adb.execute() # collect output data output_data_folder = f"{args.artifact}/outputs" output_pic_folder = f"{args.artifact}/output_pics" make_output_dir(output_data_folder) make_output_dir(output_pic_folder) output_raws = [] def post_process(): cnt = 0 output_shape = tuple(targets[0].size()) for f in sorted( os.listdir(output_data_folder), key=lambda f: int(f.split("_")[1]) ): filename = os.path.join(output_data_folder, f) if re.match(r"^output_[0-9]+_[1-9].raw$", f): os.remove(filename) else: output = np.fromfile(filename, dtype=np.float32) output = torch.tensor(output).reshape(output_shape).clamp(0, 1) output_raws.append(output) to_pil_image(output.squeeze(0)).save( os.path.join(output_pic_folder, str(cnt) + ".png") ) cnt += 1 adb.pull(output_path=args.artifact, callback=post_process) psnr_list = [] ssim_list = [] for i, hr in enumerate(targets): psnr_list.append(piq.psnr(hr, output_raws[i])) ssim_list.append(piq.ssim(hr, output_raws[i])) avg_PSNR = sum(psnr_list).item() / len(psnr_list) avg_SSIM = sum(ssim_list).item() / len(ssim_list) if args.ip and args.port != -1: with Client((args.ip, args.port)) as conn: conn.send(json.dumps({"PSNR": avg_PSNR, "SSIM": avg_SSIM})) else: print(f"Average of PNSR is: {avg_PSNR}") print(f"Average of SSIM is: {avg_SSIM}") if __name__ == "__main__": parser = setup_common_args_and_variables() parser.add_argument( "-a", "--artifact", help="path for storing generated artifacts by this example. Default ./edsr", default="./edsr", type=str, ) parser.add_argument( "-r", "--hr_ref_dir", help="Path to the high resolution images", default="", type=str, ) parser.add_argument( "-l", "--lr_dir", help="Path to the low resolution image inputs", default="", type=str, ) parser.add_argument( "-d", "--default_dataset", help="If specified, download and use B100 dataset by torchSR API", action="store_true", default=False, ) args = parser.parse_args() try: main(args) except Exception as e: if args.ip and args.port != -1: with Client((args.ip, args.port)) as conn: conn.send(json.dumps({"Error": str(e)})) else: raise Exception(e)