1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: onnx"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport unittest 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerimport pytorch_test_common 6*da0073e9SAndroid Build Coastguard Workerfrom model_defs.dcgan import _netD, _netG, bsz, imgsz, nz, weights_init 7*da0073e9SAndroid Build Coastguard Workerfrom model_defs.emb_seq import EmbeddingNetwork1, EmbeddingNetwork2 8*da0073e9SAndroid Build Coastguard Workerfrom model_defs.mnist import MNIST 9*da0073e9SAndroid Build Coastguard Workerfrom model_defs.op_test import ConcatNet, DummyNet, FakeQuantNet, PermuteNet, PReluNet 10*da0073e9SAndroid Build Coastguard Workerfrom model_defs.squeezenet import SqueezeNet 11*da0073e9SAndroid Build Coastguard Workerfrom model_defs.srresnet import SRResNet 12*da0073e9SAndroid Build Coastguard Workerfrom model_defs.super_resolution import SuperResolutionNet 13*da0073e9SAndroid Build Coastguard Workerfrom pytorch_test_common import skipIfUnsupportedMinOpsetVersion, skipScriptTest 14*da0073e9SAndroid Build Coastguard Workerfrom torchvision.models import shufflenet_v2_x1_0 15*da0073e9SAndroid Build Coastguard Workerfrom torchvision.models.alexnet import alexnet 16*da0073e9SAndroid Build Coastguard Workerfrom torchvision.models.densenet import densenet121 17*da0073e9SAndroid Build Coastguard Workerfrom torchvision.models.googlenet import googlenet 18*da0073e9SAndroid Build Coastguard Workerfrom torchvision.models.inception import inception_v3 19*da0073e9SAndroid Build Coastguard Workerfrom torchvision.models.mnasnet import mnasnet1_0 20*da0073e9SAndroid Build Coastguard Workerfrom torchvision.models.mobilenet import mobilenet_v2 21*da0073e9SAndroid Build Coastguard Workerfrom torchvision.models.resnet import resnet50 22*da0073e9SAndroid Build Coastguard Workerfrom torchvision.models.segmentation import deeplabv3_resnet101, fcn_resnet101 23*da0073e9SAndroid Build Coastguard Workerfrom torchvision.models.vgg import vgg16, vgg16_bn, vgg19, vgg19_bn 24*da0073e9SAndroid Build Coastguard Workerfrom torchvision.models.video import mc3_18, r2plus1d_18, r3d_18 25*da0073e9SAndroid Build Coastguard Workerfrom verify import verify 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Workerimport torch 28*da0073e9SAndroid Build Coastguard Workerfrom torch.ao import quantization 29*da0073e9SAndroid Build Coastguard Workerfrom torch.autograd import Variable 30*da0073e9SAndroid Build Coastguard Workerfrom torch.onnx import OperatorExportTypes 31*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal import common_utils 32*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import skipIfNoLapack 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Workerif torch.cuda.is_available(): 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Worker def toC(x): 38*da0073e9SAndroid Build Coastguard Worker return x.cuda() 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Workerelse: 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Worker def toC(x): 43*da0073e9SAndroid Build Coastguard Worker return x 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard WorkerBATCH_SIZE = 2 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Workerclass TestModels(pytorch_test_common.ExportTestCase): 50*da0073e9SAndroid Build Coastguard Worker opset_version = 9 # Caffe2 doesn't support the default. 51*da0073e9SAndroid Build Coastguard Worker keep_initializers_as_inputs = False 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7, **kwargs): 54*da0073e9SAndroid Build Coastguard Worker import caffe2.python.onnx.backend as backend 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Worker with torch.onnx.select_model_mode_for_export( 57*da0073e9SAndroid Build Coastguard Worker model, torch.onnx.TrainingMode.EVAL 58*da0073e9SAndroid Build Coastguard Worker ): 59*da0073e9SAndroid Build Coastguard Worker graph = torch.onnx.utils._trace(model, inputs, OperatorExportTypes.ONNX) 60*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_lint(graph) 61*da0073e9SAndroid Build Coastguard Worker verify( 62*da0073e9SAndroid Build Coastguard Worker model, 63*da0073e9SAndroid Build Coastguard Worker inputs, 64*da0073e9SAndroid Build Coastguard Worker backend, 65*da0073e9SAndroid Build Coastguard Worker rtol=rtol, 66*da0073e9SAndroid Build Coastguard Worker atol=atol, 67*da0073e9SAndroid Build Coastguard Worker opset_version=self.opset_version, 68*da0073e9SAndroid Build Coastguard Worker ) 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker def test_ops(self): 71*da0073e9SAndroid Build Coastguard Worker x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) 72*da0073e9SAndroid Build Coastguard Worker self.exportTest(toC(DummyNet()), toC(x)) 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Worker def test_prelu(self): 75*da0073e9SAndroid Build Coastguard Worker x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) 76*da0073e9SAndroid Build Coastguard Worker self.exportTest(PReluNet(), x) 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker @skipScriptTest() 79*da0073e9SAndroid Build Coastguard Worker def test_concat(self): 80*da0073e9SAndroid Build Coastguard Worker input_a = Variable(torch.randn(BATCH_SIZE, 3)) 81*da0073e9SAndroid Build Coastguard Worker input_b = Variable(torch.randn(BATCH_SIZE, 3)) 82*da0073e9SAndroid Build Coastguard Worker inputs = ((toC(input_a), toC(input_b)),) 83*da0073e9SAndroid Build Coastguard Worker self.exportTest(toC(ConcatNet()), inputs) 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker def test_permute(self): 86*da0073e9SAndroid Build Coastguard Worker x = Variable(torch.randn(BATCH_SIZE, 3, 10, 12)) 87*da0073e9SAndroid Build Coastguard Worker self.exportTest(PermuteNet(), x) 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker @skipScriptTest() 90*da0073e9SAndroid Build Coastguard Worker def test_embedding_sequential_1(self): 91*da0073e9SAndroid Build Coastguard Worker x = Variable(torch.randint(0, 10, (BATCH_SIZE, 3))) 92*da0073e9SAndroid Build Coastguard Worker self.exportTest(EmbeddingNetwork1(), x) 93*da0073e9SAndroid Build Coastguard Worker 94*da0073e9SAndroid Build Coastguard Worker @skipScriptTest() 95*da0073e9SAndroid Build Coastguard Worker def test_embedding_sequential_2(self): 96*da0073e9SAndroid Build Coastguard Worker x = Variable(torch.randint(0, 10, (BATCH_SIZE, 3))) 97*da0073e9SAndroid Build Coastguard Worker self.exportTest(EmbeddingNetwork2(), x) 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard Worker @unittest.skip("This model takes too much memory") 100*da0073e9SAndroid Build Coastguard Worker def test_srresnet(self): 101*da0073e9SAndroid Build Coastguard Worker x = Variable(torch.randn(1, 3, 224, 224).fill_(1.0)) 102*da0073e9SAndroid Build Coastguard Worker self.exportTest( 103*da0073e9SAndroid Build Coastguard Worker toC(SRResNet(rescale_factor=4, n_filters=64, n_blocks=8)), toC(x) 104*da0073e9SAndroid Build Coastguard Worker ) 105*da0073e9SAndroid Build Coastguard Worker 106*da0073e9SAndroid Build Coastguard Worker @skipIfNoLapack 107*da0073e9SAndroid Build Coastguard Worker def test_super_resolution(self): 108*da0073e9SAndroid Build Coastguard Worker x = Variable(torch.randn(BATCH_SIZE, 1, 224, 224).fill_(1.0)) 109*da0073e9SAndroid Build Coastguard Worker self.exportTest(toC(SuperResolutionNet(upscale_factor=3)), toC(x), atol=1e-6) 110*da0073e9SAndroid Build Coastguard Worker 111*da0073e9SAndroid Build Coastguard Worker def test_alexnet(self): 112*da0073e9SAndroid Build Coastguard Worker x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) 113*da0073e9SAndroid Build Coastguard Worker self.exportTest(toC(alexnet()), toC(x)) 114*da0073e9SAndroid Build Coastguard Worker 115*da0073e9SAndroid Build Coastguard Worker def test_mnist(self): 116*da0073e9SAndroid Build Coastguard Worker x = Variable(torch.randn(BATCH_SIZE, 1, 28, 28).fill_(1.0)) 117*da0073e9SAndroid Build Coastguard Worker self.exportTest(toC(MNIST()), toC(x)) 118*da0073e9SAndroid Build Coastguard Worker 119*da0073e9SAndroid Build Coastguard Worker @unittest.skip("This model takes too much memory") 120*da0073e9SAndroid Build Coastguard Worker def test_vgg16(self): 121*da0073e9SAndroid Build Coastguard Worker # VGG 16-layer model (configuration "D") 122*da0073e9SAndroid Build Coastguard Worker x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) 123*da0073e9SAndroid Build Coastguard Worker self.exportTest(toC(vgg16()), toC(x)) 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard Worker @unittest.skip("This model takes too much memory") 126*da0073e9SAndroid Build Coastguard Worker def test_vgg16_bn(self): 127*da0073e9SAndroid Build Coastguard Worker # VGG 16-layer model (configuration "D") with batch normalization 128*da0073e9SAndroid Build Coastguard Worker x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) 129*da0073e9SAndroid Build Coastguard Worker self.exportTest(toC(vgg16_bn()), toC(x)) 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker @unittest.skip("This model takes too much memory") 132*da0073e9SAndroid Build Coastguard Worker def test_vgg19(self): 133*da0073e9SAndroid Build Coastguard Worker # VGG 19-layer model (configuration "E") 134*da0073e9SAndroid Build Coastguard Worker x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) 135*da0073e9SAndroid Build Coastguard Worker self.exportTest(toC(vgg19()), toC(x)) 136*da0073e9SAndroid Build Coastguard Worker 137*da0073e9SAndroid Build Coastguard Worker @unittest.skip("This model takes too much memory") 138*da0073e9SAndroid Build Coastguard Worker def test_vgg19_bn(self): 139*da0073e9SAndroid Build Coastguard Worker # VGG 19-layer model (configuration "E") with batch normalization 140*da0073e9SAndroid Build Coastguard Worker x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) 141*da0073e9SAndroid Build Coastguard Worker self.exportTest(toC(vgg19_bn()), toC(x)) 142*da0073e9SAndroid Build Coastguard Worker 143*da0073e9SAndroid Build Coastguard Worker def test_resnet(self): 144*da0073e9SAndroid Build Coastguard Worker # ResNet50 model 145*da0073e9SAndroid Build Coastguard Worker x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) 146*da0073e9SAndroid Build Coastguard Worker self.exportTest(toC(resnet50()), toC(x), atol=1e-6) 147*da0073e9SAndroid Build Coastguard Worker 148*da0073e9SAndroid Build Coastguard Worker # This test is numerically unstable. Sporadic single element mismatch occurs occasionally. 149*da0073e9SAndroid Build Coastguard Worker def test_inception(self): 150*da0073e9SAndroid Build Coastguard Worker x = Variable(torch.randn(BATCH_SIZE, 3, 299, 299)) 151*da0073e9SAndroid Build Coastguard Worker self.exportTest(toC(inception_v3()), toC(x), acceptable_error_percentage=0.01) 152*da0073e9SAndroid Build Coastguard Worker 153*da0073e9SAndroid Build Coastguard Worker def test_squeezenet(self): 154*da0073e9SAndroid Build Coastguard Worker # SqueezeNet: AlexNet-level accuracy with 50x fewer parameters and 155*da0073e9SAndroid Build Coastguard Worker # <0.5MB model size 156*da0073e9SAndroid Build Coastguard Worker x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) 157*da0073e9SAndroid Build Coastguard Worker sqnet_v1_0 = SqueezeNet(version=1.1) 158*da0073e9SAndroid Build Coastguard Worker self.exportTest(toC(sqnet_v1_0), toC(x)) 159*da0073e9SAndroid Build Coastguard Worker 160*da0073e9SAndroid Build Coastguard Worker # SqueezeNet 1.1 has 2.4x less computation and slightly fewer params 161*da0073e9SAndroid Build Coastguard Worker # than SqueezeNet 1.0, without sacrificing accuracy. 162*da0073e9SAndroid Build Coastguard Worker x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) 163*da0073e9SAndroid Build Coastguard Worker sqnet_v1_1 = SqueezeNet(version=1.1) 164*da0073e9SAndroid Build Coastguard Worker self.exportTest(toC(sqnet_v1_1), toC(x)) 165*da0073e9SAndroid Build Coastguard Worker 166*da0073e9SAndroid Build Coastguard Worker def test_densenet(self): 167*da0073e9SAndroid Build Coastguard Worker # Densenet-121 model 168*da0073e9SAndroid Build Coastguard Worker x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) 169*da0073e9SAndroid Build Coastguard Worker self.exportTest(toC(densenet121()), toC(x), rtol=1e-2, atol=1e-5) 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard Worker @skipScriptTest() 172*da0073e9SAndroid Build Coastguard Worker def test_dcgan_netD(self): 173*da0073e9SAndroid Build Coastguard Worker netD = _netD(1) 174*da0073e9SAndroid Build Coastguard Worker netD.apply(weights_init) 175*da0073e9SAndroid Build Coastguard Worker input = Variable(torch.empty(bsz, 3, imgsz, imgsz).normal_(0, 1)) 176*da0073e9SAndroid Build Coastguard Worker self.exportTest(toC(netD), toC(input)) 177*da0073e9SAndroid Build Coastguard Worker 178*da0073e9SAndroid Build Coastguard Worker @skipScriptTest() 179*da0073e9SAndroid Build Coastguard Worker def test_dcgan_netG(self): 180*da0073e9SAndroid Build Coastguard Worker netG = _netG(1) 181*da0073e9SAndroid Build Coastguard Worker netG.apply(weights_init) 182*da0073e9SAndroid Build Coastguard Worker input = Variable(torch.empty(bsz, nz, 1, 1).normal_(0, 1)) 183*da0073e9SAndroid Build Coastguard Worker self.exportTest(toC(netG), toC(input)) 184*da0073e9SAndroid Build Coastguard Worker 185*da0073e9SAndroid Build Coastguard Worker @skipIfUnsupportedMinOpsetVersion(10) 186*da0073e9SAndroid Build Coastguard Worker def test_fake_quant(self): 187*da0073e9SAndroid Build Coastguard Worker x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) 188*da0073e9SAndroid Build Coastguard Worker self.exportTest(toC(FakeQuantNet()), toC(x)) 189*da0073e9SAndroid Build Coastguard Worker 190*da0073e9SAndroid Build Coastguard Worker @skipIfUnsupportedMinOpsetVersion(10) 191*da0073e9SAndroid Build Coastguard Worker def test_qat_resnet_pertensor(self): 192*da0073e9SAndroid Build Coastguard Worker # Quantize ResNet50 model 193*da0073e9SAndroid Build Coastguard Worker x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) 194*da0073e9SAndroid Build Coastguard Worker qat_resnet50 = resnet50() 195*da0073e9SAndroid Build Coastguard Worker 196*da0073e9SAndroid Build Coastguard Worker # Use per tensor for weight. Per channel support will come with opset 13 197*da0073e9SAndroid Build Coastguard Worker qat_resnet50.qconfig = quantization.QConfig( 198*da0073e9SAndroid Build Coastguard Worker activation=quantization.default_fake_quant, 199*da0073e9SAndroid Build Coastguard Worker weight=quantization.default_fake_quant, 200*da0073e9SAndroid Build Coastguard Worker ) 201*da0073e9SAndroid Build Coastguard Worker quantization.prepare_qat(qat_resnet50, inplace=True) 202*da0073e9SAndroid Build Coastguard Worker qat_resnet50.apply(torch.ao.quantization.enable_observer) 203*da0073e9SAndroid Build Coastguard Worker qat_resnet50.apply(torch.ao.quantization.enable_fake_quant) 204*da0073e9SAndroid Build Coastguard Worker 205*da0073e9SAndroid Build Coastguard Worker _ = qat_resnet50(x) 206*da0073e9SAndroid Build Coastguard Worker for module in qat_resnet50.modules(): 207*da0073e9SAndroid Build Coastguard Worker if isinstance(module, quantization.FakeQuantize): 208*da0073e9SAndroid Build Coastguard Worker module.calculate_qparams() 209*da0073e9SAndroid Build Coastguard Worker qat_resnet50.apply(torch.ao.quantization.disable_observer) 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Worker self.exportTest(toC(qat_resnet50), toC(x)) 212*da0073e9SAndroid Build Coastguard Worker 213*da0073e9SAndroid Build Coastguard Worker @skipIfUnsupportedMinOpsetVersion(13) 214*da0073e9SAndroid Build Coastguard Worker def test_qat_resnet_per_channel(self): 215*da0073e9SAndroid Build Coastguard Worker # Quantize ResNet50 model 216*da0073e9SAndroid Build Coastguard Worker x = torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0) 217*da0073e9SAndroid Build Coastguard Worker qat_resnet50 = resnet50() 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard Worker qat_resnet50.qconfig = quantization.QConfig( 220*da0073e9SAndroid Build Coastguard Worker activation=quantization.default_fake_quant, 221*da0073e9SAndroid Build Coastguard Worker weight=quantization.default_per_channel_weight_fake_quant, 222*da0073e9SAndroid Build Coastguard Worker ) 223*da0073e9SAndroid Build Coastguard Worker quantization.prepare_qat(qat_resnet50, inplace=True) 224*da0073e9SAndroid Build Coastguard Worker qat_resnet50.apply(torch.ao.quantization.enable_observer) 225*da0073e9SAndroid Build Coastguard Worker qat_resnet50.apply(torch.ao.quantization.enable_fake_quant) 226*da0073e9SAndroid Build Coastguard Worker 227*da0073e9SAndroid Build Coastguard Worker _ = qat_resnet50(x) 228*da0073e9SAndroid Build Coastguard Worker for module in qat_resnet50.modules(): 229*da0073e9SAndroid Build Coastguard Worker if isinstance(module, quantization.FakeQuantize): 230*da0073e9SAndroid Build Coastguard Worker module.calculate_qparams() 231*da0073e9SAndroid Build Coastguard Worker qat_resnet50.apply(torch.ao.quantization.disable_observer) 232*da0073e9SAndroid Build Coastguard Worker 233*da0073e9SAndroid Build Coastguard Worker self.exportTest(toC(qat_resnet50), toC(x)) 234*da0073e9SAndroid Build Coastguard Worker 235*da0073e9SAndroid Build Coastguard Worker @skipScriptTest(skip_before_opset_version=15, reason="None type in outputs") 236*da0073e9SAndroid Build Coastguard Worker def test_googlenet(self): 237*da0073e9SAndroid Build Coastguard Worker x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) 238*da0073e9SAndroid Build Coastguard Worker self.exportTest(toC(googlenet()), toC(x), rtol=1e-3, atol=1e-5) 239*da0073e9SAndroid Build Coastguard Worker 240*da0073e9SAndroid Build Coastguard Worker def test_mnasnet(self): 241*da0073e9SAndroid Build Coastguard Worker x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) 242*da0073e9SAndroid Build Coastguard Worker self.exportTest(toC(mnasnet1_0()), toC(x), rtol=1e-3, atol=1e-5) 243*da0073e9SAndroid Build Coastguard Worker 244*da0073e9SAndroid Build Coastguard Worker def test_mobilenet(self): 245*da0073e9SAndroid Build Coastguard Worker x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) 246*da0073e9SAndroid Build Coastguard Worker self.exportTest(toC(mobilenet_v2()), toC(x), rtol=1e-3, atol=1e-5) 247*da0073e9SAndroid Build Coastguard Worker 248*da0073e9SAndroid Build Coastguard Worker @skipScriptTest() # prim_data 249*da0073e9SAndroid Build Coastguard Worker def test_shufflenet(self): 250*da0073e9SAndroid Build Coastguard Worker x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) 251*da0073e9SAndroid Build Coastguard Worker self.exportTest(toC(shufflenet_v2_x1_0()), toC(x), rtol=1e-3, atol=1e-5) 252*da0073e9SAndroid Build Coastguard Worker 253*da0073e9SAndroid Build Coastguard Worker @skipIfUnsupportedMinOpsetVersion(11) 254*da0073e9SAndroid Build Coastguard Worker def test_fcn(self): 255*da0073e9SAndroid Build Coastguard Worker x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) 256*da0073e9SAndroid Build Coastguard Worker self.exportTest( 257*da0073e9SAndroid Build Coastguard Worker toC(fcn_resnet101(weights=None, weights_backbone=None)), 258*da0073e9SAndroid Build Coastguard Worker toC(x), 259*da0073e9SAndroid Build Coastguard Worker rtol=1e-3, 260*da0073e9SAndroid Build Coastguard Worker atol=1e-5, 261*da0073e9SAndroid Build Coastguard Worker ) 262*da0073e9SAndroid Build Coastguard Worker 263*da0073e9SAndroid Build Coastguard Worker @skipIfUnsupportedMinOpsetVersion(11) 264*da0073e9SAndroid Build Coastguard Worker def test_deeplab(self): 265*da0073e9SAndroid Build Coastguard Worker x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) 266*da0073e9SAndroid Build Coastguard Worker self.exportTest( 267*da0073e9SAndroid Build Coastguard Worker toC(deeplabv3_resnet101(weights=None, weights_backbone=None)), 268*da0073e9SAndroid Build Coastguard Worker toC(x), 269*da0073e9SAndroid Build Coastguard Worker rtol=1e-3, 270*da0073e9SAndroid Build Coastguard Worker atol=1e-5, 271*da0073e9SAndroid Build Coastguard Worker ) 272*da0073e9SAndroid Build Coastguard Worker 273*da0073e9SAndroid Build Coastguard Worker def test_r3d_18_video(self): 274*da0073e9SAndroid Build Coastguard Worker x = Variable(torch.randn(1, 3, 4, 112, 112).fill_(1.0)) 275*da0073e9SAndroid Build Coastguard Worker self.exportTest(toC(r3d_18()), toC(x), rtol=1e-3, atol=1e-5) 276*da0073e9SAndroid Build Coastguard Worker 277*da0073e9SAndroid Build Coastguard Worker def test_mc3_18_video(self): 278*da0073e9SAndroid Build Coastguard Worker x = Variable(torch.randn(1, 3, 4, 112, 112).fill_(1.0)) 279*da0073e9SAndroid Build Coastguard Worker self.exportTest(toC(mc3_18()), toC(x), rtol=1e-3, atol=1e-5) 280*da0073e9SAndroid Build Coastguard Worker 281*da0073e9SAndroid Build Coastguard Worker def test_r2plus1d_18_video(self): 282*da0073e9SAndroid Build Coastguard Worker x = Variable(torch.randn(1, 3, 4, 112, 112).fill_(1.0)) 283*da0073e9SAndroid Build Coastguard Worker self.exportTest(toC(r2plus1d_18()), toC(x), rtol=1e-3, atol=1e-5) 284*da0073e9SAndroid Build Coastguard Worker 285*da0073e9SAndroid Build Coastguard Worker 286*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 287*da0073e9SAndroid Build Coastguard Worker common_utils.run_tests() 288