xref: /aosp_15_r20/external/pytorch/test/package/package_a/test_nn_module.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: package/deploy"]
2
3import torch
4
5
6class TestNnModule(torch.nn.Module):
7    def __init__(self, nz=6, ngf=9, nc=3):
8        super().__init__()
9        self.main = torch.nn.Sequential(
10            # input is Z, going into a convolution
11            torch.nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
12            torch.nn.BatchNorm2d(ngf * 8),
13            torch.nn.ReLU(True),
14            # state size. (ngf*8) x 4 x 4
15            torch.nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
16            torch.nn.BatchNorm2d(ngf * 4),
17            torch.nn.ReLU(True),
18            # state size. (ngf*4) x 8 x 8
19            torch.nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
20            torch.nn.BatchNorm2d(ngf * 2),
21            torch.nn.ReLU(True),
22            # state size. (ngf*2) x 16 x 16
23            torch.nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
24            torch.nn.BatchNorm2d(ngf),
25            torch.nn.ReLU(True),
26            # state size. (ngf) x 32 x 32
27            torch.nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
28            torch.nn.Tanh()
29            # state size. (nc) x 64 x 64
30        )
31
32    def forward(self, input):
33        return self.main(input)
34