1import math 2 3from torch import nn 4from torch.nn import init 5 6 7def _initialize_orthogonal(conv): 8 prelu_gain = math.sqrt(2) 9 init.orthogonal(conv.weight, gain=prelu_gain) 10 if conv.bias is not None: 11 conv.bias.data.zero_() 12 13 14class ResidualBlock(nn.Module): 15 def __init__(self, n_filters): 16 super().__init__() 17 self.conv1 = nn.Conv2d( 18 n_filters, n_filters, kernel_size=3, padding=1, bias=False 19 ) 20 self.bn1 = nn.BatchNorm2d(n_filters) 21 self.prelu = nn.PReLU(n_filters) 22 self.conv2 = nn.Conv2d( 23 n_filters, n_filters, kernel_size=3, padding=1, bias=False 24 ) 25 self.bn2 = nn.BatchNorm2d(n_filters) 26 27 # Orthogonal initialisation 28 _initialize_orthogonal(self.conv1) 29 _initialize_orthogonal(self.conv2) 30 31 def forward(self, x): 32 residual = self.prelu(self.bn1(self.conv1(x))) 33 residual = self.bn2(self.conv2(residual)) 34 return x + residual 35 36 37class UpscaleBlock(nn.Module): 38 def __init__(self, n_filters): 39 super().__init__() 40 self.upscaling_conv = nn.Conv2d( 41 n_filters, 4 * n_filters, kernel_size=3, padding=1 42 ) 43 self.upscaling_shuffler = nn.PixelShuffle(2) 44 self.upscaling = nn.PReLU(n_filters) 45 _initialize_orthogonal(self.upscaling_conv) 46 47 def forward(self, x): 48 return self.upscaling(self.upscaling_shuffler(self.upscaling_conv(x))) 49 50 51class SRResNet(nn.Module): 52 def __init__(self, rescale_factor, n_filters, n_blocks): 53 super().__init__() 54 self.rescale_levels = int(math.log(rescale_factor, 2)) # noqa: FURB163 55 self.n_filters = n_filters 56 self.n_blocks = n_blocks 57 58 self.conv1 = nn.Conv2d(3, n_filters, kernel_size=9, padding=4) 59 self.prelu1 = nn.PReLU(n_filters) 60 61 for residual_block_num in range(1, n_blocks + 1): 62 residual_block = ResidualBlock(self.n_filters) 63 self.add_module( 64 "residual_block" + str(residual_block_num), 65 nn.Sequential(residual_block), 66 ) 67 68 self.skip_conv = nn.Conv2d( 69 n_filters, n_filters, kernel_size=3, padding=1, bias=False 70 ) 71 self.skip_bn = nn.BatchNorm2d(n_filters) 72 73 for upscale_block_num in range(1, self.rescale_levels + 1): 74 upscale_block = UpscaleBlock(self.n_filters) 75 self.add_module( 76 "upscale_block" + str(upscale_block_num), nn.Sequential(upscale_block) 77 ) 78 79 self.output_conv = nn.Conv2d(n_filters, 3, kernel_size=9, padding=4) 80 81 # Orthogonal initialisation 82 _initialize_orthogonal(self.conv1) 83 _initialize_orthogonal(self.skip_conv) 84 _initialize_orthogonal(self.output_conv) 85 86 def forward(self, x): 87 x_init = self.prelu1(self.conv1(x)) 88 x = self.residual_block1(x_init) 89 for residual_block_num in range(2, self.n_blocks + 1): 90 x = getattr(self, "residual_block" + str(residual_block_num))(x) 91 x = self.skip_bn(self.skip_conv(x)) + x_init 92 for upscale_block_num in range(1, self.rescale_levels + 1): 93 x = getattr(self, "upscale_block" + str(upscale_block_num))(x) 94 return self.output_conv(x) 95