xref: /aosp_15_r20/external/pytorch/test/onnx/test_verification.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: onnx"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport contextlib
4*da0073e9SAndroid Build Coastguard Workerimport io
5*da0073e9SAndroid Build Coastguard Workerimport tempfile
6*da0073e9SAndroid Build Coastguard Workerimport unittest
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerimport numpy as np
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Workerimport onnx
11*da0073e9SAndroid Build Coastguard Workerimport parameterized
12*da0073e9SAndroid Build Coastguard Workerimport pytorch_test_common
13*da0073e9SAndroid Build Coastguard Workerfrom packaging import version
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Workerimport torch
16*da0073e9SAndroid Build Coastguard Workerfrom torch.onnx import _constants, _experimental, verification
17*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal import common_utils
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Workerclass TestVerification(pytorch_test_common.ExportTestCase):
21*da0073e9SAndroid Build Coastguard Worker    def test_check_export_model_diff_returns_diff_when_constant_mismatch(self):
22*da0073e9SAndroid Build Coastguard Worker        class UnexportableModel(torch.nn.Module):
23*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, y):
24*da0073e9SAndroid Build Coastguard Worker                # tensor.data() will be exported as a constant,
25*da0073e9SAndroid Build Coastguard Worker                # leading to wrong model output under different inputs.
26*da0073e9SAndroid Build Coastguard Worker                return x + y.data
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker        test_input_groups = [
29*da0073e9SAndroid Build Coastguard Worker            ((torch.randn(2, 3), torch.randn(2, 3)), {}),
30*da0073e9SAndroid Build Coastguard Worker            ((torch.randn(2, 3), torch.randn(2, 3)), {}),
31*da0073e9SAndroid Build Coastguard Worker        ]
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker        results = verification.check_export_model_diff(
34*da0073e9SAndroid Build Coastguard Worker            UnexportableModel(), test_input_groups
35*da0073e9SAndroid Build Coastguard Worker        )
36*da0073e9SAndroid Build Coastguard Worker        self.assertRegex(
37*da0073e9SAndroid Build Coastguard Worker            results,
38*da0073e9SAndroid Build Coastguard Worker            r"Graph diff:(.|\n)*"
39*da0073e9SAndroid Build Coastguard Worker            r"First diverging operator:(.|\n)*"
40*da0073e9SAndroid Build Coastguard Worker            r"prim::Constant(.|\n)*"
41*da0073e9SAndroid Build Coastguard Worker            r"Former source location:(.|\n)*"
42*da0073e9SAndroid Build Coastguard Worker            r"Latter source location:",
43*da0073e9SAndroid Build Coastguard Worker        )
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker    def test_check_export_model_diff_returns_diff_when_dynamic_controlflow_mismatch(
46*da0073e9SAndroid Build Coastguard Worker        self,
47*da0073e9SAndroid Build Coastguard Worker    ):
48*da0073e9SAndroid Build Coastguard Worker        class UnexportableModel(torch.nn.Module):
49*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, y):
50*da0073e9SAndroid Build Coastguard Worker                for i in range(x.size(0)):
51*da0073e9SAndroid Build Coastguard Worker                    y = x[i] + y
52*da0073e9SAndroid Build Coastguard Worker                return y
53*da0073e9SAndroid Build Coastguard Worker
54*da0073e9SAndroid Build Coastguard Worker        test_input_groups = [
55*da0073e9SAndroid Build Coastguard Worker            ((torch.randn(2, 3), torch.randn(2, 3)), {}),
56*da0073e9SAndroid Build Coastguard Worker            ((torch.randn(4, 3), torch.randn(2, 3)), {}),
57*da0073e9SAndroid Build Coastguard Worker        ]
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker        export_options = _experimental.ExportOptions(
60*da0073e9SAndroid Build Coastguard Worker            input_names=["x", "y"], dynamic_axes={"x": [0]}
61*da0073e9SAndroid Build Coastguard Worker        )
62*da0073e9SAndroid Build Coastguard Worker        results = verification.check_export_model_diff(
63*da0073e9SAndroid Build Coastguard Worker            UnexportableModel(), test_input_groups, export_options
64*da0073e9SAndroid Build Coastguard Worker        )
65*da0073e9SAndroid Build Coastguard Worker        self.assertRegex(
66*da0073e9SAndroid Build Coastguard Worker            results,
67*da0073e9SAndroid Build Coastguard Worker            r"Graph diff:(.|\n)*"
68*da0073e9SAndroid Build Coastguard Worker            r"First diverging operator:(.|\n)*"
69*da0073e9SAndroid Build Coastguard Worker            r"prim::Constant(.|\n)*"
70*da0073e9SAndroid Build Coastguard Worker            r"Latter source location:(.|\n)*",
71*da0073e9SAndroid Build Coastguard Worker        )
72*da0073e9SAndroid Build Coastguard Worker
73*da0073e9SAndroid Build Coastguard Worker    def test_check_export_model_diff_returns_empty_when_correct_export(self):
74*da0073e9SAndroid Build Coastguard Worker        class SupportedModel(torch.nn.Module):
75*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, y):
76*da0073e9SAndroid Build Coastguard Worker                return x + y
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker        test_input_groups = [
79*da0073e9SAndroid Build Coastguard Worker            ((torch.randn(2, 3), torch.randn(2, 3)), {}),
80*da0073e9SAndroid Build Coastguard Worker            ((torch.randn(2, 3), torch.randn(2, 3)), {}),
81*da0073e9SAndroid Build Coastguard Worker        ]
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard Worker        results = verification.check_export_model_diff(
84*da0073e9SAndroid Build Coastguard Worker            SupportedModel(), test_input_groups
85*da0073e9SAndroid Build Coastguard Worker        )
86*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(results, "")
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Worker    def test_compare_ort_pytorch_outputs_no_raise_with_acceptable_error_percentage(
89*da0073e9SAndroid Build Coastguard Worker        self,
90*da0073e9SAndroid Build Coastguard Worker    ):
91*da0073e9SAndroid Build Coastguard Worker        ort_outs = [np.array([[1.0, 2.0], [3.0, 4.0]])]
92*da0073e9SAndroid Build Coastguard Worker        pytorch_outs = [torch.tensor([[1.0, 2.0], [3.0, 1.0]])]
93*da0073e9SAndroid Build Coastguard Worker        options = verification.VerificationOptions(
94*da0073e9SAndroid Build Coastguard Worker            rtol=1e-5,
95*da0073e9SAndroid Build Coastguard Worker            atol=1e-6,
96*da0073e9SAndroid Build Coastguard Worker            check_shape=True,
97*da0073e9SAndroid Build Coastguard Worker            check_dtype=False,
98*da0073e9SAndroid Build Coastguard Worker            ignore_none=True,
99*da0073e9SAndroid Build Coastguard Worker            acceptable_error_percentage=0.3,
100*da0073e9SAndroid Build Coastguard Worker        )
101*da0073e9SAndroid Build Coastguard Worker        verification._compare_onnx_pytorch_outputs(
102*da0073e9SAndroid Build Coastguard Worker            ort_outs,
103*da0073e9SAndroid Build Coastguard Worker            pytorch_outs,
104*da0073e9SAndroid Build Coastguard Worker            options,
105*da0073e9SAndroid Build Coastguard Worker        )
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Worker    def test_compare_ort_pytorch_outputs_raise_without_acceptable_error_percentage(
108*da0073e9SAndroid Build Coastguard Worker        self,
109*da0073e9SAndroid Build Coastguard Worker    ):
110*da0073e9SAndroid Build Coastguard Worker        ort_outs = [np.array([[1.0, 2.0], [3.0, 4.0]])]
111*da0073e9SAndroid Build Coastguard Worker        pytorch_outs = [torch.tensor([[1.0, 2.0], [3.0, 1.0]])]
112*da0073e9SAndroid Build Coastguard Worker        options = verification.VerificationOptions(
113*da0073e9SAndroid Build Coastguard Worker            rtol=1e-5,
114*da0073e9SAndroid Build Coastguard Worker            atol=1e-6,
115*da0073e9SAndroid Build Coastguard Worker            check_shape=True,
116*da0073e9SAndroid Build Coastguard Worker            check_dtype=False,
117*da0073e9SAndroid Build Coastguard Worker            ignore_none=True,
118*da0073e9SAndroid Build Coastguard Worker            acceptable_error_percentage=None,
119*da0073e9SAndroid Build Coastguard Worker        )
120*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(AssertionError):
121*da0073e9SAndroid Build Coastguard Worker            verification._compare_onnx_pytorch_outputs(
122*da0073e9SAndroid Build Coastguard Worker                ort_outs,
123*da0073e9SAndroid Build Coastguard Worker                pytorch_outs,
124*da0073e9SAndroid Build Coastguard Worker                options,
125*da0073e9SAndroid Build Coastguard Worker            )
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Worker
128*da0073e9SAndroid Build Coastguard Worker@common_utils.instantiate_parametrized_tests
129*da0073e9SAndroid Build Coastguard Workerclass TestVerificationOnWrongExport(pytorch_test_common.ExportTestCase):
130*da0073e9SAndroid Build Coastguard Worker    opset_version: int
131*da0073e9SAndroid Build Coastguard Worker
132*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
133*da0073e9SAndroid Build Coastguard Worker        super().setUp()
134*da0073e9SAndroid Build Coastguard Worker
135*da0073e9SAndroid Build Coastguard Worker        def incorrect_add_symbolic_function(g, self, other, alpha):
136*da0073e9SAndroid Build Coastguard Worker            return self
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Worker        self.opset_version = _constants.ONNX_DEFAULT_OPSET
139*da0073e9SAndroid Build Coastguard Worker        torch.onnx.register_custom_op_symbolic(
140*da0073e9SAndroid Build Coastguard Worker            "aten::add",
141*da0073e9SAndroid Build Coastguard Worker            incorrect_add_symbolic_function,
142*da0073e9SAndroid Build Coastguard Worker            opset_version=self.opset_version,
143*da0073e9SAndroid Build Coastguard Worker        )
144*da0073e9SAndroid Build Coastguard Worker
145*da0073e9SAndroid Build Coastguard Worker    def tearDown(self):
146*da0073e9SAndroid Build Coastguard Worker        super().tearDown()
147*da0073e9SAndroid Build Coastguard Worker        torch.onnx.unregister_custom_op_symbolic(
148*da0073e9SAndroid Build Coastguard Worker            "aten::add", opset_version=self.opset_version
149*da0073e9SAndroid Build Coastguard Worker        )
150*da0073e9SAndroid Build Coastguard Worker
151*da0073e9SAndroid Build Coastguard Worker    @common_utils.parametrize(
152*da0073e9SAndroid Build Coastguard Worker        "onnx_backend",
153*da0073e9SAndroid Build Coastguard Worker        [
154*da0073e9SAndroid Build Coastguard Worker            common_utils.subtest(
155*da0073e9SAndroid Build Coastguard Worker                verification.OnnxBackend.REFERENCE,
156*da0073e9SAndroid Build Coastguard Worker                decorators=[
157*da0073e9SAndroid Build Coastguard Worker                    unittest.skipIf(
158*da0073e9SAndroid Build Coastguard Worker                        version.Version(onnx.__version__) < version.Version("1.13"),
159*da0073e9SAndroid Build Coastguard Worker                        reason="Reference Python runtime was introduced in 'onnx' 1.13.",
160*da0073e9SAndroid Build Coastguard Worker                    )
161*da0073e9SAndroid Build Coastguard Worker                ],
162*da0073e9SAndroid Build Coastguard Worker            ),
163*da0073e9SAndroid Build Coastguard Worker            verification.OnnxBackend.ONNX_RUNTIME_CPU,
164*da0073e9SAndroid Build Coastguard Worker        ],
165*da0073e9SAndroid Build Coastguard Worker    )
166*da0073e9SAndroid Build Coastguard Worker    def test_verify_found_mismatch_when_export_is_wrong(
167*da0073e9SAndroid Build Coastguard Worker        self, onnx_backend: verification.OnnxBackend
168*da0073e9SAndroid Build Coastguard Worker    ):
169*da0073e9SAndroid Build Coastguard Worker        class Model(torch.nn.Module):
170*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
171*da0073e9SAndroid Build Coastguard Worker                return x + 1
172*da0073e9SAndroid Build Coastguard Worker
173*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(AssertionError, ".*Tensor-likes are not close!.*"):
174*da0073e9SAndroid Build Coastguard Worker            verification.verify(
175*da0073e9SAndroid Build Coastguard Worker                Model(),
176*da0073e9SAndroid Build Coastguard Worker                (torch.randn(2, 3),),
177*da0073e9SAndroid Build Coastguard Worker                opset_version=self.opset_version,
178*da0073e9SAndroid Build Coastguard Worker                options=verification.VerificationOptions(backend=onnx_backend),
179*da0073e9SAndroid Build Coastguard Worker            )
180*da0073e9SAndroid Build Coastguard Worker
181*da0073e9SAndroid Build Coastguard Worker
182*da0073e9SAndroid Build Coastguard Worker@parameterized.parameterized_class(
183*da0073e9SAndroid Build Coastguard Worker    [
184*da0073e9SAndroid Build Coastguard Worker        # TODO: enable this when ONNX submodule catches up to >= 1.13.
185*da0073e9SAndroid Build Coastguard Worker        # {"onnx_backend": verification.OnnxBackend.ONNX},
186*da0073e9SAndroid Build Coastguard Worker        {"onnx_backend": verification.OnnxBackend.ONNX_RUNTIME_CPU},
187*da0073e9SAndroid Build Coastguard Worker    ],
188*da0073e9SAndroid Build Coastguard Worker    class_name_func=lambda cls,
189*da0073e9SAndroid Build Coastguard Worker    idx,
190*da0073e9SAndroid Build Coastguard Worker    input_dicts: f"{cls.__name__}_{input_dicts['onnx_backend'].name}",
191*da0073e9SAndroid Build Coastguard Worker)
192*da0073e9SAndroid Build Coastguard Workerclass TestFindMismatch(pytorch_test_common.ExportTestCase):
193*da0073e9SAndroid Build Coastguard Worker    onnx_backend: verification.OnnxBackend
194*da0073e9SAndroid Build Coastguard Worker    opset_version: int
195*da0073e9SAndroid Build Coastguard Worker    graph_info: verification.GraphInfo
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
198*da0073e9SAndroid Build Coastguard Worker        super().setUp()
199*da0073e9SAndroid Build Coastguard Worker        self.opset_version = _constants.ONNX_DEFAULT_OPSET
200*da0073e9SAndroid Build Coastguard Worker
201*da0073e9SAndroid Build Coastguard Worker        def incorrect_relu_symbolic_function(g, self):
202*da0073e9SAndroid Build Coastguard Worker            return g.op("Add", self, g.op("Constant", value_t=torch.tensor(1.0)))
203*da0073e9SAndroid Build Coastguard Worker
204*da0073e9SAndroid Build Coastguard Worker        torch.onnx.register_custom_op_symbolic(
205*da0073e9SAndroid Build Coastguard Worker            "aten::relu",
206*da0073e9SAndroid Build Coastguard Worker            incorrect_relu_symbolic_function,
207*da0073e9SAndroid Build Coastguard Worker            opset_version=self.opset_version,
208*da0073e9SAndroid Build Coastguard Worker        )
209*da0073e9SAndroid Build Coastguard Worker
210*da0073e9SAndroid Build Coastguard Worker        class Model(torch.nn.Module):
211*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
212*da0073e9SAndroid Build Coastguard Worker                super().__init__()
213*da0073e9SAndroid Build Coastguard Worker                self.layers = torch.nn.Sequential(
214*da0073e9SAndroid Build Coastguard Worker                    torch.nn.Linear(3, 4),
215*da0073e9SAndroid Build Coastguard Worker                    torch.nn.ReLU(),
216*da0073e9SAndroid Build Coastguard Worker                    torch.nn.Linear(4, 5),
217*da0073e9SAndroid Build Coastguard Worker                    torch.nn.ReLU(),
218*da0073e9SAndroid Build Coastguard Worker                    torch.nn.Linear(5, 6),
219*da0073e9SAndroid Build Coastguard Worker                )
220*da0073e9SAndroid Build Coastguard Worker
221*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
222*da0073e9SAndroid Build Coastguard Worker                return self.layers(x)
223*da0073e9SAndroid Build Coastguard Worker
224*da0073e9SAndroid Build Coastguard Worker        self.graph_info = verification.find_mismatch(
225*da0073e9SAndroid Build Coastguard Worker            Model(),
226*da0073e9SAndroid Build Coastguard Worker            (torch.randn(2, 3),),
227*da0073e9SAndroid Build Coastguard Worker            opset_version=self.opset_version,
228*da0073e9SAndroid Build Coastguard Worker            options=verification.VerificationOptions(backend=self.onnx_backend),
229*da0073e9SAndroid Build Coastguard Worker        )
230*da0073e9SAndroid Build Coastguard Worker
231*da0073e9SAndroid Build Coastguard Worker    def tearDown(self):
232*da0073e9SAndroid Build Coastguard Worker        super().tearDown()
233*da0073e9SAndroid Build Coastguard Worker        torch.onnx.unregister_custom_op_symbolic(
234*da0073e9SAndroid Build Coastguard Worker            "aten::relu", opset_version=self.opset_version
235*da0073e9SAndroid Build Coastguard Worker        )
236*da0073e9SAndroid Build Coastguard Worker        delattr(self, "opset_version")
237*da0073e9SAndroid Build Coastguard Worker        delattr(self, "graph_info")
238*da0073e9SAndroid Build Coastguard Worker
239*da0073e9SAndroid Build Coastguard Worker    def test_pretty_print_tree_visualizes_mismatch(self):
240*da0073e9SAndroid Build Coastguard Worker        f = io.StringIO()
241*da0073e9SAndroid Build Coastguard Worker        with contextlib.redirect_stdout(f):
242*da0073e9SAndroid Build Coastguard Worker            self.graph_info.pretty_print_tree()
243*da0073e9SAndroid Build Coastguard Worker        self.assertExpected(f.getvalue())
244*da0073e9SAndroid Build Coastguard Worker
245*da0073e9SAndroid Build Coastguard Worker    def test_preserve_mismatch_source_location(self):
246*da0073e9SAndroid Build Coastguard Worker        mismatch_leaves = self.graph_info.all_mismatch_leaf_graph_info()
247*da0073e9SAndroid Build Coastguard Worker
248*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(mismatch_leaves) > 0)
249*da0073e9SAndroid Build Coastguard Worker
250*da0073e9SAndroid Build Coastguard Worker        for leaf_info in mismatch_leaves:
251*da0073e9SAndroid Build Coastguard Worker            f = io.StringIO()
252*da0073e9SAndroid Build Coastguard Worker            with contextlib.redirect_stdout(f):
253*da0073e9SAndroid Build Coastguard Worker                leaf_info.pretty_print_mismatch(graph=True)
254*da0073e9SAndroid Build Coastguard Worker            self.assertRegex(
255*da0073e9SAndroid Build Coastguard Worker                f.getvalue(),
256*da0073e9SAndroid Build Coastguard Worker                r"(.|\n)*" r"aten::relu.*/torch/nn/functional.py:[0-9]+(.|\n)*",
257*da0073e9SAndroid Build Coastguard Worker            )
258*da0073e9SAndroid Build Coastguard Worker
259*da0073e9SAndroid Build Coastguard Worker    def test_find_all_mismatch_operators(self):
260*da0073e9SAndroid Build Coastguard Worker        mismatch_leaves = self.graph_info.all_mismatch_leaf_graph_info()
261*da0073e9SAndroid Build Coastguard Worker
262*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(mismatch_leaves), 2)
263*da0073e9SAndroid Build Coastguard Worker
264*da0073e9SAndroid Build Coastguard Worker        for leaf_info in mismatch_leaves:
265*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(leaf_info.essential_node_count(), 1)
266*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(leaf_info.essential_node_kinds(), {"aten::relu"})
267*da0073e9SAndroid Build Coastguard Worker
268*da0073e9SAndroid Build Coastguard Worker    def test_find_mismatch_prints_correct_info_when_no_mismatch(self):
269*da0073e9SAndroid Build Coastguard Worker        self.maxDiff = None
270*da0073e9SAndroid Build Coastguard Worker
271*da0073e9SAndroid Build Coastguard Worker        class Model(torch.nn.Module):
272*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
273*da0073e9SAndroid Build Coastguard Worker                return x + 1
274*da0073e9SAndroid Build Coastguard Worker
275*da0073e9SAndroid Build Coastguard Worker        f = io.StringIO()
276*da0073e9SAndroid Build Coastguard Worker        with contextlib.redirect_stdout(f):
277*da0073e9SAndroid Build Coastguard Worker            verification.find_mismatch(
278*da0073e9SAndroid Build Coastguard Worker                Model(),
279*da0073e9SAndroid Build Coastguard Worker                (torch.randn(2, 3),),
280*da0073e9SAndroid Build Coastguard Worker                opset_version=self.opset_version,
281*da0073e9SAndroid Build Coastguard Worker                options=verification.VerificationOptions(backend=self.onnx_backend),
282*da0073e9SAndroid Build Coastguard Worker            )
283*da0073e9SAndroid Build Coastguard Worker        self.assertExpected(f.getvalue())
284*da0073e9SAndroid Build Coastguard Worker
285*da0073e9SAndroid Build Coastguard Worker    def test_export_repro_for_mismatch(self):
286*da0073e9SAndroid Build Coastguard Worker        mismatch_leaves = self.graph_info.all_mismatch_leaf_graph_info()
287*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(mismatch_leaves) > 0)
288*da0073e9SAndroid Build Coastguard Worker        leaf_info = mismatch_leaves[0]
289*da0073e9SAndroid Build Coastguard Worker        with tempfile.TemporaryDirectory() as temp_dir:
290*da0073e9SAndroid Build Coastguard Worker            repro_dir = leaf_info.export_repro(temp_dir)
291*da0073e9SAndroid Build Coastguard Worker
292*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"):
293*da0073e9SAndroid Build Coastguard Worker                options = verification.VerificationOptions(backend=self.onnx_backend)
294*da0073e9SAndroid Build Coastguard Worker                verification.OnnxTestCaseRepro(repro_dir).validate(options)
295*da0073e9SAndroid Build Coastguard Worker
296*da0073e9SAndroid Build Coastguard Worker
297*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
298*da0073e9SAndroid Build Coastguard Worker    common_utils.run_tests()
299