xref: /aosp_15_r20/external/pytorch/test/onnx/model_defs/srresnet.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import math
2
3from torch import nn
4from torch.nn import init
5
6
7def _initialize_orthogonal(conv):
8    prelu_gain = math.sqrt(2)
9    init.orthogonal(conv.weight, gain=prelu_gain)
10    if conv.bias is not None:
11        conv.bias.data.zero_()
12
13
14class ResidualBlock(nn.Module):
15    def __init__(self, n_filters):
16        super().__init__()
17        self.conv1 = nn.Conv2d(
18            n_filters, n_filters, kernel_size=3, padding=1, bias=False
19        )
20        self.bn1 = nn.BatchNorm2d(n_filters)
21        self.prelu = nn.PReLU(n_filters)
22        self.conv2 = nn.Conv2d(
23            n_filters, n_filters, kernel_size=3, padding=1, bias=False
24        )
25        self.bn2 = nn.BatchNorm2d(n_filters)
26
27        # Orthogonal initialisation
28        _initialize_orthogonal(self.conv1)
29        _initialize_orthogonal(self.conv2)
30
31    def forward(self, x):
32        residual = self.prelu(self.bn1(self.conv1(x)))
33        residual = self.bn2(self.conv2(residual))
34        return x + residual
35
36
37class UpscaleBlock(nn.Module):
38    def __init__(self, n_filters):
39        super().__init__()
40        self.upscaling_conv = nn.Conv2d(
41            n_filters, 4 * n_filters, kernel_size=3, padding=1
42        )
43        self.upscaling_shuffler = nn.PixelShuffle(2)
44        self.upscaling = nn.PReLU(n_filters)
45        _initialize_orthogonal(self.upscaling_conv)
46
47    def forward(self, x):
48        return self.upscaling(self.upscaling_shuffler(self.upscaling_conv(x)))
49
50
51class SRResNet(nn.Module):
52    def __init__(self, rescale_factor, n_filters, n_blocks):
53        super().__init__()
54        self.rescale_levels = int(math.log(rescale_factor, 2))  # noqa: FURB163
55        self.n_filters = n_filters
56        self.n_blocks = n_blocks
57
58        self.conv1 = nn.Conv2d(3, n_filters, kernel_size=9, padding=4)
59        self.prelu1 = nn.PReLU(n_filters)
60
61        for residual_block_num in range(1, n_blocks + 1):
62            residual_block = ResidualBlock(self.n_filters)
63            self.add_module(
64                "residual_block" + str(residual_block_num),
65                nn.Sequential(residual_block),
66            )
67
68        self.skip_conv = nn.Conv2d(
69            n_filters, n_filters, kernel_size=3, padding=1, bias=False
70        )
71        self.skip_bn = nn.BatchNorm2d(n_filters)
72
73        for upscale_block_num in range(1, self.rescale_levels + 1):
74            upscale_block = UpscaleBlock(self.n_filters)
75            self.add_module(
76                "upscale_block" + str(upscale_block_num), nn.Sequential(upscale_block)
77            )
78
79        self.output_conv = nn.Conv2d(n_filters, 3, kernel_size=9, padding=4)
80
81        # Orthogonal initialisation
82        _initialize_orthogonal(self.conv1)
83        _initialize_orthogonal(self.skip_conv)
84        _initialize_orthogonal(self.output_conv)
85
86    def forward(self, x):
87        x_init = self.prelu1(self.conv1(x))
88        x = self.residual_block1(x_init)
89        for residual_block_num in range(2, self.n_blocks + 1):
90            x = getattr(self, "residual_block" + str(residual_block_num))(x)
91        x = self.skip_bn(self.skip_conv(x)) + x_init
92        for upscale_block_num in range(1, self.rescale_levels + 1):
93            x = getattr(self, "upscale_block" + str(upscale_block_num))(x)
94        return self.output_conv(x)
95