1# Owner(s): ["module: onnx"] 2 3import contextlib 4import io 5import tempfile 6import unittest 7 8import numpy as np 9 10import onnx 11import parameterized 12import pytorch_test_common 13from packaging import version 14 15import torch 16from torch.onnx import _constants, _experimental, verification 17from torch.testing._internal import common_utils 18 19 20class TestVerification(pytorch_test_common.ExportTestCase): 21 def test_check_export_model_diff_returns_diff_when_constant_mismatch(self): 22 class UnexportableModel(torch.nn.Module): 23 def forward(self, x, y): 24 # tensor.data() will be exported as a constant, 25 # leading to wrong model output under different inputs. 26 return x + y.data 27 28 test_input_groups = [ 29 ((torch.randn(2, 3), torch.randn(2, 3)), {}), 30 ((torch.randn(2, 3), torch.randn(2, 3)), {}), 31 ] 32 33 results = verification.check_export_model_diff( 34 UnexportableModel(), test_input_groups 35 ) 36 self.assertRegex( 37 results, 38 r"Graph diff:(.|\n)*" 39 r"First diverging operator:(.|\n)*" 40 r"prim::Constant(.|\n)*" 41 r"Former source location:(.|\n)*" 42 r"Latter source location:", 43 ) 44 45 def test_check_export_model_diff_returns_diff_when_dynamic_controlflow_mismatch( 46 self, 47 ): 48 class UnexportableModel(torch.nn.Module): 49 def forward(self, x, y): 50 for i in range(x.size(0)): 51 y = x[i] + y 52 return y 53 54 test_input_groups = [ 55 ((torch.randn(2, 3), torch.randn(2, 3)), {}), 56 ((torch.randn(4, 3), torch.randn(2, 3)), {}), 57 ] 58 59 export_options = _experimental.ExportOptions( 60 input_names=["x", "y"], dynamic_axes={"x": [0]} 61 ) 62 results = verification.check_export_model_diff( 63 UnexportableModel(), test_input_groups, export_options 64 ) 65 self.assertRegex( 66 results, 67 r"Graph diff:(.|\n)*" 68 r"First diverging operator:(.|\n)*" 69 r"prim::Constant(.|\n)*" 70 r"Latter source location:(.|\n)*", 71 ) 72 73 def test_check_export_model_diff_returns_empty_when_correct_export(self): 74 class SupportedModel(torch.nn.Module): 75 def forward(self, x, y): 76 return x + y 77 78 test_input_groups = [ 79 ((torch.randn(2, 3), torch.randn(2, 3)), {}), 80 ((torch.randn(2, 3), torch.randn(2, 3)), {}), 81 ] 82 83 results = verification.check_export_model_diff( 84 SupportedModel(), test_input_groups 85 ) 86 self.assertEqual(results, "") 87 88 def test_compare_ort_pytorch_outputs_no_raise_with_acceptable_error_percentage( 89 self, 90 ): 91 ort_outs = [np.array([[1.0, 2.0], [3.0, 4.0]])] 92 pytorch_outs = [torch.tensor([[1.0, 2.0], [3.0, 1.0]])] 93 options = verification.VerificationOptions( 94 rtol=1e-5, 95 atol=1e-6, 96 check_shape=True, 97 check_dtype=False, 98 ignore_none=True, 99 acceptable_error_percentage=0.3, 100 ) 101 verification._compare_onnx_pytorch_outputs( 102 ort_outs, 103 pytorch_outs, 104 options, 105 ) 106 107 def test_compare_ort_pytorch_outputs_raise_without_acceptable_error_percentage( 108 self, 109 ): 110 ort_outs = [np.array([[1.0, 2.0], [3.0, 4.0]])] 111 pytorch_outs = [torch.tensor([[1.0, 2.0], [3.0, 1.0]])] 112 options = verification.VerificationOptions( 113 rtol=1e-5, 114 atol=1e-6, 115 check_shape=True, 116 check_dtype=False, 117 ignore_none=True, 118 acceptable_error_percentage=None, 119 ) 120 with self.assertRaises(AssertionError): 121 verification._compare_onnx_pytorch_outputs( 122 ort_outs, 123 pytorch_outs, 124 options, 125 ) 126 127 128@common_utils.instantiate_parametrized_tests 129class TestVerificationOnWrongExport(pytorch_test_common.ExportTestCase): 130 opset_version: int 131 132 def setUp(self): 133 super().setUp() 134 135 def incorrect_add_symbolic_function(g, self, other, alpha): 136 return self 137 138 self.opset_version = _constants.ONNX_DEFAULT_OPSET 139 torch.onnx.register_custom_op_symbolic( 140 "aten::add", 141 incorrect_add_symbolic_function, 142 opset_version=self.opset_version, 143 ) 144 145 def tearDown(self): 146 super().tearDown() 147 torch.onnx.unregister_custom_op_symbolic( 148 "aten::add", opset_version=self.opset_version 149 ) 150 151 @common_utils.parametrize( 152 "onnx_backend", 153 [ 154 common_utils.subtest( 155 verification.OnnxBackend.REFERENCE, 156 decorators=[ 157 unittest.skipIf( 158 version.Version(onnx.__version__) < version.Version("1.13"), 159 reason="Reference Python runtime was introduced in 'onnx' 1.13.", 160 ) 161 ], 162 ), 163 verification.OnnxBackend.ONNX_RUNTIME_CPU, 164 ], 165 ) 166 def test_verify_found_mismatch_when_export_is_wrong( 167 self, onnx_backend: verification.OnnxBackend 168 ): 169 class Model(torch.nn.Module): 170 def forward(self, x): 171 return x + 1 172 173 with self.assertRaisesRegex(AssertionError, ".*Tensor-likes are not close!.*"): 174 verification.verify( 175 Model(), 176 (torch.randn(2, 3),), 177 opset_version=self.opset_version, 178 options=verification.VerificationOptions(backend=onnx_backend), 179 ) 180 181 182@parameterized.parameterized_class( 183 [ 184 # TODO: enable this when ONNX submodule catches up to >= 1.13. 185 # {"onnx_backend": verification.OnnxBackend.ONNX}, 186 {"onnx_backend": verification.OnnxBackend.ONNX_RUNTIME_CPU}, 187 ], 188 class_name_func=lambda cls, 189 idx, 190 input_dicts: f"{cls.__name__}_{input_dicts['onnx_backend'].name}", 191) 192class TestFindMismatch(pytorch_test_common.ExportTestCase): 193 onnx_backend: verification.OnnxBackend 194 opset_version: int 195 graph_info: verification.GraphInfo 196 197 def setUp(self): 198 super().setUp() 199 self.opset_version = _constants.ONNX_DEFAULT_OPSET 200 201 def incorrect_relu_symbolic_function(g, self): 202 return g.op("Add", self, g.op("Constant", value_t=torch.tensor(1.0))) 203 204 torch.onnx.register_custom_op_symbolic( 205 "aten::relu", 206 incorrect_relu_symbolic_function, 207 opset_version=self.opset_version, 208 ) 209 210 class Model(torch.nn.Module): 211 def __init__(self) -> None: 212 super().__init__() 213 self.layers = torch.nn.Sequential( 214 torch.nn.Linear(3, 4), 215 torch.nn.ReLU(), 216 torch.nn.Linear(4, 5), 217 torch.nn.ReLU(), 218 torch.nn.Linear(5, 6), 219 ) 220 221 def forward(self, x): 222 return self.layers(x) 223 224 self.graph_info = verification.find_mismatch( 225 Model(), 226 (torch.randn(2, 3),), 227 opset_version=self.opset_version, 228 options=verification.VerificationOptions(backend=self.onnx_backend), 229 ) 230 231 def tearDown(self): 232 super().tearDown() 233 torch.onnx.unregister_custom_op_symbolic( 234 "aten::relu", opset_version=self.opset_version 235 ) 236 delattr(self, "opset_version") 237 delattr(self, "graph_info") 238 239 def test_pretty_print_tree_visualizes_mismatch(self): 240 f = io.StringIO() 241 with contextlib.redirect_stdout(f): 242 self.graph_info.pretty_print_tree() 243 self.assertExpected(f.getvalue()) 244 245 def test_preserve_mismatch_source_location(self): 246 mismatch_leaves = self.graph_info.all_mismatch_leaf_graph_info() 247 248 self.assertTrue(len(mismatch_leaves) > 0) 249 250 for leaf_info in mismatch_leaves: 251 f = io.StringIO() 252 with contextlib.redirect_stdout(f): 253 leaf_info.pretty_print_mismatch(graph=True) 254 self.assertRegex( 255 f.getvalue(), 256 r"(.|\n)*" r"aten::relu.*/torch/nn/functional.py:[0-9]+(.|\n)*", 257 ) 258 259 def test_find_all_mismatch_operators(self): 260 mismatch_leaves = self.graph_info.all_mismatch_leaf_graph_info() 261 262 self.assertEqual(len(mismatch_leaves), 2) 263 264 for leaf_info in mismatch_leaves: 265 self.assertEqual(leaf_info.essential_node_count(), 1) 266 self.assertEqual(leaf_info.essential_node_kinds(), {"aten::relu"}) 267 268 def test_find_mismatch_prints_correct_info_when_no_mismatch(self): 269 self.maxDiff = None 270 271 class Model(torch.nn.Module): 272 def forward(self, x): 273 return x + 1 274 275 f = io.StringIO() 276 with contextlib.redirect_stdout(f): 277 verification.find_mismatch( 278 Model(), 279 (torch.randn(2, 3),), 280 opset_version=self.opset_version, 281 options=verification.VerificationOptions(backend=self.onnx_backend), 282 ) 283 self.assertExpected(f.getvalue()) 284 285 def test_export_repro_for_mismatch(self): 286 mismatch_leaves = self.graph_info.all_mismatch_leaf_graph_info() 287 self.assertTrue(len(mismatch_leaves) > 0) 288 leaf_info = mismatch_leaves[0] 289 with tempfile.TemporaryDirectory() as temp_dir: 290 repro_dir = leaf_info.export_repro(temp_dir) 291 292 with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"): 293 options = verification.VerificationOptions(backend=self.onnx_backend) 294 verification.OnnxTestCaseRepro(repro_dir).validate(options) 295 296 297if __name__ == "__main__": 298 common_utils.run_tests() 299