xref: /aosp_15_r20/external/executorch/examples/qualcomm/oss_scripts/esrgan.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import json
8import os
9from multiprocessing.connection import Client
10
11import numpy as np
12import piq
13import torch
14
15from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
16from executorch.examples.qualcomm.scripts.edsr import get_dataset
17from executorch.examples.qualcomm.utils import (
18    build_executorch_binary,
19    make_output_dir,
20    parse_skip_delegation_node,
21    setup_common_args_and_variables,
22    SimpleADB,
23)
24
25from torchvision.transforms.functional import to_pil_image
26
27
28def get_instance(repo: str):
29    import sys
30
31    sys.path.insert(0, repo)
32
33    from RealESRGAN import RealESRGAN
34
35    # required by layout transform
36    sys.setrecursionlimit(2000)
37    model = RealESRGAN(torch.device("cpu"), scale=2)
38    model.load_weights("weights/RealESRGAN_x2.pth", download=True)
39    return model.model.eval()
40
41
42def main(args):
43    skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args)
44
45    # ensure the working directory exist.
46    os.makedirs(args.artifact, exist_ok=True)
47
48    if not args.compile_only and args.device is None:
49        raise RuntimeError(
50            "device serial is required if not compile only. "
51            "Please specify a device serial by -s/--device argument."
52        )
53
54    dataset = get_dataset(
55        args.hr_ref_dir, args.lr_dir, args.default_dataset, args.artifact
56    )
57
58    inputs, targets, input_list = dataset.lr, dataset.hr, dataset.get_input_list()
59    pte_filename = "esrgan_qnn"
60    instance = get_instance(args.oss_repo)
61
62    build_executorch_binary(
63        instance,
64        (inputs[0],),
65        args.model,
66        f"{args.artifact}/{pte_filename}",
67        [(input,) for input in inputs],
68        skip_node_id_set=skip_node_id_set,
69        skip_node_op_set=skip_node_op_set,
70        quant_dtype=QuantDtype.use_8a8w,
71        shared_buffer=args.shared_buffer,
72    )
73
74    if args.compile_only:
75        return
76
77    adb = SimpleADB(
78        qnn_sdk=os.getenv("QNN_SDK_ROOT"),
79        build_path=f"{args.build_folder}",
80        pte_path=f"{args.artifact}/{pte_filename}.pte",
81        workspace=f"/data/local/tmp/executorch/{pte_filename}",
82        device_id=args.device,
83        host_id=args.host,
84        soc_model=args.model,
85    )
86    adb.push(inputs=inputs, input_list=input_list)
87    adb.execute()
88
89    # collect output data
90    output_data_folder = f"{args.artifact}/outputs"
91    output_pic_folder = f"{args.artifact}/output_pics"
92    make_output_dir(output_data_folder)
93    make_output_dir(output_pic_folder)
94
95    output_raws = []
96
97    def post_process():
98        cnt = 0
99        output_shape = tuple(targets[0].size())
100        for f in sorted(
101            os.listdir(output_data_folder), key=lambda f: int(f.split("_")[1])
102        ):
103            filename = os.path.join(output_data_folder, f)
104            output = np.fromfile(filename, dtype=np.float32)
105            output = torch.tensor(output).reshape(output_shape).clamp(0, 1)
106            output_raws.append(output)
107            to_pil_image(output.squeeze(0)).save(
108                os.path.join(output_pic_folder, str(cnt) + ".png")
109            )
110            cnt += 1
111
112    adb.pull(output_path=args.artifact, callback=post_process)
113
114    psnr_list = []
115    ssim_list = []
116    for i, hr in enumerate(targets):
117        psnr_list.append(piq.psnr(hr, output_raws[i]))
118        ssim_list.append(piq.ssim(hr, output_raws[i]))
119
120    avg_PSNR = sum(psnr_list).item() / len(psnr_list)
121    avg_SSIM = sum(ssim_list).item() / len(ssim_list)
122    if args.ip and args.port != -1:
123        with Client((args.ip, args.port)) as conn:
124            conn.send(json.dumps({"PSNR": avg_PSNR, "SSIM": avg_SSIM}))
125    else:
126        print(f"Average of PSNR is: {avg_PSNR}")
127        print(f"Average of SSIM is: {avg_SSIM}")
128
129
130if __name__ == "__main__":
131    parser = setup_common_args_and_variables()
132
133    parser.add_argument(
134        "-a",
135        "--artifact",
136        help="path for storing generated artifacts by this example. Default ./esrgan",
137        default="./esrgan",
138        type=str,
139    )
140
141    parser.add_argument(
142        "-r",
143        "--hr_ref_dir",
144        help="Path to the high resolution images",
145        default="",
146        type=str,
147    )
148
149    parser.add_argument(
150        "-l",
151        "--lr_dir",
152        help="Path to the low resolution image inputs",
153        default="",
154        type=str,
155    )
156
157    parser.add_argument(
158        "-d",
159        "--default_dataset",
160        help="If specified, download and use B100 dataset by torchSR API",
161        action="store_true",
162        default=False,
163    )
164
165    parser.add_argument(
166        "--oss_repo",
167        help="Path to cloned https://github.com/ai-forever/Real-ESRGAN",
168        type=str,
169        required=True,
170    )
171
172    args = parser.parse_args()
173    try:
174        main(args)
175    except Exception as e:
176        if args.ip and args.port != -1:
177            with Client((args.ip, args.port)) as conn:
178                conn.send(json.dumps({"Error": str(e)}))
179        else:
180            raise Exception(e)
181