1# Owner(s): ["module: onnx"] 2 3import unittest 4 5import pytorch_test_common 6from model_defs.dcgan import _netD, _netG, bsz, imgsz, nz, weights_init 7from model_defs.emb_seq import EmbeddingNetwork1, EmbeddingNetwork2 8from model_defs.mnist import MNIST 9from model_defs.op_test import ConcatNet, DummyNet, FakeQuantNet, PermuteNet, PReluNet 10from model_defs.squeezenet import SqueezeNet 11from model_defs.srresnet import SRResNet 12from model_defs.super_resolution import SuperResolutionNet 13from pytorch_test_common import skipIfUnsupportedMinOpsetVersion, skipScriptTest 14from torchvision.models import shufflenet_v2_x1_0 15from torchvision.models.alexnet import alexnet 16from torchvision.models.densenet import densenet121 17from torchvision.models.googlenet import googlenet 18from torchvision.models.inception import inception_v3 19from torchvision.models.mnasnet import mnasnet1_0 20from torchvision.models.mobilenet import mobilenet_v2 21from torchvision.models.resnet import resnet50 22from torchvision.models.segmentation import deeplabv3_resnet101, fcn_resnet101 23from torchvision.models.vgg import vgg16, vgg16_bn, vgg19, vgg19_bn 24from torchvision.models.video import mc3_18, r2plus1d_18, r3d_18 25from verify import verify 26 27import torch 28from torch.ao import quantization 29from torch.autograd import Variable 30from torch.onnx import OperatorExportTypes 31from torch.testing._internal import common_utils 32from torch.testing._internal.common_utils import skipIfNoLapack 33 34 35if torch.cuda.is_available(): 36 37 def toC(x): 38 return x.cuda() 39 40else: 41 42 def toC(x): 43 return x 44 45 46BATCH_SIZE = 2 47 48 49class TestModels(pytorch_test_common.ExportTestCase): 50 opset_version = 9 # Caffe2 doesn't support the default. 51 keep_initializers_as_inputs = False 52 53 def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7, **kwargs): 54 import caffe2.python.onnx.backend as backend 55 56 with torch.onnx.select_model_mode_for_export( 57 model, torch.onnx.TrainingMode.EVAL 58 ): 59 graph = torch.onnx.utils._trace(model, inputs, OperatorExportTypes.ONNX) 60 torch._C._jit_pass_lint(graph) 61 verify( 62 model, 63 inputs, 64 backend, 65 rtol=rtol, 66 atol=atol, 67 opset_version=self.opset_version, 68 ) 69 70 def test_ops(self): 71 x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) 72 self.exportTest(toC(DummyNet()), toC(x)) 73 74 def test_prelu(self): 75 x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) 76 self.exportTest(PReluNet(), x) 77 78 @skipScriptTest() 79 def test_concat(self): 80 input_a = Variable(torch.randn(BATCH_SIZE, 3)) 81 input_b = Variable(torch.randn(BATCH_SIZE, 3)) 82 inputs = ((toC(input_a), toC(input_b)),) 83 self.exportTest(toC(ConcatNet()), inputs) 84 85 def test_permute(self): 86 x = Variable(torch.randn(BATCH_SIZE, 3, 10, 12)) 87 self.exportTest(PermuteNet(), x) 88 89 @skipScriptTest() 90 def test_embedding_sequential_1(self): 91 x = Variable(torch.randint(0, 10, (BATCH_SIZE, 3))) 92 self.exportTest(EmbeddingNetwork1(), x) 93 94 @skipScriptTest() 95 def test_embedding_sequential_2(self): 96 x = Variable(torch.randint(0, 10, (BATCH_SIZE, 3))) 97 self.exportTest(EmbeddingNetwork2(), x) 98 99 @unittest.skip("This model takes too much memory") 100 def test_srresnet(self): 101 x = Variable(torch.randn(1, 3, 224, 224).fill_(1.0)) 102 self.exportTest( 103 toC(SRResNet(rescale_factor=4, n_filters=64, n_blocks=8)), toC(x) 104 ) 105 106 @skipIfNoLapack 107 def test_super_resolution(self): 108 x = Variable(torch.randn(BATCH_SIZE, 1, 224, 224).fill_(1.0)) 109 self.exportTest(toC(SuperResolutionNet(upscale_factor=3)), toC(x), atol=1e-6) 110 111 def test_alexnet(self): 112 x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) 113 self.exportTest(toC(alexnet()), toC(x)) 114 115 def test_mnist(self): 116 x = Variable(torch.randn(BATCH_SIZE, 1, 28, 28).fill_(1.0)) 117 self.exportTest(toC(MNIST()), toC(x)) 118 119 @unittest.skip("This model takes too much memory") 120 def test_vgg16(self): 121 # VGG 16-layer model (configuration "D") 122 x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) 123 self.exportTest(toC(vgg16()), toC(x)) 124 125 @unittest.skip("This model takes too much memory") 126 def test_vgg16_bn(self): 127 # VGG 16-layer model (configuration "D") with batch normalization 128 x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) 129 self.exportTest(toC(vgg16_bn()), toC(x)) 130 131 @unittest.skip("This model takes too much memory") 132 def test_vgg19(self): 133 # VGG 19-layer model (configuration "E") 134 x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) 135 self.exportTest(toC(vgg19()), toC(x)) 136 137 @unittest.skip("This model takes too much memory") 138 def test_vgg19_bn(self): 139 # VGG 19-layer model (configuration "E") with batch normalization 140 x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) 141 self.exportTest(toC(vgg19_bn()), toC(x)) 142 143 def test_resnet(self): 144 # ResNet50 model 145 x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) 146 self.exportTest(toC(resnet50()), toC(x), atol=1e-6) 147 148 # This test is numerically unstable. Sporadic single element mismatch occurs occasionally. 149 def test_inception(self): 150 x = Variable(torch.randn(BATCH_SIZE, 3, 299, 299)) 151 self.exportTest(toC(inception_v3()), toC(x), acceptable_error_percentage=0.01) 152 153 def test_squeezenet(self): 154 # SqueezeNet: AlexNet-level accuracy with 50x fewer parameters and 155 # <0.5MB model size 156 x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) 157 sqnet_v1_0 = SqueezeNet(version=1.1) 158 self.exportTest(toC(sqnet_v1_0), toC(x)) 159 160 # SqueezeNet 1.1 has 2.4x less computation and slightly fewer params 161 # than SqueezeNet 1.0, without sacrificing accuracy. 162 x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) 163 sqnet_v1_1 = SqueezeNet(version=1.1) 164 self.exportTest(toC(sqnet_v1_1), toC(x)) 165 166 def test_densenet(self): 167 # Densenet-121 model 168 x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) 169 self.exportTest(toC(densenet121()), toC(x), rtol=1e-2, atol=1e-5) 170 171 @skipScriptTest() 172 def test_dcgan_netD(self): 173 netD = _netD(1) 174 netD.apply(weights_init) 175 input = Variable(torch.empty(bsz, 3, imgsz, imgsz).normal_(0, 1)) 176 self.exportTest(toC(netD), toC(input)) 177 178 @skipScriptTest() 179 def test_dcgan_netG(self): 180 netG = _netG(1) 181 netG.apply(weights_init) 182 input = Variable(torch.empty(bsz, nz, 1, 1).normal_(0, 1)) 183 self.exportTest(toC(netG), toC(input)) 184 185 @skipIfUnsupportedMinOpsetVersion(10) 186 def test_fake_quant(self): 187 x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) 188 self.exportTest(toC(FakeQuantNet()), toC(x)) 189 190 @skipIfUnsupportedMinOpsetVersion(10) 191 def test_qat_resnet_pertensor(self): 192 # Quantize ResNet50 model 193 x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) 194 qat_resnet50 = resnet50() 195 196 # Use per tensor for weight. Per channel support will come with opset 13 197 qat_resnet50.qconfig = quantization.QConfig( 198 activation=quantization.default_fake_quant, 199 weight=quantization.default_fake_quant, 200 ) 201 quantization.prepare_qat(qat_resnet50, inplace=True) 202 qat_resnet50.apply(torch.ao.quantization.enable_observer) 203 qat_resnet50.apply(torch.ao.quantization.enable_fake_quant) 204 205 _ = qat_resnet50(x) 206 for module in qat_resnet50.modules(): 207 if isinstance(module, quantization.FakeQuantize): 208 module.calculate_qparams() 209 qat_resnet50.apply(torch.ao.quantization.disable_observer) 210 211 self.exportTest(toC(qat_resnet50), toC(x)) 212 213 @skipIfUnsupportedMinOpsetVersion(13) 214 def test_qat_resnet_per_channel(self): 215 # Quantize ResNet50 model 216 x = torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0) 217 qat_resnet50 = resnet50() 218 219 qat_resnet50.qconfig = quantization.QConfig( 220 activation=quantization.default_fake_quant, 221 weight=quantization.default_per_channel_weight_fake_quant, 222 ) 223 quantization.prepare_qat(qat_resnet50, inplace=True) 224 qat_resnet50.apply(torch.ao.quantization.enable_observer) 225 qat_resnet50.apply(torch.ao.quantization.enable_fake_quant) 226 227 _ = qat_resnet50(x) 228 for module in qat_resnet50.modules(): 229 if isinstance(module, quantization.FakeQuantize): 230 module.calculate_qparams() 231 qat_resnet50.apply(torch.ao.quantization.disable_observer) 232 233 self.exportTest(toC(qat_resnet50), toC(x)) 234 235 @skipScriptTest(skip_before_opset_version=15, reason="None type in outputs") 236 def test_googlenet(self): 237 x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) 238 self.exportTest(toC(googlenet()), toC(x), rtol=1e-3, atol=1e-5) 239 240 def test_mnasnet(self): 241 x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) 242 self.exportTest(toC(mnasnet1_0()), toC(x), rtol=1e-3, atol=1e-5) 243 244 def test_mobilenet(self): 245 x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) 246 self.exportTest(toC(mobilenet_v2()), toC(x), rtol=1e-3, atol=1e-5) 247 248 @skipScriptTest() # prim_data 249 def test_shufflenet(self): 250 x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) 251 self.exportTest(toC(shufflenet_v2_x1_0()), toC(x), rtol=1e-3, atol=1e-5) 252 253 @skipIfUnsupportedMinOpsetVersion(11) 254 def test_fcn(self): 255 x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) 256 self.exportTest( 257 toC(fcn_resnet101(weights=None, weights_backbone=None)), 258 toC(x), 259 rtol=1e-3, 260 atol=1e-5, 261 ) 262 263 @skipIfUnsupportedMinOpsetVersion(11) 264 def test_deeplab(self): 265 x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) 266 self.exportTest( 267 toC(deeplabv3_resnet101(weights=None, weights_backbone=None)), 268 toC(x), 269 rtol=1e-3, 270 atol=1e-5, 271 ) 272 273 def test_r3d_18_video(self): 274 x = Variable(torch.randn(1, 3, 4, 112, 112).fill_(1.0)) 275 self.exportTest(toC(r3d_18()), toC(x), rtol=1e-3, atol=1e-5) 276 277 def test_mc3_18_video(self): 278 x = Variable(torch.randn(1, 3, 4, 112, 112).fill_(1.0)) 279 self.exportTest(toC(mc3_18()), toC(x), rtol=1e-3, atol=1e-5) 280 281 def test_r2plus1d_18_video(self): 282 x = Variable(torch.randn(1, 3, 4, 112, 112).fill_(1.0)) 283 self.exportTest(toC(r2plus1d_18()), toC(x), rtol=1e-3, atol=1e-5) 284 285 286if __name__ == "__main__": 287 common_utils.run_tests() 288