1import torch.nn as nn 2import torch.nn.init as init 3 4 5class SuperResolutionNet(nn.Module): 6 def __init__(self, upscale_factor): 7 super().__init__() 8 9 self.relu = nn.ReLU() 10 self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2)) 11 self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) 12 self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1)) 13 self.conv4 = nn.Conv2d(32, upscale_factor**2, (3, 3), (1, 1), (1, 1)) 14 self.pixel_shuffle = nn.PixelShuffle(upscale_factor) 15 16 self._initialize_weights() 17 18 def forward(self, x): 19 x = self.relu(self.conv1(x)) 20 x = self.relu(self.conv2(x)) 21 x = self.relu(self.conv3(x)) 22 x = self.pixel_shuffle(self.conv4(x)) 23 return x 24 25 def _initialize_weights(self): 26 init.orthogonal_(self.conv1.weight, init.calculate_gain("relu")) 27 init.orthogonal_(self.conv2.weight, init.calculate_gain("relu")) 28 init.orthogonal_(self.conv3.weight, init.calculate_gain("relu")) 29 init.orthogonal_(self.conv4.weight) 30