xref: /aosp_15_r20/external/pytorch/test/jit/test_models.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import os
4import sys
5import unittest
6
7import torch
8import torch.nn as nn
9import torch.nn.functional as F
10from torch.testing._internal.common_utils import (
11    enable_profiling_mode_for_profiling_tests,
12    GRAPH_EXECUTOR,
13    ProfilingMode,
14    set_default_dtype,
15)
16
17
18# Make the helper files in test/ importable
19pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
20sys.path.append(pytorch_test_dir)
21from torch.testing._internal.common_utils import slowTest, suppress_warnings
22from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA
23
24
25if __name__ == "__main__":
26    raise RuntimeError(
27        "This test file is not meant to be run directly, use:\n\n"
28        "\tpython test/test_jit.py TESTNAME\n\n"
29        "instead."
30    )
31
32try:
33    import torchvision
34
35    HAS_TORCHVISION = True
36except ImportError:
37    HAS_TORCHVISION = False
38except RuntimeError:
39    HAS_TORCHVISION = False
40skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
41
42
43class MnistNet(nn.Module):
44    def __init__(self) -> None:
45        super().__init__()
46        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
47        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
48        self.conv2_drop = nn.Dropout2d()
49        self.fc1 = nn.Linear(320, 50)
50        self.fc2 = nn.Linear(50, 10)
51
52    def forward(self, x):
53        x = F.relu(F.max_pool2d(self.conv1(x), 2))
54        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
55        x = x.reshape(-1, 320)
56        x = F.relu(self.fc1(x))
57        x = F.dropout(x, training=self.training)
58        x = self.fc2(x)
59        return F.log_softmax(x, dim=1)
60
61
62class TestModels(JitTestCase):
63    @staticmethod
64    def _test_dcgan_models(self, device, check_export_import=True):
65        class DCGANGenerator(nn.Module):
66            def __init__(self, nz, ngf, nc):
67                super().__init__()
68                self.main = nn.Sequential(
69                    # input is Z, going into a convolution
70                    nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
71                    nn.BatchNorm2d(ngf * 8),
72                    nn.ReLU(True),
73                    # state size. (ngf*8) x 4 x 4
74                    nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
75                    nn.BatchNorm2d(ngf * 4),
76                    nn.ReLU(True),
77                    # state size. (ngf*4) x 8 x 8
78                    nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
79                    nn.BatchNorm2d(ngf * 2),
80                    nn.ReLU(True),
81                    # state size. (ngf*2) x 16 x 16
82                    nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
83                    nn.BatchNorm2d(ngf),
84                    nn.ReLU(True),
85                    # state size. (ngf) x 32 x 32
86                    nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
87                    nn.Tanh()
88                    # state size. (nc) x 64 x 64
89                )
90
91            def forward(self, input):
92                return self.main(input)
93
94        class DCGANDiscriminator(nn.Module):
95            def __init__(self, nc, ndf):
96                super().__init__()
97                self.main = nn.Sequential(
98                    # input is (nc) x 64 x 64
99                    nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
100                    nn.LeakyReLU(0.2, inplace=True),
101                    # state size. (ndf) x 32 x 32
102                    nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
103                    nn.BatchNorm2d(ndf * 2),
104                    nn.LeakyReLU(0.2, inplace=True),
105                    # state size. (ndf*2) x 16 x 16
106                    nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
107                    nn.BatchNorm2d(ndf * 4),
108                    nn.LeakyReLU(0.2, inplace=True),
109                    # state size. (ndf*4) x 8 x 8
110                    nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
111                    nn.BatchNorm2d(ndf * 8),
112                    nn.LeakyReLU(0.2, inplace=True),
113                    # state size. (ndf*8) x 4 x 4
114                    nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
115                    nn.Sigmoid(),
116                )
117
118            def forward(self, input):
119                return self.main(input).view(-1, 1).squeeze(1)
120
121        bs, nz, ngf, nc, ndf = 5, 6, 9, 3, 10
122        self.checkTrace(
123            DCGANGenerator(nz, ngf, nc).to(device),
124            (torch.rand(bs, nz, 1, 1, device=device),),
125            export_import=check_export_import,
126        )
127        example_input = DCGANGenerator(nz, ngf, nc).to(device)(
128            torch.rand(bs, nz, 1, 1, device=device)
129        )
130        self.checkTrace(
131            DCGANDiscriminator(nc, ndf).to(device),
132            (example_input,),
133            export_import=check_export_import,
134        )
135
136    def test_dcgan_models(self):
137        # Note: Can sometimes fail with low precision if run with float dtype
138        with set_default_dtype(torch.double):
139            self._test_dcgan_models(self, device="cpu")
140
141    @unittest.skipIf(not RUN_CUDA, "no CUDA")
142    def test_dcgan_models_cuda(self):
143        # Note: Can sometimes fail with low precision if run with float dtype
144        with set_default_dtype(torch.double):
145            # XXX: export_import on CUDA modules doesn't work (#11480)
146            self._test_dcgan_models(self, device="cuda", check_export_import=False)
147
148    @staticmethod
149    def _test_neural_style(self, device, check_export_import=True):
150        class TransformerNet(torch.nn.Module):
151            def __init__(self) -> None:
152                super().__init__()
153                # Initial convolution layers
154                self.conv1 = ConvLayer(3, 32, kernel_size=9, stride=1)
155                self.in1 = torch.nn.InstanceNorm2d(32, affine=True)
156                self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2)
157                self.in2 = torch.nn.InstanceNorm2d(64, affine=True)
158                self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2)
159                self.in3 = torch.nn.InstanceNorm2d(128, affine=True)
160                # Residual layers
161                self.res1 = ResidualBlock(128)
162                self.res2 = ResidualBlock(128)
163                self.res3 = ResidualBlock(128)
164                self.res4 = ResidualBlock(128)
165                self.res5 = ResidualBlock(128)
166                # Upsampling Layers
167                self.deconv1 = UpsampleConvLayer(
168                    128, 64, kernel_size=3, stride=1, upsample=2
169                )
170                self.in4 = torch.nn.InstanceNorm2d(64, affine=True)
171                self.deconv2 = UpsampleConvLayer(
172                    64, 32, kernel_size=3, stride=1, upsample=2
173                )
174                self.in5 = torch.nn.InstanceNorm2d(32, affine=True)
175                self.deconv3 = ConvLayer(32, 3, kernel_size=9, stride=1)
176                # Non-linearities
177                self.relu = torch.nn.ReLU()
178
179            def forward(self, X):
180                y = self.relu(self.in1(self.conv1(X)))
181                y = self.relu(self.in2(self.conv2(y)))
182                y = self.relu(self.in3(self.conv3(y)))
183                y = self.res1(y)
184                y = self.res2(y)
185                y = self.res3(y)
186                y = self.res4(y)
187                y = self.res5(y)
188                y = self.relu(self.in4(self.deconv1(y)))
189                y = self.relu(self.in5(self.deconv2(y)))
190                y = self.deconv3(y)
191                return y
192
193        class ConvLayer(torch.nn.Module):
194            def __init__(self, in_channels, out_channels, kernel_size, stride):
195                super().__init__()
196                reflection_padding = kernel_size // 2
197                self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
198                self.conv2d = torch.nn.Conv2d(
199                    in_channels, out_channels, kernel_size, stride
200                )
201
202            def forward(self, x):
203                out = self.reflection_pad(x)
204                out = self.conv2d(out)
205                return out
206
207        class ResidualBlock(torch.nn.Module):
208            """ResidualBlock
209            introduced in: https://arxiv.org/abs/1512.03385
210            recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html
211            """
212
213            def __init__(self, channels):
214                super().__init__()
215                self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
216                self.in1 = torch.nn.InstanceNorm2d(channels, affine=True)
217                self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
218                self.in2 = torch.nn.InstanceNorm2d(channels, affine=True)
219                self.relu = torch.nn.ReLU()
220
221            def forward(self, x):
222                residual = x
223                out = self.relu(self.in1(self.conv1(x)))
224                out = self.in2(self.conv2(out))
225                out = out + residual
226                return out
227
228        class UpsampleConvLayer(torch.nn.Module):
229            """UpsampleConvLayer
230            Upsamples the input and then does a convolution. This method gives better results
231            compared to ConvTranspose2d.
232            ref: http://distill.pub/2016/deconv-checkerboard/
233            """
234
235            def __init__(
236                self, in_channels, out_channels, kernel_size, stride, upsample=None
237            ):
238                super().__init__()
239                self.upsample = upsample
240                if upsample:
241                    self.upsample_layer = torch.nn.Upsample(
242                        mode="nearest", scale_factor=upsample
243                    )
244                reflection_padding = kernel_size // 2
245                self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
246                self.conv2d = torch.nn.Conv2d(
247                    in_channels, out_channels, kernel_size, stride
248                )
249
250            def forward(self, x):
251                x_in = x
252                if self.upsample:
253                    x_in = self.upsample_layer(x_in)
254                out = self.reflection_pad(x_in)
255                out = self.conv2d(out)
256                return out
257
258        self.checkTrace(
259            TransformerNet(),
260            (torch.rand(5, 3, 16, 16),),
261            export_import=check_export_import,
262        )
263
264    @slowTest
265    def test_neural_style(self):
266        self._test_neural_style(self, device="cpu")
267
268    @unittest.skipIf(not RUN_CUDA, "no CUDA")
269    def test_neural_style_cuda(self):
270        # XXX: export_import on CUDA modules doesn't work (#11480)
271        self._test_neural_style(self, device="cuda", check_export_import=False)
272
273    @unittest.skipIf(
274        GRAPH_EXECUTOR == ProfilingMode.LEGACY, "Bug found in deprecated executor"
275    )
276    @staticmethod
277    def _test_mnist(self, device, check_export_import=True):
278        # eval() is present because dropout makes this nondeterministic
279        with enable_profiling_mode_for_profiling_tests():
280            self.checkTrace(
281                MnistNet().to(device).eval(),
282                (torch.rand(5, 1, 28, 28, device=device),),
283                export_import=check_export_import,
284            )
285
286    def test_mnist(self):
287        self._test_mnist(self, device="cpu")
288
289    @unittest.skipIf(not RUN_CUDA, "no CUDA")
290    def test_mnist_cuda(self):
291        # XXX: export_import on CUDA modules doesn't work (#11480)
292        self._test_mnist(self, device="cuda", check_export_import=False)
293
294    @unittest.skipIf(not RUN_CUDA, "no CUDA")
295    def test_mnist_training_leaks_no_memory_cuda(self):
296        net = MnistNet().cuda()
297        # MnistNet uses dropout, don't check its trace
298        traced_net = torch.jit.trace(
299            net, [torch.randn(5, 1, 28, 28, device="cuda")], check_trace=False
300        )
301
302        def train(iters):
303            for _ in range(iters):
304                # Get some fake data
305                inp = torch.randn(5, 1, 28, 28, device="cuda")
306                out = traced_net(inp)
307
308                # Here's some fake loss
309                out.sum().backward()
310
311                # Zero out grads
312                traced_net.zero_grad()
313
314        # Set it up so the params have .grad fields so they are not reported as leaks
315        train(1)
316
317        with self.assertLeaksNoCudaTensors():
318            train(5)
319
320    @staticmethod
321    def _test_reinforcement_learning(self, device, test_export_import=True):
322        class Policy(nn.Module):
323            def __init__(self) -> None:
324                super().__init__()
325                self.affine1 = nn.Linear(4, 128)
326                self.affine2 = nn.Linear(128, 2)
327
328            def forward(self, x):
329                x = F.relu(self.affine1(x))
330                action_scores = self.affine2(x)
331                return F.softmax(action_scores, dim=1)
332
333        with enable_profiling_mode_for_profiling_tests():
334            self.checkTrace(
335                Policy().to(device),
336                (torch.rand(1, 4, device=device),),
337                export_import=test_export_import,
338            )
339
340    def test_reinforcement_learning(self):
341        self._test_reinforcement_learning(self, device="cpu")
342
343    @unittest.skipIf(not RUN_CUDA, "no CUDA")
344    def test_reinforcement_learning_cuda(self):
345        # XXX: export_import on CUDA modules doesn't work (#11480)
346        self._test_reinforcement_learning(self, device="cuda", test_export_import=False)
347
348    @staticmethod
349    def _test_snli(self, device, check_export_import=True):
350        class Bottle(nn.Module):
351            def forward(self, input):
352                if len(input.size()) <= 2:
353                    return super().forward(input)
354                size = input.size()[:2]
355                out = super().forward(input.view(size[0] * size[1], -1))
356                return out.view(size[0], size[1], -1)
357
358        class Linear(Bottle, nn.Linear):
359            pass
360
361        class Encoder(nn.Module):
362            def __init__(self, config):
363                super().__init__()
364                self.config = config
365                input_size = config.d_proj if config.projection else config.d_embed
366                dropout = 0 if config.n_layers == 1 else config.dp_ratio
367                self.rnn = nn.LSTM(
368                    input_size=input_size,
369                    hidden_size=config.d_hidden,
370                    num_layers=config.n_layers,
371                    dropout=dropout,
372                    bidirectional=config.birnn,
373                )
374
375            def forward(self, inputs):
376                batch_size = inputs.size()[1]
377                state_shape = self.config.n_cells, batch_size, self.config.d_hidden
378                h0 = c0 = inputs.new_zeros(state_shape)
379                outputs, (ht, ct) = self.rnn(inputs, (h0, c0))
380                return (
381                    ht[-1]
382                    if not self.config.birnn
383                    else ht[-2:].transpose(0, 1).contiguous().view(batch_size, -1)
384                )
385
386        class SNLIClassifier(nn.Module):
387            def __init__(self, config):
388                super().__init__()
389                self.config = config
390                self.embed = nn.Embedding(config.n_embed, config.d_embed)
391                self.projection = Linear(config.d_embed, config.d_proj)
392                self.encoder = Encoder(config)
393                self.dropout = nn.Dropout(p=config.dp_ratio)
394                self.relu = nn.ReLU()
395                seq_in_size = 2 * config.d_hidden
396                if self.config.birnn:
397                    seq_in_size *= 2
398                lin_config = [seq_in_size] * 2
399                self.out = nn.Sequential(
400                    Linear(*lin_config),
401                    self.relu,
402                    self.dropout,
403                    Linear(*lin_config),
404                    self.relu,
405                    self.dropout,
406                    Linear(*lin_config),
407                    self.relu,
408                    self.dropout,
409                    Linear(seq_in_size, config.d_out),
410                )
411
412            def forward(self, premise, hypothesis):
413                prem_embed = self.embed(premise)
414                hypo_embed = self.embed(hypothesis)
415                if self.config.fix_emb:
416                    prem_embed = prem_embed.detach()
417                    hypo_embed = hypo_embed.detach()
418                if self.config.projection:
419                    prem_embed = self.relu(self.projection(prem_embed))
420                    hypo_embed = self.relu(self.projection(hypo_embed))
421                premise = self.encoder(prem_embed)
422                hypothesis = self.encoder(hypo_embed)
423                scores = self.out(torch.cat([premise, hypothesis], 1))
424                return scores
425
426        class Config:
427            n_embed = 100
428            d_embed = 100
429            d_proj = 300
430            dp_ratio = 0.0  # For deterministic testing TODO: change by fixing seed in checkTrace?
431            d_hidden = 30
432            birnn = True
433            d_out = 300
434            fix_emb = True
435            projection = True
436            n_layers = 2
437            n_cells = 4  # 2 * n_layers because birnn = True
438
439        premise = torch.LongTensor(48, 64).random_(0, 100).to(device)
440        hypothesis = torch.LongTensor(24, 64).random_(0, 100).to(device)
441
442        self.checkTrace(
443            SNLIClassifier(Config()).to(device),
444            (premise, hypothesis),
445            inputs_require_grads=False,
446            export_import=check_export_import,
447        )
448
449    @slowTest
450    def test_snli(self):
451        self._test_snli(self, device="cpu")
452
453    @unittest.skipIf(not RUN_CUDA, "no CUDA")
454    def test_snli_cuda(self):
455        # XXX: export_import on CUDA modules doesn't work (#11480)
456        self._test_snli(self, device="cuda", check_export_import=False)
457
458    @staticmethod
459    def _test_super_resolution(self, device, check_export_import=True):
460        class Net(nn.Module):
461            def __init__(self, upscale_factor):
462                super().__init__()
463
464                self.relu = nn.ReLU()
465                self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
466                self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
467                self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
468                self.conv4 = nn.Conv2d(32, upscale_factor**2, (3, 3), (1, 1), (1, 1))
469                self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
470
471            def forward(self, x):
472                x = self.relu(self.conv1(x))
473                x = self.relu(self.conv2(x))
474                x = self.relu(self.conv3(x))
475                x = self.pixel_shuffle(self.conv4(x))
476                return x
477
478        net = Net(upscale_factor=4).to(device)
479        self.checkTrace(
480            net,
481            (torch.rand(5, 1, 32, 32, device=device),),
482            export_import=check_export_import,
483        )
484
485    @slowTest
486    def test_super_resolution(self):
487        self._test_super_resolution(self, device="cpu")
488
489    @unittest.skipIf(not RUN_CUDA, "no CUDA")
490    def test_super_resolution_cuda(self):
491        # XXX: export_import on CUDA modules doesn't work (#11480)
492        self._test_super_resolution(self, device="cuda", check_export_import=False)
493
494    @suppress_warnings
495    def test_time_sequence_prediction(self):
496        class Sequence(torch.jit.ScriptModule):
497            def __init__(self) -> None:
498                super().__init__()
499                self.lstm1 = nn.LSTMCell(1, 51)
500                self.lstm2 = nn.LSTMCell(51, 51)
501                self.linear = nn.Linear(51, 1)
502
503            @torch.jit.script_method
504            def forward(self, input):
505                # TODO: add future as input with default val
506                # see https://github.com/pytorch/pytorch/issues/8724
507                outputs = torch.empty((3, 0))
508                h_t = torch.zeros((3, 51))
509                c_t = torch.zeros((3, 51))
510                h_t2 = torch.zeros((3, 51))
511                c_t2 = torch.zeros((3, 51))
512
513                output = torch.zeros([3, 51])
514                future = 2
515
516                # TODO: chunk call should appear as the for loop iterable
517                # We hard-code it to 4 for now.
518                a, b, c, d = input.chunk(input.size(1), dim=1)
519                for input_t in (a, b, c, d):
520                    h_t, c_t = self.lstm1(input_t, (h_t, c_t))
521                    h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))
522                    output = self.linear(h_t2)
523                    outputs = torch.cat((outputs, output), 1)
524                for _ in range(future):  # if we should predict the future
525                    h_t, c_t = self.lstm1(output, (h_t, c_t))
526                    h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))
527                    output = self.linear(h_t2)
528                    outputs = torch.cat((outputs, output), 1)
529                return outputs
530
531        class Traced(nn.Module):
532            def __init__(self) -> None:
533                super().__init__()
534                self.seq = Sequence()
535
536            def forward(self, input):
537                return self.seq.forward(input)
538
539        # disabled due to a jitter issues that will be fixed by using load/store in the compiler
540        with torch._jit_internal._disable_emit_hooks():
541            # TODO: toggle export_import once above issues are fixed
542            self.checkTrace(Traced(), (torch.rand(3, 4),), export_import=False)
543
544    @staticmethod
545    def _test_vae(self, device, check_export_import=True):
546        class VAE(nn.Module):
547            def __init__(self) -> None:
548                super().__init__()
549
550                self.fc1 = nn.Linear(784, 400)
551                self.fc21 = nn.Linear(400, 20)
552                self.fc22 = nn.Linear(400, 20)
553                self.fc3 = nn.Linear(20, 400)
554                self.fc4 = nn.Linear(400, 784)
555
556            def encode(self, x):
557                h1 = F.relu(self.fc1(x))
558                return self.fc21(h1), self.fc22(h1)
559
560            def reparameterize(self, mu, logvar):
561                if self.training:
562                    std = torch.exp(0.5 * logvar)
563                    eps = torch.randn_like(std)
564                    return eps.mul(std).add_(mu)
565                else:
566                    return mu
567
568            def decode(self, z):
569                h3 = F.relu(self.fc3(z))
570                return torch.sigmoid(self.fc4(h3))
571
572            def forward(self, x):
573                mu, logvar = self.encode(x.view(-1, 784))
574                z = self.reparameterize(mu, logvar)
575                return self.decode(z), mu, logvar
576
577        with enable_profiling_mode_for_profiling_tests():
578            # eval() is present because randn_like makes this nondeterministic
579            self.checkTrace(
580                VAE().to(device).eval(),
581                (torch.rand(128, 1, 28, 28, device=device),),
582                export_import=check_export_import,
583            )
584
585    def test_vae(self):
586        self._test_vae(self, device="cpu")
587
588    @unittest.skipIf(not RUN_CUDA, "no CUDA")
589    def test_vae_cuda(self):
590        # XXX: export_import on CUDA modules doesn't work (#11480)
591        self._test_vae(self, device="cuda", check_export_import=False)
592
593    @slowTest
594    @skipIfNoTorchVision
595    def test_script_module_trace_resnet18(self):
596        x = torch.ones(1, 3, 224, 224)
597        m_orig = torch.jit.trace(
598            torchvision.models.resnet18(), torch.ones(1, 3, 224, 224)
599        )
600        m_import = self.getExportImportCopy(m_orig)
601
602        input = torch.randn(1, 3, 224, 224, requires_grad=True)
603        output_orig = m_orig(input)
604        output_orig.sum().backward()
605        grad_orig = input.grad.clone()
606        input.grad.zero_()
607
608        output_import = m_import(input)
609        output_import.sum().backward()
610        grad_import = input.grad.clone()
611
612        self.assertEqual(output_orig, output_import)
613        self.assertEqual(grad_orig, grad_import)
614
615    @slowTest
616    @skipIfNoTorchVision
617    def test_script_module_script_resnet(self):
618        def conv1x1(in_planes, out_planes, stride=1):
619            """1x1 convolution"""
620            return nn.Conv2d(
621                in_planes, out_planes, kernel_size=1, stride=stride, bias=False
622            )
623
624        def conv3x3(in_planes, out_planes, stride=1):
625            """3x3 convolution with padding"""
626            return nn.Conv2d(
627                in_planes,
628                out_planes,
629                kernel_size=3,
630                stride=stride,
631                padding=1,
632                bias=False,
633            )
634
635        class BasicBlock(torch.jit.ScriptModule):
636            expansion = 1
637            __constants__ = ["downsample"]
638
639            def __init__(self, inplanes, planes, stride=1, downsample=None):
640                super().__init__()
641                self.conv1 = conv3x3(inplanes, planes, stride)
642                self.bn1 = nn.BatchNorm2d(planes)
643                self.relu = nn.ReLU(inplace=True)
644                self.conv2 = conv3x3(planes, planes)
645                self.bn2 = nn.BatchNorm2d(planes)
646                self.downsample = downsample
647                self.stride = stride
648
649            @torch.jit.script_method
650            def forward(self, x):
651                residual = x
652
653                out = self.conv1(x)
654                out = self.bn1(out)
655                out = self.relu(out)
656
657                out = self.conv2(out)
658                out = self.bn2(out)
659
660                if self.downsample is not None:
661                    residual = self.downsample(x)
662
663                out += residual
664                out = self.relu(out)
665
666                return out
667
668        class ResNet(torch.jit.ScriptModule):
669            __constants__ = ["layer1", "layer2", "layer3", "layer4"]
670
671            def __init__(self, block, layers, num_classes=1000):
672                super().__init__()
673                self.inplanes = 64
674                self.conv1 = nn.Conv2d(
675                    3, 64, kernel_size=7, stride=2, padding=3, bias=False
676                )
677                self.bn1 = nn.BatchNorm2d(64)
678                self.relu = nn.ReLU(inplace=True)
679                self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
680                self.layer1 = self._make_layer(block, 64, layers[0])
681                self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
682                self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
683                self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
684                self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
685                self.fc = nn.Linear(512 * block.expansion, num_classes)
686
687                for m in self.modules():
688                    if isinstance(m, nn.Conv2d):
689                        nn.init.kaiming_normal_(
690                            m.weight, mode="fan_out", nonlinearity="relu"
691                        )
692                    elif isinstance(m, nn.BatchNorm2d):
693                        nn.init.constant_(m.weight, 1)
694                        nn.init.constant_(m.bias, 0)
695
696            def _make_layer(self, block, planes, blocks, stride=1):
697                downsample = None
698                if stride != 1 or self.inplanes != planes * block.expansion:
699                    downsample = nn.Sequential(
700                        conv1x1(self.inplanes, planes * block.expansion, stride),
701                        nn.BatchNorm2d(planes * block.expansion),
702                    )
703
704                layers = []
705                layers.append(block(self.inplanes, planes, stride, downsample))
706                self.inplanes = planes * block.expansion
707                for _ in range(1, blocks):
708                    layers.append(block(self.inplanes, planes))
709
710                return nn.Sequential(*layers)
711
712            @torch.jit.script_method
713            def forward(self, x):
714                x = self.conv1(x)
715                x = self.bn1(x)
716                x = self.relu(x)
717                x = self.maxpool(x)
718
719                x = self.layer1(x)
720                x = self.layer2(x)
721                x = self.layer3(x)
722                x = self.layer4(x)
723
724                x = self.avgpool(x)
725                x = x.view(x.size(0), -1)
726                x = self.fc(x)
727
728                return x
729
730        resnet18 = ResNet(BasicBlock, [2, 2, 2, 2])
731
732        resnet18_imported = self.getExportImportCopy(resnet18)
733
734        input = torch.randn(1, 3, 224, 224, requires_grad=True)
735        output_orig = resnet18(input)
736        output_orig.sum().backward()
737        grad_orig = input.grad.clone()
738        input.grad.zero_()
739        output_import = resnet18_imported(input)
740        output_import.sum().backward()
741        grad_import = input.grad.clone()
742
743        self.assertEqual(output_orig, output_import)
744        self.assertEqual(grad_orig, grad_import)
745
746    @skipIfNoTorchVision
747    def test_alexnet(self):
748        x = torch.ones(1, 3, 224, 224)
749        model = torchvision.models.AlexNet()
750        with torch.random.fork_rng(devices=[]):
751            g, outputs, inputs = torch.jit._get_trace_graph(
752                model, x, return_inputs=True
753            )
754        self.run_pass("cse", g)
755        m = self.createFunctionFromGraph(g)
756        with torch.random.fork_rng(devices=[]):
757            self.assertEqual(outputs, m(*inputs))
758