1# Owner(s): ["oncall: jit"] 2 3import os 4import sys 5import unittest 6 7import torch 8import torch.nn as nn 9import torch.nn.functional as F 10from torch.testing._internal.common_utils import ( 11 enable_profiling_mode_for_profiling_tests, 12 GRAPH_EXECUTOR, 13 ProfilingMode, 14 set_default_dtype, 15) 16 17 18# Make the helper files in test/ importable 19pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 20sys.path.append(pytorch_test_dir) 21from torch.testing._internal.common_utils import slowTest, suppress_warnings 22from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA 23 24 25if __name__ == "__main__": 26 raise RuntimeError( 27 "This test file is not meant to be run directly, use:\n\n" 28 "\tpython test/test_jit.py TESTNAME\n\n" 29 "instead." 30 ) 31 32try: 33 import torchvision 34 35 HAS_TORCHVISION = True 36except ImportError: 37 HAS_TORCHVISION = False 38except RuntimeError: 39 HAS_TORCHVISION = False 40skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") 41 42 43class MnistNet(nn.Module): 44 def __init__(self) -> None: 45 super().__init__() 46 self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 47 self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 48 self.conv2_drop = nn.Dropout2d() 49 self.fc1 = nn.Linear(320, 50) 50 self.fc2 = nn.Linear(50, 10) 51 52 def forward(self, x): 53 x = F.relu(F.max_pool2d(self.conv1(x), 2)) 54 x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 55 x = x.reshape(-1, 320) 56 x = F.relu(self.fc1(x)) 57 x = F.dropout(x, training=self.training) 58 x = self.fc2(x) 59 return F.log_softmax(x, dim=1) 60 61 62class TestModels(JitTestCase): 63 @staticmethod 64 def _test_dcgan_models(self, device, check_export_import=True): 65 class DCGANGenerator(nn.Module): 66 def __init__(self, nz, ngf, nc): 67 super().__init__() 68 self.main = nn.Sequential( 69 # input is Z, going into a convolution 70 nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False), 71 nn.BatchNorm2d(ngf * 8), 72 nn.ReLU(True), 73 # state size. (ngf*8) x 4 x 4 74 nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), 75 nn.BatchNorm2d(ngf * 4), 76 nn.ReLU(True), 77 # state size. (ngf*4) x 8 x 8 78 nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), 79 nn.BatchNorm2d(ngf * 2), 80 nn.ReLU(True), 81 # state size. (ngf*2) x 16 x 16 82 nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), 83 nn.BatchNorm2d(ngf), 84 nn.ReLU(True), 85 # state size. (ngf) x 32 x 32 86 nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False), 87 nn.Tanh() 88 # state size. (nc) x 64 x 64 89 ) 90 91 def forward(self, input): 92 return self.main(input) 93 94 class DCGANDiscriminator(nn.Module): 95 def __init__(self, nc, ndf): 96 super().__init__() 97 self.main = nn.Sequential( 98 # input is (nc) x 64 x 64 99 nn.Conv2d(nc, ndf, 4, 2, 1, bias=False), 100 nn.LeakyReLU(0.2, inplace=True), 101 # state size. (ndf) x 32 x 32 102 nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), 103 nn.BatchNorm2d(ndf * 2), 104 nn.LeakyReLU(0.2, inplace=True), 105 # state size. (ndf*2) x 16 x 16 106 nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), 107 nn.BatchNorm2d(ndf * 4), 108 nn.LeakyReLU(0.2, inplace=True), 109 # state size. (ndf*4) x 8 x 8 110 nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), 111 nn.BatchNorm2d(ndf * 8), 112 nn.LeakyReLU(0.2, inplace=True), 113 # state size. (ndf*8) x 4 x 4 114 nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False), 115 nn.Sigmoid(), 116 ) 117 118 def forward(self, input): 119 return self.main(input).view(-1, 1).squeeze(1) 120 121 bs, nz, ngf, nc, ndf = 5, 6, 9, 3, 10 122 self.checkTrace( 123 DCGANGenerator(nz, ngf, nc).to(device), 124 (torch.rand(bs, nz, 1, 1, device=device),), 125 export_import=check_export_import, 126 ) 127 example_input = DCGANGenerator(nz, ngf, nc).to(device)( 128 torch.rand(bs, nz, 1, 1, device=device) 129 ) 130 self.checkTrace( 131 DCGANDiscriminator(nc, ndf).to(device), 132 (example_input,), 133 export_import=check_export_import, 134 ) 135 136 def test_dcgan_models(self): 137 # Note: Can sometimes fail with low precision if run with float dtype 138 with set_default_dtype(torch.double): 139 self._test_dcgan_models(self, device="cpu") 140 141 @unittest.skipIf(not RUN_CUDA, "no CUDA") 142 def test_dcgan_models_cuda(self): 143 # Note: Can sometimes fail with low precision if run with float dtype 144 with set_default_dtype(torch.double): 145 # XXX: export_import on CUDA modules doesn't work (#11480) 146 self._test_dcgan_models(self, device="cuda", check_export_import=False) 147 148 @staticmethod 149 def _test_neural_style(self, device, check_export_import=True): 150 class TransformerNet(torch.nn.Module): 151 def __init__(self) -> None: 152 super().__init__() 153 # Initial convolution layers 154 self.conv1 = ConvLayer(3, 32, kernel_size=9, stride=1) 155 self.in1 = torch.nn.InstanceNorm2d(32, affine=True) 156 self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2) 157 self.in2 = torch.nn.InstanceNorm2d(64, affine=True) 158 self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2) 159 self.in3 = torch.nn.InstanceNorm2d(128, affine=True) 160 # Residual layers 161 self.res1 = ResidualBlock(128) 162 self.res2 = ResidualBlock(128) 163 self.res3 = ResidualBlock(128) 164 self.res4 = ResidualBlock(128) 165 self.res5 = ResidualBlock(128) 166 # Upsampling Layers 167 self.deconv1 = UpsampleConvLayer( 168 128, 64, kernel_size=3, stride=1, upsample=2 169 ) 170 self.in4 = torch.nn.InstanceNorm2d(64, affine=True) 171 self.deconv2 = UpsampleConvLayer( 172 64, 32, kernel_size=3, stride=1, upsample=2 173 ) 174 self.in5 = torch.nn.InstanceNorm2d(32, affine=True) 175 self.deconv3 = ConvLayer(32, 3, kernel_size=9, stride=1) 176 # Non-linearities 177 self.relu = torch.nn.ReLU() 178 179 def forward(self, X): 180 y = self.relu(self.in1(self.conv1(X))) 181 y = self.relu(self.in2(self.conv2(y))) 182 y = self.relu(self.in3(self.conv3(y))) 183 y = self.res1(y) 184 y = self.res2(y) 185 y = self.res3(y) 186 y = self.res4(y) 187 y = self.res5(y) 188 y = self.relu(self.in4(self.deconv1(y))) 189 y = self.relu(self.in5(self.deconv2(y))) 190 y = self.deconv3(y) 191 return y 192 193 class ConvLayer(torch.nn.Module): 194 def __init__(self, in_channels, out_channels, kernel_size, stride): 195 super().__init__() 196 reflection_padding = kernel_size // 2 197 self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding) 198 self.conv2d = torch.nn.Conv2d( 199 in_channels, out_channels, kernel_size, stride 200 ) 201 202 def forward(self, x): 203 out = self.reflection_pad(x) 204 out = self.conv2d(out) 205 return out 206 207 class ResidualBlock(torch.nn.Module): 208 """ResidualBlock 209 introduced in: https://arxiv.org/abs/1512.03385 210 recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html 211 """ 212 213 def __init__(self, channels): 214 super().__init__() 215 self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1) 216 self.in1 = torch.nn.InstanceNorm2d(channels, affine=True) 217 self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1) 218 self.in2 = torch.nn.InstanceNorm2d(channels, affine=True) 219 self.relu = torch.nn.ReLU() 220 221 def forward(self, x): 222 residual = x 223 out = self.relu(self.in1(self.conv1(x))) 224 out = self.in2(self.conv2(out)) 225 out = out + residual 226 return out 227 228 class UpsampleConvLayer(torch.nn.Module): 229 """UpsampleConvLayer 230 Upsamples the input and then does a convolution. This method gives better results 231 compared to ConvTranspose2d. 232 ref: http://distill.pub/2016/deconv-checkerboard/ 233 """ 234 235 def __init__( 236 self, in_channels, out_channels, kernel_size, stride, upsample=None 237 ): 238 super().__init__() 239 self.upsample = upsample 240 if upsample: 241 self.upsample_layer = torch.nn.Upsample( 242 mode="nearest", scale_factor=upsample 243 ) 244 reflection_padding = kernel_size // 2 245 self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding) 246 self.conv2d = torch.nn.Conv2d( 247 in_channels, out_channels, kernel_size, stride 248 ) 249 250 def forward(self, x): 251 x_in = x 252 if self.upsample: 253 x_in = self.upsample_layer(x_in) 254 out = self.reflection_pad(x_in) 255 out = self.conv2d(out) 256 return out 257 258 self.checkTrace( 259 TransformerNet(), 260 (torch.rand(5, 3, 16, 16),), 261 export_import=check_export_import, 262 ) 263 264 @slowTest 265 def test_neural_style(self): 266 self._test_neural_style(self, device="cpu") 267 268 @unittest.skipIf(not RUN_CUDA, "no CUDA") 269 def test_neural_style_cuda(self): 270 # XXX: export_import on CUDA modules doesn't work (#11480) 271 self._test_neural_style(self, device="cuda", check_export_import=False) 272 273 @unittest.skipIf( 274 GRAPH_EXECUTOR == ProfilingMode.LEGACY, "Bug found in deprecated executor" 275 ) 276 @staticmethod 277 def _test_mnist(self, device, check_export_import=True): 278 # eval() is present because dropout makes this nondeterministic 279 with enable_profiling_mode_for_profiling_tests(): 280 self.checkTrace( 281 MnistNet().to(device).eval(), 282 (torch.rand(5, 1, 28, 28, device=device),), 283 export_import=check_export_import, 284 ) 285 286 def test_mnist(self): 287 self._test_mnist(self, device="cpu") 288 289 @unittest.skipIf(not RUN_CUDA, "no CUDA") 290 def test_mnist_cuda(self): 291 # XXX: export_import on CUDA modules doesn't work (#11480) 292 self._test_mnist(self, device="cuda", check_export_import=False) 293 294 @unittest.skipIf(not RUN_CUDA, "no CUDA") 295 def test_mnist_training_leaks_no_memory_cuda(self): 296 net = MnistNet().cuda() 297 # MnistNet uses dropout, don't check its trace 298 traced_net = torch.jit.trace( 299 net, [torch.randn(5, 1, 28, 28, device="cuda")], check_trace=False 300 ) 301 302 def train(iters): 303 for _ in range(iters): 304 # Get some fake data 305 inp = torch.randn(5, 1, 28, 28, device="cuda") 306 out = traced_net(inp) 307 308 # Here's some fake loss 309 out.sum().backward() 310 311 # Zero out grads 312 traced_net.zero_grad() 313 314 # Set it up so the params have .grad fields so they are not reported as leaks 315 train(1) 316 317 with self.assertLeaksNoCudaTensors(): 318 train(5) 319 320 @staticmethod 321 def _test_reinforcement_learning(self, device, test_export_import=True): 322 class Policy(nn.Module): 323 def __init__(self) -> None: 324 super().__init__() 325 self.affine1 = nn.Linear(4, 128) 326 self.affine2 = nn.Linear(128, 2) 327 328 def forward(self, x): 329 x = F.relu(self.affine1(x)) 330 action_scores = self.affine2(x) 331 return F.softmax(action_scores, dim=1) 332 333 with enable_profiling_mode_for_profiling_tests(): 334 self.checkTrace( 335 Policy().to(device), 336 (torch.rand(1, 4, device=device),), 337 export_import=test_export_import, 338 ) 339 340 def test_reinforcement_learning(self): 341 self._test_reinforcement_learning(self, device="cpu") 342 343 @unittest.skipIf(not RUN_CUDA, "no CUDA") 344 def test_reinforcement_learning_cuda(self): 345 # XXX: export_import on CUDA modules doesn't work (#11480) 346 self._test_reinforcement_learning(self, device="cuda", test_export_import=False) 347 348 @staticmethod 349 def _test_snli(self, device, check_export_import=True): 350 class Bottle(nn.Module): 351 def forward(self, input): 352 if len(input.size()) <= 2: 353 return super().forward(input) 354 size = input.size()[:2] 355 out = super().forward(input.view(size[0] * size[1], -1)) 356 return out.view(size[0], size[1], -1) 357 358 class Linear(Bottle, nn.Linear): 359 pass 360 361 class Encoder(nn.Module): 362 def __init__(self, config): 363 super().__init__() 364 self.config = config 365 input_size = config.d_proj if config.projection else config.d_embed 366 dropout = 0 if config.n_layers == 1 else config.dp_ratio 367 self.rnn = nn.LSTM( 368 input_size=input_size, 369 hidden_size=config.d_hidden, 370 num_layers=config.n_layers, 371 dropout=dropout, 372 bidirectional=config.birnn, 373 ) 374 375 def forward(self, inputs): 376 batch_size = inputs.size()[1] 377 state_shape = self.config.n_cells, batch_size, self.config.d_hidden 378 h0 = c0 = inputs.new_zeros(state_shape) 379 outputs, (ht, ct) = self.rnn(inputs, (h0, c0)) 380 return ( 381 ht[-1] 382 if not self.config.birnn 383 else ht[-2:].transpose(0, 1).contiguous().view(batch_size, -1) 384 ) 385 386 class SNLIClassifier(nn.Module): 387 def __init__(self, config): 388 super().__init__() 389 self.config = config 390 self.embed = nn.Embedding(config.n_embed, config.d_embed) 391 self.projection = Linear(config.d_embed, config.d_proj) 392 self.encoder = Encoder(config) 393 self.dropout = nn.Dropout(p=config.dp_ratio) 394 self.relu = nn.ReLU() 395 seq_in_size = 2 * config.d_hidden 396 if self.config.birnn: 397 seq_in_size *= 2 398 lin_config = [seq_in_size] * 2 399 self.out = nn.Sequential( 400 Linear(*lin_config), 401 self.relu, 402 self.dropout, 403 Linear(*lin_config), 404 self.relu, 405 self.dropout, 406 Linear(*lin_config), 407 self.relu, 408 self.dropout, 409 Linear(seq_in_size, config.d_out), 410 ) 411 412 def forward(self, premise, hypothesis): 413 prem_embed = self.embed(premise) 414 hypo_embed = self.embed(hypothesis) 415 if self.config.fix_emb: 416 prem_embed = prem_embed.detach() 417 hypo_embed = hypo_embed.detach() 418 if self.config.projection: 419 prem_embed = self.relu(self.projection(prem_embed)) 420 hypo_embed = self.relu(self.projection(hypo_embed)) 421 premise = self.encoder(prem_embed) 422 hypothesis = self.encoder(hypo_embed) 423 scores = self.out(torch.cat([premise, hypothesis], 1)) 424 return scores 425 426 class Config: 427 n_embed = 100 428 d_embed = 100 429 d_proj = 300 430 dp_ratio = 0.0 # For deterministic testing TODO: change by fixing seed in checkTrace? 431 d_hidden = 30 432 birnn = True 433 d_out = 300 434 fix_emb = True 435 projection = True 436 n_layers = 2 437 n_cells = 4 # 2 * n_layers because birnn = True 438 439 premise = torch.LongTensor(48, 64).random_(0, 100).to(device) 440 hypothesis = torch.LongTensor(24, 64).random_(0, 100).to(device) 441 442 self.checkTrace( 443 SNLIClassifier(Config()).to(device), 444 (premise, hypothesis), 445 inputs_require_grads=False, 446 export_import=check_export_import, 447 ) 448 449 @slowTest 450 def test_snli(self): 451 self._test_snli(self, device="cpu") 452 453 @unittest.skipIf(not RUN_CUDA, "no CUDA") 454 def test_snli_cuda(self): 455 # XXX: export_import on CUDA modules doesn't work (#11480) 456 self._test_snli(self, device="cuda", check_export_import=False) 457 458 @staticmethod 459 def _test_super_resolution(self, device, check_export_import=True): 460 class Net(nn.Module): 461 def __init__(self, upscale_factor): 462 super().__init__() 463 464 self.relu = nn.ReLU() 465 self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2)) 466 self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) 467 self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1)) 468 self.conv4 = nn.Conv2d(32, upscale_factor**2, (3, 3), (1, 1), (1, 1)) 469 self.pixel_shuffle = nn.PixelShuffle(upscale_factor) 470 471 def forward(self, x): 472 x = self.relu(self.conv1(x)) 473 x = self.relu(self.conv2(x)) 474 x = self.relu(self.conv3(x)) 475 x = self.pixel_shuffle(self.conv4(x)) 476 return x 477 478 net = Net(upscale_factor=4).to(device) 479 self.checkTrace( 480 net, 481 (torch.rand(5, 1, 32, 32, device=device),), 482 export_import=check_export_import, 483 ) 484 485 @slowTest 486 def test_super_resolution(self): 487 self._test_super_resolution(self, device="cpu") 488 489 @unittest.skipIf(not RUN_CUDA, "no CUDA") 490 def test_super_resolution_cuda(self): 491 # XXX: export_import on CUDA modules doesn't work (#11480) 492 self._test_super_resolution(self, device="cuda", check_export_import=False) 493 494 @suppress_warnings 495 def test_time_sequence_prediction(self): 496 class Sequence(torch.jit.ScriptModule): 497 def __init__(self) -> None: 498 super().__init__() 499 self.lstm1 = nn.LSTMCell(1, 51) 500 self.lstm2 = nn.LSTMCell(51, 51) 501 self.linear = nn.Linear(51, 1) 502 503 @torch.jit.script_method 504 def forward(self, input): 505 # TODO: add future as input with default val 506 # see https://github.com/pytorch/pytorch/issues/8724 507 outputs = torch.empty((3, 0)) 508 h_t = torch.zeros((3, 51)) 509 c_t = torch.zeros((3, 51)) 510 h_t2 = torch.zeros((3, 51)) 511 c_t2 = torch.zeros((3, 51)) 512 513 output = torch.zeros([3, 51]) 514 future = 2 515 516 # TODO: chunk call should appear as the for loop iterable 517 # We hard-code it to 4 for now. 518 a, b, c, d = input.chunk(input.size(1), dim=1) 519 for input_t in (a, b, c, d): 520 h_t, c_t = self.lstm1(input_t, (h_t, c_t)) 521 h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2)) 522 output = self.linear(h_t2) 523 outputs = torch.cat((outputs, output), 1) 524 for _ in range(future): # if we should predict the future 525 h_t, c_t = self.lstm1(output, (h_t, c_t)) 526 h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2)) 527 output = self.linear(h_t2) 528 outputs = torch.cat((outputs, output), 1) 529 return outputs 530 531 class Traced(nn.Module): 532 def __init__(self) -> None: 533 super().__init__() 534 self.seq = Sequence() 535 536 def forward(self, input): 537 return self.seq.forward(input) 538 539 # disabled due to a jitter issues that will be fixed by using load/store in the compiler 540 with torch._jit_internal._disable_emit_hooks(): 541 # TODO: toggle export_import once above issues are fixed 542 self.checkTrace(Traced(), (torch.rand(3, 4),), export_import=False) 543 544 @staticmethod 545 def _test_vae(self, device, check_export_import=True): 546 class VAE(nn.Module): 547 def __init__(self) -> None: 548 super().__init__() 549 550 self.fc1 = nn.Linear(784, 400) 551 self.fc21 = nn.Linear(400, 20) 552 self.fc22 = nn.Linear(400, 20) 553 self.fc3 = nn.Linear(20, 400) 554 self.fc4 = nn.Linear(400, 784) 555 556 def encode(self, x): 557 h1 = F.relu(self.fc1(x)) 558 return self.fc21(h1), self.fc22(h1) 559 560 def reparameterize(self, mu, logvar): 561 if self.training: 562 std = torch.exp(0.5 * logvar) 563 eps = torch.randn_like(std) 564 return eps.mul(std).add_(mu) 565 else: 566 return mu 567 568 def decode(self, z): 569 h3 = F.relu(self.fc3(z)) 570 return torch.sigmoid(self.fc4(h3)) 571 572 def forward(self, x): 573 mu, logvar = self.encode(x.view(-1, 784)) 574 z = self.reparameterize(mu, logvar) 575 return self.decode(z), mu, logvar 576 577 with enable_profiling_mode_for_profiling_tests(): 578 # eval() is present because randn_like makes this nondeterministic 579 self.checkTrace( 580 VAE().to(device).eval(), 581 (torch.rand(128, 1, 28, 28, device=device),), 582 export_import=check_export_import, 583 ) 584 585 def test_vae(self): 586 self._test_vae(self, device="cpu") 587 588 @unittest.skipIf(not RUN_CUDA, "no CUDA") 589 def test_vae_cuda(self): 590 # XXX: export_import on CUDA modules doesn't work (#11480) 591 self._test_vae(self, device="cuda", check_export_import=False) 592 593 @slowTest 594 @skipIfNoTorchVision 595 def test_script_module_trace_resnet18(self): 596 x = torch.ones(1, 3, 224, 224) 597 m_orig = torch.jit.trace( 598 torchvision.models.resnet18(), torch.ones(1, 3, 224, 224) 599 ) 600 m_import = self.getExportImportCopy(m_orig) 601 602 input = torch.randn(1, 3, 224, 224, requires_grad=True) 603 output_orig = m_orig(input) 604 output_orig.sum().backward() 605 grad_orig = input.grad.clone() 606 input.grad.zero_() 607 608 output_import = m_import(input) 609 output_import.sum().backward() 610 grad_import = input.grad.clone() 611 612 self.assertEqual(output_orig, output_import) 613 self.assertEqual(grad_orig, grad_import) 614 615 @slowTest 616 @skipIfNoTorchVision 617 def test_script_module_script_resnet(self): 618 def conv1x1(in_planes, out_planes, stride=1): 619 """1x1 convolution""" 620 return nn.Conv2d( 621 in_planes, out_planes, kernel_size=1, stride=stride, bias=False 622 ) 623 624 def conv3x3(in_planes, out_planes, stride=1): 625 """3x3 convolution with padding""" 626 return nn.Conv2d( 627 in_planes, 628 out_planes, 629 kernel_size=3, 630 stride=stride, 631 padding=1, 632 bias=False, 633 ) 634 635 class BasicBlock(torch.jit.ScriptModule): 636 expansion = 1 637 __constants__ = ["downsample"] 638 639 def __init__(self, inplanes, planes, stride=1, downsample=None): 640 super().__init__() 641 self.conv1 = conv3x3(inplanes, planes, stride) 642 self.bn1 = nn.BatchNorm2d(planes) 643 self.relu = nn.ReLU(inplace=True) 644 self.conv2 = conv3x3(planes, planes) 645 self.bn2 = nn.BatchNorm2d(planes) 646 self.downsample = downsample 647 self.stride = stride 648 649 @torch.jit.script_method 650 def forward(self, x): 651 residual = x 652 653 out = self.conv1(x) 654 out = self.bn1(out) 655 out = self.relu(out) 656 657 out = self.conv2(out) 658 out = self.bn2(out) 659 660 if self.downsample is not None: 661 residual = self.downsample(x) 662 663 out += residual 664 out = self.relu(out) 665 666 return out 667 668 class ResNet(torch.jit.ScriptModule): 669 __constants__ = ["layer1", "layer2", "layer3", "layer4"] 670 671 def __init__(self, block, layers, num_classes=1000): 672 super().__init__() 673 self.inplanes = 64 674 self.conv1 = nn.Conv2d( 675 3, 64, kernel_size=7, stride=2, padding=3, bias=False 676 ) 677 self.bn1 = nn.BatchNorm2d(64) 678 self.relu = nn.ReLU(inplace=True) 679 self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 680 self.layer1 = self._make_layer(block, 64, layers[0]) 681 self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 682 self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 683 self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 684 self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 685 self.fc = nn.Linear(512 * block.expansion, num_classes) 686 687 for m in self.modules(): 688 if isinstance(m, nn.Conv2d): 689 nn.init.kaiming_normal_( 690 m.weight, mode="fan_out", nonlinearity="relu" 691 ) 692 elif isinstance(m, nn.BatchNorm2d): 693 nn.init.constant_(m.weight, 1) 694 nn.init.constant_(m.bias, 0) 695 696 def _make_layer(self, block, planes, blocks, stride=1): 697 downsample = None 698 if stride != 1 or self.inplanes != planes * block.expansion: 699 downsample = nn.Sequential( 700 conv1x1(self.inplanes, planes * block.expansion, stride), 701 nn.BatchNorm2d(planes * block.expansion), 702 ) 703 704 layers = [] 705 layers.append(block(self.inplanes, planes, stride, downsample)) 706 self.inplanes = planes * block.expansion 707 for _ in range(1, blocks): 708 layers.append(block(self.inplanes, planes)) 709 710 return nn.Sequential(*layers) 711 712 @torch.jit.script_method 713 def forward(self, x): 714 x = self.conv1(x) 715 x = self.bn1(x) 716 x = self.relu(x) 717 x = self.maxpool(x) 718 719 x = self.layer1(x) 720 x = self.layer2(x) 721 x = self.layer3(x) 722 x = self.layer4(x) 723 724 x = self.avgpool(x) 725 x = x.view(x.size(0), -1) 726 x = self.fc(x) 727 728 return x 729 730 resnet18 = ResNet(BasicBlock, [2, 2, 2, 2]) 731 732 resnet18_imported = self.getExportImportCopy(resnet18) 733 734 input = torch.randn(1, 3, 224, 224, requires_grad=True) 735 output_orig = resnet18(input) 736 output_orig.sum().backward() 737 grad_orig = input.grad.clone() 738 input.grad.zero_() 739 output_import = resnet18_imported(input) 740 output_import.sum().backward() 741 grad_import = input.grad.clone() 742 743 self.assertEqual(output_orig, output_import) 744 self.assertEqual(grad_orig, grad_import) 745 746 @skipIfNoTorchVision 747 def test_alexnet(self): 748 x = torch.ones(1, 3, 224, 224) 749 model = torchvision.models.AlexNet() 750 with torch.random.fork_rng(devices=[]): 751 g, outputs, inputs = torch.jit._get_trace_graph( 752 model, x, return_inputs=True 753 ) 754 self.run_pass("cse", g) 755 m = self.createFunctionFromGraph(g) 756 with torch.random.fork_rng(devices=[]): 757 self.assertEqual(outputs, m(*inputs)) 758