xref: /aosp_15_r20/external/pytorch/test/onnx/test_models.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: onnx"]
2
3import unittest
4
5import pytorch_test_common
6from model_defs.dcgan import _netD, _netG, bsz, imgsz, nz, weights_init
7from model_defs.emb_seq import EmbeddingNetwork1, EmbeddingNetwork2
8from model_defs.mnist import MNIST
9from model_defs.op_test import ConcatNet, DummyNet, FakeQuantNet, PermuteNet, PReluNet
10from model_defs.squeezenet import SqueezeNet
11from model_defs.srresnet import SRResNet
12from model_defs.super_resolution import SuperResolutionNet
13from pytorch_test_common import skipIfUnsupportedMinOpsetVersion, skipScriptTest
14from torchvision.models import shufflenet_v2_x1_0
15from torchvision.models.alexnet import alexnet
16from torchvision.models.densenet import densenet121
17from torchvision.models.googlenet import googlenet
18from torchvision.models.inception import inception_v3
19from torchvision.models.mnasnet import mnasnet1_0
20from torchvision.models.mobilenet import mobilenet_v2
21from torchvision.models.resnet import resnet50
22from torchvision.models.segmentation import deeplabv3_resnet101, fcn_resnet101
23from torchvision.models.vgg import vgg16, vgg16_bn, vgg19, vgg19_bn
24from torchvision.models.video import mc3_18, r2plus1d_18, r3d_18
25from verify import verify
26
27import torch
28from torch.ao import quantization
29from torch.autograd import Variable
30from torch.onnx import OperatorExportTypes
31from torch.testing._internal import common_utils
32from torch.testing._internal.common_utils import skipIfNoLapack
33
34
35if torch.cuda.is_available():
36
37    def toC(x):
38        return x.cuda()
39
40else:
41
42    def toC(x):
43        return x
44
45
46BATCH_SIZE = 2
47
48
49class TestModels(pytorch_test_common.ExportTestCase):
50    opset_version = 9  # Caffe2 doesn't support the default.
51    keep_initializers_as_inputs = False
52
53    def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7, **kwargs):
54        import caffe2.python.onnx.backend as backend
55
56        with torch.onnx.select_model_mode_for_export(
57            model, torch.onnx.TrainingMode.EVAL
58        ):
59            graph = torch.onnx.utils._trace(model, inputs, OperatorExportTypes.ONNX)
60            torch._C._jit_pass_lint(graph)
61            verify(
62                model,
63                inputs,
64                backend,
65                rtol=rtol,
66                atol=atol,
67                opset_version=self.opset_version,
68            )
69
70    def test_ops(self):
71        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
72        self.exportTest(toC(DummyNet()), toC(x))
73
74    def test_prelu(self):
75        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
76        self.exportTest(PReluNet(), x)
77
78    @skipScriptTest()
79    def test_concat(self):
80        input_a = Variable(torch.randn(BATCH_SIZE, 3))
81        input_b = Variable(torch.randn(BATCH_SIZE, 3))
82        inputs = ((toC(input_a), toC(input_b)),)
83        self.exportTest(toC(ConcatNet()), inputs)
84
85    def test_permute(self):
86        x = Variable(torch.randn(BATCH_SIZE, 3, 10, 12))
87        self.exportTest(PermuteNet(), x)
88
89    @skipScriptTest()
90    def test_embedding_sequential_1(self):
91        x = Variable(torch.randint(0, 10, (BATCH_SIZE, 3)))
92        self.exportTest(EmbeddingNetwork1(), x)
93
94    @skipScriptTest()
95    def test_embedding_sequential_2(self):
96        x = Variable(torch.randint(0, 10, (BATCH_SIZE, 3)))
97        self.exportTest(EmbeddingNetwork2(), x)
98
99    @unittest.skip("This model takes too much memory")
100    def test_srresnet(self):
101        x = Variable(torch.randn(1, 3, 224, 224).fill_(1.0))
102        self.exportTest(
103            toC(SRResNet(rescale_factor=4, n_filters=64, n_blocks=8)), toC(x)
104        )
105
106    @skipIfNoLapack
107    def test_super_resolution(self):
108        x = Variable(torch.randn(BATCH_SIZE, 1, 224, 224).fill_(1.0))
109        self.exportTest(toC(SuperResolutionNet(upscale_factor=3)), toC(x), atol=1e-6)
110
111    def test_alexnet(self):
112        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
113        self.exportTest(toC(alexnet()), toC(x))
114
115    def test_mnist(self):
116        x = Variable(torch.randn(BATCH_SIZE, 1, 28, 28).fill_(1.0))
117        self.exportTest(toC(MNIST()), toC(x))
118
119    @unittest.skip("This model takes too much memory")
120    def test_vgg16(self):
121        # VGG 16-layer model (configuration "D")
122        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
123        self.exportTest(toC(vgg16()), toC(x))
124
125    @unittest.skip("This model takes too much memory")
126    def test_vgg16_bn(self):
127        # VGG 16-layer model (configuration "D") with batch normalization
128        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
129        self.exportTest(toC(vgg16_bn()), toC(x))
130
131    @unittest.skip("This model takes too much memory")
132    def test_vgg19(self):
133        # VGG 19-layer model (configuration "E")
134        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
135        self.exportTest(toC(vgg19()), toC(x))
136
137    @unittest.skip("This model takes too much memory")
138    def test_vgg19_bn(self):
139        # VGG 19-layer model (configuration "E") with batch normalization
140        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
141        self.exportTest(toC(vgg19_bn()), toC(x))
142
143    def test_resnet(self):
144        # ResNet50 model
145        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
146        self.exportTest(toC(resnet50()), toC(x), atol=1e-6)
147
148    # This test is numerically unstable. Sporadic single element mismatch occurs occasionally.
149    def test_inception(self):
150        x = Variable(torch.randn(BATCH_SIZE, 3, 299, 299))
151        self.exportTest(toC(inception_v3()), toC(x), acceptable_error_percentage=0.01)
152
153    def test_squeezenet(self):
154        # SqueezeNet: AlexNet-level accuracy with 50x fewer parameters and
155        # <0.5MB model size
156        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
157        sqnet_v1_0 = SqueezeNet(version=1.1)
158        self.exportTest(toC(sqnet_v1_0), toC(x))
159
160        # SqueezeNet 1.1 has 2.4x less computation and slightly fewer params
161        # than SqueezeNet 1.0, without sacrificing accuracy.
162        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
163        sqnet_v1_1 = SqueezeNet(version=1.1)
164        self.exportTest(toC(sqnet_v1_1), toC(x))
165
166    def test_densenet(self):
167        # Densenet-121 model
168        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
169        self.exportTest(toC(densenet121()), toC(x), rtol=1e-2, atol=1e-5)
170
171    @skipScriptTest()
172    def test_dcgan_netD(self):
173        netD = _netD(1)
174        netD.apply(weights_init)
175        input = Variable(torch.empty(bsz, 3, imgsz, imgsz).normal_(0, 1))
176        self.exportTest(toC(netD), toC(input))
177
178    @skipScriptTest()
179    def test_dcgan_netG(self):
180        netG = _netG(1)
181        netG.apply(weights_init)
182        input = Variable(torch.empty(bsz, nz, 1, 1).normal_(0, 1))
183        self.exportTest(toC(netG), toC(input))
184
185    @skipIfUnsupportedMinOpsetVersion(10)
186    def test_fake_quant(self):
187        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
188        self.exportTest(toC(FakeQuantNet()), toC(x))
189
190    @skipIfUnsupportedMinOpsetVersion(10)
191    def test_qat_resnet_pertensor(self):
192        # Quantize ResNet50 model
193        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
194        qat_resnet50 = resnet50()
195
196        # Use per tensor for weight. Per channel support will come with opset 13
197        qat_resnet50.qconfig = quantization.QConfig(
198            activation=quantization.default_fake_quant,
199            weight=quantization.default_fake_quant,
200        )
201        quantization.prepare_qat(qat_resnet50, inplace=True)
202        qat_resnet50.apply(torch.ao.quantization.enable_observer)
203        qat_resnet50.apply(torch.ao.quantization.enable_fake_quant)
204
205        _ = qat_resnet50(x)
206        for module in qat_resnet50.modules():
207            if isinstance(module, quantization.FakeQuantize):
208                module.calculate_qparams()
209        qat_resnet50.apply(torch.ao.quantization.disable_observer)
210
211        self.exportTest(toC(qat_resnet50), toC(x))
212
213    @skipIfUnsupportedMinOpsetVersion(13)
214    def test_qat_resnet_per_channel(self):
215        # Quantize ResNet50 model
216        x = torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)
217        qat_resnet50 = resnet50()
218
219        qat_resnet50.qconfig = quantization.QConfig(
220            activation=quantization.default_fake_quant,
221            weight=quantization.default_per_channel_weight_fake_quant,
222        )
223        quantization.prepare_qat(qat_resnet50, inplace=True)
224        qat_resnet50.apply(torch.ao.quantization.enable_observer)
225        qat_resnet50.apply(torch.ao.quantization.enable_fake_quant)
226
227        _ = qat_resnet50(x)
228        for module in qat_resnet50.modules():
229            if isinstance(module, quantization.FakeQuantize):
230                module.calculate_qparams()
231        qat_resnet50.apply(torch.ao.quantization.disable_observer)
232
233        self.exportTest(toC(qat_resnet50), toC(x))
234
235    @skipScriptTest(skip_before_opset_version=15, reason="None type in outputs")
236    def test_googlenet(self):
237        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
238        self.exportTest(toC(googlenet()), toC(x), rtol=1e-3, atol=1e-5)
239
240    def test_mnasnet(self):
241        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
242        self.exportTest(toC(mnasnet1_0()), toC(x), rtol=1e-3, atol=1e-5)
243
244    def test_mobilenet(self):
245        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
246        self.exportTest(toC(mobilenet_v2()), toC(x), rtol=1e-3, atol=1e-5)
247
248    @skipScriptTest()  # prim_data
249    def test_shufflenet(self):
250        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
251        self.exportTest(toC(shufflenet_v2_x1_0()), toC(x), rtol=1e-3, atol=1e-5)
252
253    @skipIfUnsupportedMinOpsetVersion(11)
254    def test_fcn(self):
255        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
256        self.exportTest(
257            toC(fcn_resnet101(weights=None, weights_backbone=None)),
258            toC(x),
259            rtol=1e-3,
260            atol=1e-5,
261        )
262
263    @skipIfUnsupportedMinOpsetVersion(11)
264    def test_deeplab(self):
265        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
266        self.exportTest(
267            toC(deeplabv3_resnet101(weights=None, weights_backbone=None)),
268            toC(x),
269            rtol=1e-3,
270            atol=1e-5,
271        )
272
273    def test_r3d_18_video(self):
274        x = Variable(torch.randn(1, 3, 4, 112, 112).fill_(1.0))
275        self.exportTest(toC(r3d_18()), toC(x), rtol=1e-3, atol=1e-5)
276
277    def test_mc3_18_video(self):
278        x = Variable(torch.randn(1, 3, 4, 112, 112).fill_(1.0))
279        self.exportTest(toC(mc3_18()), toC(x), rtol=1e-3, atol=1e-5)
280
281    def test_r2plus1d_18_video(self):
282        x = Variable(torch.randn(1, 3, 4, 112, 112).fill_(1.0))
283        self.exportTest(toC(r2plus1d_18()), toC(x), rtol=1e-3, atol=1e-5)
284
285
286if __name__ == "__main__":
287    common_utils.run_tests()
288