1# Owner(s): ["oncall: package/deploy"] 2 3import torch 4 5 6class TestNnModule(torch.nn.Module): 7 def __init__(self, nz=6, ngf=9, nc=3): 8 super().__init__() 9 self.main = torch.nn.Sequential( 10 # input is Z, going into a convolution 11 torch.nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False), 12 torch.nn.BatchNorm2d(ngf * 8), 13 torch.nn.ReLU(True), 14 # state size. (ngf*8) x 4 x 4 15 torch.nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), 16 torch.nn.BatchNorm2d(ngf * 4), 17 torch.nn.ReLU(True), 18 # state size. (ngf*4) x 8 x 8 19 torch.nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), 20 torch.nn.BatchNorm2d(ngf * 2), 21 torch.nn.ReLU(True), 22 # state size. (ngf*2) x 16 x 16 23 torch.nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), 24 torch.nn.BatchNorm2d(ngf), 25 torch.nn.ReLU(True), 26 # state size. (ngf) x 32 x 32 27 torch.nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False), 28 torch.nn.Tanh() 29 # state size. (nc) x 64 x 64 30 ) 31 32 def forward(self, input): 33 return self.main(input) 34