xref: /aosp_15_r20/external/pytorch/test/onnx/test_models_quantized_onnxruntime.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: onnx"]
2
3import os
4import unittest
5
6import onnx_test_common
7import parameterized
8import PIL
9import torchvision
10
11import torch
12from torch import nn
13
14
15def _get_test_image_tensor():
16    data_dir = os.path.join(os.path.dirname(__file__), "assets")
17    img_path = os.path.join(data_dir, "grace_hopper_517x606.jpg")
18    input_image = PIL.Image.open(img_path)
19    # Based on example from https://pytorch.org/hub/pytorch_vision_resnet/
20    preprocess = torchvision.transforms.Compose(
21        [
22            torchvision.transforms.Resize(256),
23            torchvision.transforms.CenterCrop(224),
24            torchvision.transforms.ToTensor(),
25            torchvision.transforms.Normalize(
26                mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
27            ),
28        ]
29    )
30    return preprocess(input_image).unsqueeze(0)
31
32
33# Due to precision error from quantization, check only that the top prediction matches.
34class _TopPredictor(nn.Module):
35    def __init__(self, base_model):
36        super().__init__()
37        self.base_model = base_model
38
39    def forward(self, x):
40        x = self.base_model(x)
41        _, topk_id = torch.topk(x[0], 1)
42        return topk_id
43
44
45# TODO: All torchvision quantized model test can be written as single parameterized test case,
46# after per-parameter test decoration is supported via #79979, or after they are all enabled,
47# whichever is first.
48@parameterized.parameterized_class(
49    ("is_script",),
50    [(True,), (False,)],
51    class_name_func=onnx_test_common.parameterize_class_name,
52)
53class TestQuantizedModelsONNXRuntime(onnx_test_common._TestONNXRuntime):
54    def run_test(self, model, inputs, *args, **kwargs):
55        model = _TopPredictor(model)
56        return super().run_test(model, inputs, *args, **kwargs)
57
58    def test_mobilenet_v3(self):
59        model = torchvision.models.quantization.mobilenet_v3_large(
60            pretrained=True, quantize=True
61        )
62        self.run_test(model, _get_test_image_tensor())
63
64    @unittest.skip("quantized::cat not supported")
65    def test_inception_v3(self):
66        model = torchvision.models.quantization.inception_v3(
67            pretrained=True, quantize=True
68        )
69        self.run_test(model, _get_test_image_tensor())
70
71    @unittest.skip("quantized::cat not supported")
72    def test_googlenet(self):
73        model = torchvision.models.quantization.googlenet(
74            pretrained=True, quantize=True
75        )
76        self.run_test(model, _get_test_image_tensor())
77
78    @unittest.skip("quantized::cat not supported")
79    def test_shufflenet_v2_x0_5(self):
80        model = torchvision.models.quantization.shufflenet_v2_x0_5(
81            pretrained=True, quantize=True
82        )
83        self.run_test(model, _get_test_image_tensor())
84
85    def test_resnet18(self):
86        model = torchvision.models.quantization.resnet18(pretrained=True, quantize=True)
87        self.run_test(model, _get_test_image_tensor())
88
89    def test_resnet50(self):
90        model = torchvision.models.quantization.resnet50(pretrained=True, quantize=True)
91        self.run_test(model, _get_test_image_tensor())
92
93    def test_resnext101_32x8d(self):
94        model = torchvision.models.quantization.resnext101_32x8d(
95            pretrained=True, quantize=True
96        )
97        self.run_test(model, _get_test_image_tensor())
98