xref: /aosp_15_r20/external/pytorch/test/onnx/model_defs/super_resolution.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import torch.nn as nn
2import torch.nn.init as init
3
4
5class SuperResolutionNet(nn.Module):
6    def __init__(self, upscale_factor):
7        super().__init__()
8
9        self.relu = nn.ReLU()
10        self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
11        self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
12        self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
13        self.conv4 = nn.Conv2d(32, upscale_factor**2, (3, 3), (1, 1), (1, 1))
14        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
15
16        self._initialize_weights()
17
18    def forward(self, x):
19        x = self.relu(self.conv1(x))
20        x = self.relu(self.conv2(x))
21        x = self.relu(self.conv3(x))
22        x = self.pixel_shuffle(self.conv4(x))
23        return x
24
25    def _initialize_weights(self):
26        init.orthogonal_(self.conv1.weight, init.calculate_gain("relu"))
27        init.orthogonal_(self.conv2.weight, init.calculate_gain("relu"))
28        init.orthogonal_(self.conv3.weight, init.calculate_gain("relu"))
29        init.orthogonal_(self.conv4.weight)
30