1# Owner(s): ["module: onnx"] 2 3import os 4import unittest 5 6import onnx_test_common 7import parameterized 8import PIL 9import torchvision 10 11import torch 12from torch import nn 13 14 15def _get_test_image_tensor(): 16 data_dir = os.path.join(os.path.dirname(__file__), "assets") 17 img_path = os.path.join(data_dir, "grace_hopper_517x606.jpg") 18 input_image = PIL.Image.open(img_path) 19 # Based on example from https://pytorch.org/hub/pytorch_vision_resnet/ 20 preprocess = torchvision.transforms.Compose( 21 [ 22 torchvision.transforms.Resize(256), 23 torchvision.transforms.CenterCrop(224), 24 torchvision.transforms.ToTensor(), 25 torchvision.transforms.Normalize( 26 mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 27 ), 28 ] 29 ) 30 return preprocess(input_image).unsqueeze(0) 31 32 33# Due to precision error from quantization, check only that the top prediction matches. 34class _TopPredictor(nn.Module): 35 def __init__(self, base_model): 36 super().__init__() 37 self.base_model = base_model 38 39 def forward(self, x): 40 x = self.base_model(x) 41 _, topk_id = torch.topk(x[0], 1) 42 return topk_id 43 44 45# TODO: All torchvision quantized model test can be written as single parameterized test case, 46# after per-parameter test decoration is supported via #79979, or after they are all enabled, 47# whichever is first. 48@parameterized.parameterized_class( 49 ("is_script",), 50 [(True,), (False,)], 51 class_name_func=onnx_test_common.parameterize_class_name, 52) 53class TestQuantizedModelsONNXRuntime(onnx_test_common._TestONNXRuntime): 54 def run_test(self, model, inputs, *args, **kwargs): 55 model = _TopPredictor(model) 56 return super().run_test(model, inputs, *args, **kwargs) 57 58 def test_mobilenet_v3(self): 59 model = torchvision.models.quantization.mobilenet_v3_large( 60 pretrained=True, quantize=True 61 ) 62 self.run_test(model, _get_test_image_tensor()) 63 64 @unittest.skip("quantized::cat not supported") 65 def test_inception_v3(self): 66 model = torchvision.models.quantization.inception_v3( 67 pretrained=True, quantize=True 68 ) 69 self.run_test(model, _get_test_image_tensor()) 70 71 @unittest.skip("quantized::cat not supported") 72 def test_googlenet(self): 73 model = torchvision.models.quantization.googlenet( 74 pretrained=True, quantize=True 75 ) 76 self.run_test(model, _get_test_image_tensor()) 77 78 @unittest.skip("quantized::cat not supported") 79 def test_shufflenet_v2_x0_5(self): 80 model = torchvision.models.quantization.shufflenet_v2_x0_5( 81 pretrained=True, quantize=True 82 ) 83 self.run_test(model, _get_test_image_tensor()) 84 85 def test_resnet18(self): 86 model = torchvision.models.quantization.resnet18(pretrained=True, quantize=True) 87 self.run_test(model, _get_test_image_tensor()) 88 89 def test_resnet50(self): 90 model = torchvision.models.quantization.resnet50(pretrained=True, quantize=True) 91 self.run_test(model, _get_test_image_tensor()) 92 93 def test_resnext101_32x8d(self): 94 model = torchvision.models.quantization.resnext101_32x8d( 95 pretrained=True, quantize=True 96 ) 97 self.run_test(model, _get_test_image_tensor()) 98