xref: /aosp_15_r20/external/executorch/backends/xnnpack/test/models/resnet.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10import torchvision
11
12from executorch.backends.xnnpack.test.tester import Quantize, Tester
13
14
15class TestResNet18(unittest.TestCase):
16    inputs = (torch.randn(1, 3, 224, 224),)
17    dynamic_shapes = (
18        {
19            2: torch.export.Dim("height", min=224, max=455),
20            3: torch.export.Dim("width", min=224, max=455),
21        },
22    )
23
24    class DynamicResNet(torch.nn.Module):
25        def __init__(self):
26            super().__init__()
27            self.model = torchvision.models.resnet18()
28
29        def forward(self, x):
30            x = torch.nn.functional.interpolate(
31                x,
32                size=(224, 224),
33                mode="bilinear",
34                align_corners=True,
35                antialias=False,
36            )
37            return self.model(x)
38
39    def _test_exported_resnet(self, tester):
40        (
41            tester.export()
42            .to_edge_transform_and_lower()
43            .check_not(
44                [
45                    "executorch_exir_dialects_edge__ops_aten_convolution_default",
46                    "executorch_exir_dialects_edge__ops_aten_mean_dim",
47                ]
48            )
49            .check(["torch.ops.higher_order.executorch_call_delegate"])
50            .to_executorch()
51            .serialize()
52            .run_method_and_compare_outputs()
53        )
54
55    def test_fp32_resnet18(self):
56        self._test_exported_resnet(Tester(torchvision.models.resnet18(), self.inputs))
57
58    @unittest.skip("T187799178: Debugging Numerical Issues with Calibration")
59    def _test_qs8_resnet18(self):
60        quantized_tester = Tester(torchvision.models.resnet18(), self.inputs).quantize()
61        self._test_exported_resnet(quantized_tester)
62
63    # TODO: Delete and only used calibrated test after T187799178
64    def test_qs8_resnet18_no_calibration(self):
65        quantized_tester = Tester(torchvision.models.resnet18(), self.inputs).quantize(
66            Quantize(calibrate=False)
67        )
68        self._test_exported_resnet(quantized_tester)
69
70    def test_fp32_resnet18_dynamic(self):
71        self._test_exported_resnet(
72            Tester(self.DynamicResNet(), self.inputs, self.dynamic_shapes)
73        )
74
75    @unittest.skip("T187799178: Debugging Numerical Issues with Calibration")
76    def _test_qs8_resnet18_dynamic(self):
77        self._test_exported_resnet(
78            Tester(self.DynamicResNet(), self.inputs, self.dynamic_shapes).quantize()
79        )
80
81    # TODO: Delete and only used calibrated test after T187799178
82    def test_qs8_resnet18_dynamic_no_calibration(self):
83        self._test_exported_resnet(
84            Tester(self.DynamicResNet(), self.inputs, self.dynamic_shapes).quantize(
85                Quantize(calibrate=False)
86            )
87        )
88