xref: /aosp_15_r20/external/pytorch/test/onnx/test_verification.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: onnx"]
2
3import contextlib
4import io
5import tempfile
6import unittest
7
8import numpy as np
9
10import onnx
11import parameterized
12import pytorch_test_common
13from packaging import version
14
15import torch
16from torch.onnx import _constants, _experimental, verification
17from torch.testing._internal import common_utils
18
19
20class TestVerification(pytorch_test_common.ExportTestCase):
21    def test_check_export_model_diff_returns_diff_when_constant_mismatch(self):
22        class UnexportableModel(torch.nn.Module):
23            def forward(self, x, y):
24                # tensor.data() will be exported as a constant,
25                # leading to wrong model output under different inputs.
26                return x + y.data
27
28        test_input_groups = [
29            ((torch.randn(2, 3), torch.randn(2, 3)), {}),
30            ((torch.randn(2, 3), torch.randn(2, 3)), {}),
31        ]
32
33        results = verification.check_export_model_diff(
34            UnexportableModel(), test_input_groups
35        )
36        self.assertRegex(
37            results,
38            r"Graph diff:(.|\n)*"
39            r"First diverging operator:(.|\n)*"
40            r"prim::Constant(.|\n)*"
41            r"Former source location:(.|\n)*"
42            r"Latter source location:",
43        )
44
45    def test_check_export_model_diff_returns_diff_when_dynamic_controlflow_mismatch(
46        self,
47    ):
48        class UnexportableModel(torch.nn.Module):
49            def forward(self, x, y):
50                for i in range(x.size(0)):
51                    y = x[i] + y
52                return y
53
54        test_input_groups = [
55            ((torch.randn(2, 3), torch.randn(2, 3)), {}),
56            ((torch.randn(4, 3), torch.randn(2, 3)), {}),
57        ]
58
59        export_options = _experimental.ExportOptions(
60            input_names=["x", "y"], dynamic_axes={"x": [0]}
61        )
62        results = verification.check_export_model_diff(
63            UnexportableModel(), test_input_groups, export_options
64        )
65        self.assertRegex(
66            results,
67            r"Graph diff:(.|\n)*"
68            r"First diverging operator:(.|\n)*"
69            r"prim::Constant(.|\n)*"
70            r"Latter source location:(.|\n)*",
71        )
72
73    def test_check_export_model_diff_returns_empty_when_correct_export(self):
74        class SupportedModel(torch.nn.Module):
75            def forward(self, x, y):
76                return x + y
77
78        test_input_groups = [
79            ((torch.randn(2, 3), torch.randn(2, 3)), {}),
80            ((torch.randn(2, 3), torch.randn(2, 3)), {}),
81        ]
82
83        results = verification.check_export_model_diff(
84            SupportedModel(), test_input_groups
85        )
86        self.assertEqual(results, "")
87
88    def test_compare_ort_pytorch_outputs_no_raise_with_acceptable_error_percentage(
89        self,
90    ):
91        ort_outs = [np.array([[1.0, 2.0], [3.0, 4.0]])]
92        pytorch_outs = [torch.tensor([[1.0, 2.0], [3.0, 1.0]])]
93        options = verification.VerificationOptions(
94            rtol=1e-5,
95            atol=1e-6,
96            check_shape=True,
97            check_dtype=False,
98            ignore_none=True,
99            acceptable_error_percentage=0.3,
100        )
101        verification._compare_onnx_pytorch_outputs(
102            ort_outs,
103            pytorch_outs,
104            options,
105        )
106
107    def test_compare_ort_pytorch_outputs_raise_without_acceptable_error_percentage(
108        self,
109    ):
110        ort_outs = [np.array([[1.0, 2.0], [3.0, 4.0]])]
111        pytorch_outs = [torch.tensor([[1.0, 2.0], [3.0, 1.0]])]
112        options = verification.VerificationOptions(
113            rtol=1e-5,
114            atol=1e-6,
115            check_shape=True,
116            check_dtype=False,
117            ignore_none=True,
118            acceptable_error_percentage=None,
119        )
120        with self.assertRaises(AssertionError):
121            verification._compare_onnx_pytorch_outputs(
122                ort_outs,
123                pytorch_outs,
124                options,
125            )
126
127
128@common_utils.instantiate_parametrized_tests
129class TestVerificationOnWrongExport(pytorch_test_common.ExportTestCase):
130    opset_version: int
131
132    def setUp(self):
133        super().setUp()
134
135        def incorrect_add_symbolic_function(g, self, other, alpha):
136            return self
137
138        self.opset_version = _constants.ONNX_DEFAULT_OPSET
139        torch.onnx.register_custom_op_symbolic(
140            "aten::add",
141            incorrect_add_symbolic_function,
142            opset_version=self.opset_version,
143        )
144
145    def tearDown(self):
146        super().tearDown()
147        torch.onnx.unregister_custom_op_symbolic(
148            "aten::add", opset_version=self.opset_version
149        )
150
151    @common_utils.parametrize(
152        "onnx_backend",
153        [
154            common_utils.subtest(
155                verification.OnnxBackend.REFERENCE,
156                decorators=[
157                    unittest.skipIf(
158                        version.Version(onnx.__version__) < version.Version("1.13"),
159                        reason="Reference Python runtime was introduced in 'onnx' 1.13.",
160                    )
161                ],
162            ),
163            verification.OnnxBackend.ONNX_RUNTIME_CPU,
164        ],
165    )
166    def test_verify_found_mismatch_when_export_is_wrong(
167        self, onnx_backend: verification.OnnxBackend
168    ):
169        class Model(torch.nn.Module):
170            def forward(self, x):
171                return x + 1
172
173        with self.assertRaisesRegex(AssertionError, ".*Tensor-likes are not close!.*"):
174            verification.verify(
175                Model(),
176                (torch.randn(2, 3),),
177                opset_version=self.opset_version,
178                options=verification.VerificationOptions(backend=onnx_backend),
179            )
180
181
182@parameterized.parameterized_class(
183    [
184        # TODO: enable this when ONNX submodule catches up to >= 1.13.
185        # {"onnx_backend": verification.OnnxBackend.ONNX},
186        {"onnx_backend": verification.OnnxBackend.ONNX_RUNTIME_CPU},
187    ],
188    class_name_func=lambda cls,
189    idx,
190    input_dicts: f"{cls.__name__}_{input_dicts['onnx_backend'].name}",
191)
192class TestFindMismatch(pytorch_test_common.ExportTestCase):
193    onnx_backend: verification.OnnxBackend
194    opset_version: int
195    graph_info: verification.GraphInfo
196
197    def setUp(self):
198        super().setUp()
199        self.opset_version = _constants.ONNX_DEFAULT_OPSET
200
201        def incorrect_relu_symbolic_function(g, self):
202            return g.op("Add", self, g.op("Constant", value_t=torch.tensor(1.0)))
203
204        torch.onnx.register_custom_op_symbolic(
205            "aten::relu",
206            incorrect_relu_symbolic_function,
207            opset_version=self.opset_version,
208        )
209
210        class Model(torch.nn.Module):
211            def __init__(self) -> None:
212                super().__init__()
213                self.layers = torch.nn.Sequential(
214                    torch.nn.Linear(3, 4),
215                    torch.nn.ReLU(),
216                    torch.nn.Linear(4, 5),
217                    torch.nn.ReLU(),
218                    torch.nn.Linear(5, 6),
219                )
220
221            def forward(self, x):
222                return self.layers(x)
223
224        self.graph_info = verification.find_mismatch(
225            Model(),
226            (torch.randn(2, 3),),
227            opset_version=self.opset_version,
228            options=verification.VerificationOptions(backend=self.onnx_backend),
229        )
230
231    def tearDown(self):
232        super().tearDown()
233        torch.onnx.unregister_custom_op_symbolic(
234            "aten::relu", opset_version=self.opset_version
235        )
236        delattr(self, "opset_version")
237        delattr(self, "graph_info")
238
239    def test_pretty_print_tree_visualizes_mismatch(self):
240        f = io.StringIO()
241        with contextlib.redirect_stdout(f):
242            self.graph_info.pretty_print_tree()
243        self.assertExpected(f.getvalue())
244
245    def test_preserve_mismatch_source_location(self):
246        mismatch_leaves = self.graph_info.all_mismatch_leaf_graph_info()
247
248        self.assertTrue(len(mismatch_leaves) > 0)
249
250        for leaf_info in mismatch_leaves:
251            f = io.StringIO()
252            with contextlib.redirect_stdout(f):
253                leaf_info.pretty_print_mismatch(graph=True)
254            self.assertRegex(
255                f.getvalue(),
256                r"(.|\n)*" r"aten::relu.*/torch/nn/functional.py:[0-9]+(.|\n)*",
257            )
258
259    def test_find_all_mismatch_operators(self):
260        mismatch_leaves = self.graph_info.all_mismatch_leaf_graph_info()
261
262        self.assertEqual(len(mismatch_leaves), 2)
263
264        for leaf_info in mismatch_leaves:
265            self.assertEqual(leaf_info.essential_node_count(), 1)
266            self.assertEqual(leaf_info.essential_node_kinds(), {"aten::relu"})
267
268    def test_find_mismatch_prints_correct_info_when_no_mismatch(self):
269        self.maxDiff = None
270
271        class Model(torch.nn.Module):
272            def forward(self, x):
273                return x + 1
274
275        f = io.StringIO()
276        with contextlib.redirect_stdout(f):
277            verification.find_mismatch(
278                Model(),
279                (torch.randn(2, 3),),
280                opset_version=self.opset_version,
281                options=verification.VerificationOptions(backend=self.onnx_backend),
282            )
283        self.assertExpected(f.getvalue())
284
285    def test_export_repro_for_mismatch(self):
286        mismatch_leaves = self.graph_info.all_mismatch_leaf_graph_info()
287        self.assertTrue(len(mismatch_leaves) > 0)
288        leaf_info = mismatch_leaves[0]
289        with tempfile.TemporaryDirectory() as temp_dir:
290            repro_dir = leaf_info.export_repro(temp_dir)
291
292            with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"):
293                options = verification.VerificationOptions(backend=self.onnx_backend)
294                verification.OnnxTestCaseRepro(repro_dir).validate(options)
295
296
297if __name__ == "__main__":
298    common_utils.run_tests()
299