1# mypy: allow-untyped-defs 2"""Functions to verify exported ONNX model is functionally equivalent to original PyTorch model. 3 4ONNX Runtime is required, and is used as the ONNX backend for export verification. 5""" 6 7from __future__ import annotations 8 9import contextlib 10import copy 11import dataclasses 12import datetime 13import difflib 14import enum 15import functools 16import io 17import itertools 18import os 19import tempfile 20import warnings 21from typing import Any, Callable, Collection, Mapping, Sequence, Tuple, Union 22 23import numpy as np 24 25import torch 26import torch._C._onnx as _C_onnx 27from torch import _C 28from torch.onnx import _constants, _experimental, _exporter_states, utils 29from torch.onnx._globals import GLOBALS 30from torch.onnx._internal import onnx_proto_utils 31from torch.types import Number 32 33 34_ORT_PROVIDERS = ("CPUExecutionProvider",) 35 36_NumericType = Union[Number, torch.Tensor, np.ndarray] 37_ModelType = Union[torch.nn.Module, torch.jit.ScriptModule] 38_InputArgsType = Union[torch.Tensor, Tuple[Any, ...]] 39_InputKwargsType = Mapping[str, Any] 40_OutputsType = Union[Sequence[_NumericType], Sequence] 41 42 43class OnnxBackend(enum.Enum): 44 """Enum class for ONNX backend used for export verification.""" 45 46 REFERENCE = "ONNXReferenceEvaluator" 47 ONNX_RUNTIME_CPU = "CPUExecutionProvider" 48 ONNX_RUNTIME_CUDA = "CUDAExecutionProvider" 49 50 51@dataclasses.dataclass 52class VerificationOptions: 53 """Options for ONNX export verification. 54 55 Attributes: 56 flatten: If True, unpack nested list/tuple/dict inputs into a flattened list of 57 Tensors for ONNX. Set this to False if nested structures are to be preserved 58 for ONNX, which is usually the case with exporting ScriptModules. Default True. 59 ignore_none: Whether to ignore None type in torch output, which is usually the 60 case with tracing. Set this to False, if torch output should keep None type, 61 which is usually the case with exporting ScriptModules. Default to True. 62 check_shape: Whether to check the shapes between PyTorch and ONNX Runtime outputs 63 are exactly the same. Set this to False to allow output shape broadcasting. 64 Default to True. 65 check_dtype: Whether to check the dtypes between PyTorch and ONNX Runtime outputs 66 are consistent. Default to True. 67 backend: ONNX backend for verification. Default to OnnxBackend.ONNX_RUNTIME_CPU. 68 rtol: relative tolerance in comparison between ONNX and PyTorch outputs. 69 atol: absolute tolerance in comparison between ONNX and PyTorch outputs. 70 remained_onnx_input_idx: If provided, only the specified inputs will be passed 71 to the ONNX model. Supply a list when there are unused inputs in the model. 72 Since unused inputs will be removed in the exported ONNX model, supplying 73 all inputs will cause an error on unexpected inputs. This parameter tells 74 the verifier which inputs to pass into the ONNX model. 75 acceptable_error_percentage: acceptable percentage of element mismatches in comparison. 76 It should be a float of value between 0.0 and 1.0. 77 """ 78 79 flatten: bool = True 80 ignore_none: bool = True 81 check_shape: bool = True 82 check_dtype: bool = True 83 backend: OnnxBackend = OnnxBackend.ONNX_RUNTIME_CPU 84 rtol: float = 1e-3 85 atol: float = 1e-7 86 remained_onnx_input_idx: Sequence[int] | None = None 87 acceptable_error_percentage: float | None = None 88 89 90def _flatten_tuples(elem): 91 flattened = [] 92 for t in elem: 93 if isinstance(t, tuple): 94 flattened.extend(_flatten_tuples(t)) 95 else: 96 flattened.append(t) 97 return flattened 98 99 100# TODO(justinchuby): Add type checking by narrowing down the return type when input is None 101def _to_numpy(elem) -> list | np.ndarray: 102 if isinstance(elem, torch.Tensor): 103 if elem.requires_grad: 104 return elem.detach().cpu().numpy() 105 else: 106 return elem.cpu().numpy() 107 elif isinstance(elem, (list, tuple)): 108 return [_to_numpy(inp) for inp in elem] 109 elif isinstance(elem, (bool, int, float)): 110 return np.array(elem) 111 elif isinstance(elem, dict): 112 flattened = [] 113 for k in elem: 114 flattened.extend([_to_numpy(k), _to_numpy(elem[k])]) 115 return flattened 116 return elem 117 118 119def _inline_flatten_list(inputs, res_list) -> list: 120 for i in inputs: 121 res_list.append(i) if not isinstance( 122 i, (list, tuple) 123 ) else _inline_flatten_list(i, res_list) 124 return res_list 125 126 127def _unpack_to_numpy(values, cast_onnx_accepted=True) -> list: 128 value_unpacked = [] 129 for value in values: 130 value_unpacked.extend( 131 utils.unpack_quantized_tensor(value, cast_onnx_accepted=cast_onnx_accepted) 132 ) 133 return [_to_numpy(v) for v in value_unpacked] 134 135 136def _run_onnx(onnx_session, inputs) -> _OutputsType: 137 kw_inputs = {} 138 if inputs and isinstance(inputs[-1], dict): 139 kw_inputs = inputs[-1] 140 inputs = inputs[:-1] 141 inputs = _unpack_to_numpy(_flatten_tuples(inputs)) 142 ort_inputs = {} 143 for input_name, input in kw_inputs.items(): 144 ort_inputs[input_name] = _to_numpy(input) 145 inputs = _to_numpy(inputs) 146 if hasattr(onnx_session, "get_inputs"): 147 # onnxruntime.InferenceSession 148 input_names = [i.name for i in onnx_session.get_inputs()] 149 elif hasattr(onnx_session, "input_names"): 150 # onnx.reference.ReferenceEvaluator 151 input_names = onnx_session.input_names 152 else: 153 raise ValueError(f"Unknown ONNX backend type: {type(onnx_session)}.") 154 155 for i, input in enumerate(inputs): 156 if i == len(input_names) or input_names[i] in ort_inputs: 157 raise ValueError( 158 f"got too many positional inputs. inputs: {inputs}. kw_inputs: {kw_inputs}. " 159 f"input names: {input_names}." 160 ) 161 ort_inputs[input_names[i]] = input 162 onnx_outs = onnx_session.run(None, ort_inputs) 163 return onnx_outs 164 165 166def _ort_session( 167 model: str | io.BytesIO, ort_providers: Sequence[str] = _ORT_PROVIDERS 168): 169 try: 170 import onnxruntime # type: ignore[import] 171 except ImportError as e: 172 raise ImportError("onnxruntime is required for export verification.") from e 173 174 if ort_providers is None: 175 ort_providers = _ORT_PROVIDERS 176 177 session_options = onnxruntime.SessionOptions() 178 # suppress ort warnings. 179 # 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2. 180 session_options.log_severity_level = 3 181 ort_session = onnxruntime.InferenceSession( 182 model if isinstance(model, str) else model.getvalue(), 183 session_options, 184 providers=ort_providers, 185 ) 186 return ort_session 187 188 189def _onnx_reference_evaluator_session(model: str | io.BytesIO): 190 try: 191 import onnx 192 from onnx import reference as onnx_reference # type: ignore[attr-defined] 193 except ImportError as exc: 194 raise ImportError("onnx >= 1.13 is required for reference evaluator.") from exc 195 196 proto = ( 197 onnx.load(model) # type: ignore[attr-defined] 198 if isinstance(model, str) 199 else onnx.load_model_from_string(model.getvalue()) # type: ignore[attr-defined] 200 ) 201 onnx_session = onnx_reference.ReferenceEvaluator(proto) 202 return onnx_session 203 204 205def _onnx_backend_session(model: str | io.BytesIO, backend: OnnxBackend): 206 if backend == OnnxBackend.REFERENCE: 207 onnx_session = _onnx_reference_evaluator_session(model) 208 elif backend in {OnnxBackend.ONNX_RUNTIME_CPU, OnnxBackend.ONNX_RUNTIME_CUDA}: 209 onnx_session = _ort_session(model, (backend.value,)) 210 else: 211 raise ValueError(f"Unsupported backend: {backend}") 212 return onnx_session 213 214 215def _compare_onnx_pytorch_outputs_in_np( 216 onnx_outs: _OutputsType, 217 pt_outs: _OutputsType, 218 options: VerificationOptions, 219): 220 assert ( 221 len(onnx_outs) == len(pt_outs) 222 ), f"Number of outputs differ ONNX runtime: ({len(onnx_outs)}) PyTorch: ({len(pt_outs)})" 223 acceptable_error_percentage = options.acceptable_error_percentage 224 if acceptable_error_percentage and ( 225 acceptable_error_percentage > 1.0 or acceptable_error_percentage < 0.0 226 ): 227 raise ValueError( 228 "If set, acceptable_error_percentage should be between 0.0 and 1.0" 229 ) 230 231 for ort_out, pt_out in zip(onnx_outs, pt_outs): 232 try: 233 # TODO: Remove `check_shape` option once every shape inconsistent issue is addressed. 234 if not options.check_shape: 235 # Allow different but broadcastable output shapes. 236 ort_out, pt_out = np.broadcast_arrays(ort_out, pt_out) 237 torch.testing.assert_close( 238 ort_out, 239 pt_out, 240 rtol=options.rtol, 241 atol=options.atol, 242 check_dtype=options.check_dtype, 243 equal_nan=True, 244 ) 245 except AssertionError as e: 246 if acceptable_error_percentage: 247 error_percentage = 1 - np.sum( 248 np.isclose(ort_out, pt_out, rtol=options.rtol, atol=options.atol) 249 ) / np.prod(ort_out.shape) 250 if error_percentage <= acceptable_error_percentage: 251 warnings.warn( 252 f"Suppressed AssertionError:\n{e}.\n" 253 f"Error percentage {error_percentage} " 254 f"within acceptable range {acceptable_error_percentage}." 255 ) 256 continue 257 if ort_out.dtype == np.uint8 or ort_out.dtype == np.int8: 258 warnings.warn("ONNX output is quantized") 259 if pt_out.dtype == np.uint8 or pt_out.dtype == np.int8: 260 warnings.warn("PyTorch output is quantized") 261 raise 262 263 264def _compare_onnx_pytorch_outputs( 265 onnx_outs: _OutputsType, 266 pt_outs: Any, 267 options: VerificationOptions, 268): 269 """ 270 Compare ONNX and PyTorch outputs. 271 272 Args: 273 onnx_outs: outputs from ONNX backend. 274 pt_outs: outputs from PyTorch. 275 options: options for verification. 276 277 Raises: 278 AssertionError: if outputs from ONNX model and PyTorch model are not 279 equal up to specified precision. 280 ValueError: if arguments provided are invalid. 281 """ 282 if options.ignore_none: 283 # torch.jit._flatten filters None type 284 pt_outs, _ = torch.jit._flatten(pt_outs) 285 else: 286 pt_outs = _inline_flatten_list([pt_outs], []) 287 pt_outs_np = _unpack_to_numpy(pt_outs, cast_onnx_accepted=False) 288 onnx_outs = _inline_flatten_list(onnx_outs, []) 289 _compare_onnx_pytorch_outputs_in_np(onnx_outs, pt_outs_np, options) 290 291 292def _prepare_input_for_pytorch(args, kwargs): 293 """Prepare input for PyTorch model execution. 294 295 Any future changes/formatting to the input before dispatching to the PyTorch 296 model should be made in this function. 297 298 Args: 299 args: positional arguments for PyTorch model forward method. 300 kwargs: keyword arguments for PyTorch model forward method. 301 302 Returns: 303 args: positional arguments for PyTorch model forward method. 304 kwargs: keyword arguments for PyTorch model forward method. 305 """ 306 if isinstance(args, (torch.Tensor, dict)): 307 args = (args,) 308 # In-place operators will update input tensor data as well. 309 # Thus inputs are replicated before every forward call. 310 args = copy.deepcopy(args) 311 if kwargs: 312 kwargs = copy.deepcopy(kwargs) 313 else: 314 kwargs = {} 315 return args, kwargs 316 317 318def _prepare_input_for_export(args, kwargs): 319 """Prepare input for ONNX model export. 320 321 Any future changes/formatting to the input before dispatching to the 322 :func:`torch.onnx.export` api should be made in this function. 323 324 Args: 325 args: positional arguments for PyTorch model forward method. 326 kwargs: keyword arguments for PyTorch model forward method. 327 328 Returns: 329 onnx_inputs: positional arguments for ONNX model export, as `args` in 330 :func:`torch.onnx.export`. 331 """ 332 args, kwargs = _prepare_input_for_pytorch(args, kwargs) 333 if not kwargs and len(args) > 0 and isinstance(args[-1], dict): 334 onnx_inputs = args + ({},) 335 elif kwargs: 336 onnx_inputs = args + (kwargs,) 337 else: 338 onnx_inputs = args 339 return onnx_inputs 340 341 342def _prepare_input_for_onnx( 343 args, kwargs, remained_onnx_input_idx: Sequence[int] | None, flatten: bool 344): 345 """Prepare input for ONNX model execution in ONNX backend. 346 347 Any future changes/formatting to the input before dispatching to the ONNX backend 348 run should be made in this function. 349 350 Args: 351 args: positional arguments for PyTorch model forward method. 352 kwargs: keyword arguments for PyTorch model forward method. 353 remained_onnx_input_idx: indices of inputs to be used for ONNX model execution. 354 flatten: whether to flatten the input before dispatching to the ONNX model execution. 355 356 Returns: 357 onnx_inputs: positional arguments for ONNX model execution in ONNX backend. 358 """ 359 onnx_inputs = _prepare_input_for_export(args, kwargs) 360 if flatten: 361 onnx_inputs, _ = torch.jit._flatten(onnx_inputs) 362 elif onnx_inputs and onnx_inputs[-1] == {}: 363 # Handle empty kwargs (normally removed by flatten). 364 onnx_inputs = onnx_inputs[:-1] 365 if remained_onnx_input_idx is not None: 366 return [onnx_inputs[i] for i in remained_onnx_input_idx] 367 else: 368 return onnx_inputs 369 370 371def _try_clone_model(model): 372 """Used for preserving original model in case forward mutates model states.""" 373 try: 374 return copy.deepcopy(model) 375 except Exception: 376 warnings.warn( 377 "Failed to clone model. Model state might be mutated during verification." 378 ) 379 return model 380 381 382def _compare_onnx_pytorch_model( 383 pt_model: _ModelType, 384 onnx_model_f: str | io.BytesIO, 385 input_args: _InputArgsType, 386 input_kwargs: _InputKwargsType | None, 387 additional_test_inputs: Sequence[_InputArgsType] | None, 388 options: VerificationOptions, 389): 390 """Compare outputs from ONNX model runs with outputs from PyTorch model runs. 391 392 Args: 393 pt_model: PyTorch model. 394 onnx_model_f: ONNX model file path or file-like object. 395 input_args: positional arguments for PyTorch model forward method. 396 input_kwargs: keyword arguments for PyTorch model forward method. 397 additional_test_inputs: additional positional arguments for PyTorch model 398 forward method. 399 options: options for verification. 400 401 Raises: 402 AssertionError: if outputs from ONNX model and PyTorch model are not 403 equal up to specified precision. 404 """ 405 onnx_session = _onnx_backend_session(onnx_model_f, options.backend) 406 407 def compare_onnx_pytorch_model_with_input(input_args, input_kwargs): 408 pt_args, pt_kwargs = _prepare_input_for_pytorch(input_args, input_kwargs) 409 # TODO: remove this and treat mutating model separately. See #77679 410 pt_model_copy = _try_clone_model(pt_model) 411 pt_outs = pt_model_copy(*pt_args, **pt_kwargs) 412 413 onnx_inputs = _prepare_input_for_onnx( 414 input_args, input_kwargs, options.remained_onnx_input_idx, options.flatten 415 ) 416 417 onnx_outs = _run_onnx(onnx_session, onnx_inputs) 418 419 _compare_onnx_pytorch_outputs( 420 onnx_outs=onnx_outs, 421 pt_outs=pt_outs, 422 options=options, 423 ) 424 425 compare_onnx_pytorch_model_with_input(input_args, input_kwargs) 426 427 if additional_test_inputs: 428 for test_input_args in additional_test_inputs: 429 compare_onnx_pytorch_model_with_input(test_input_args, {}) 430 431 432class _GraphDiff: 433 """A class to represent the difference between two graphs.""" 434 435 def __init__(self, graph_a: _C.Graph, graph_b: _C.Graph): 436 """Construct a _GraphDiff object. 437 438 Args: 439 graph_a (_C.Graph): First graph to compare. 440 graph_b (_C.Graph): Second graph to compare. 441 """ 442 self.graph_a = graph_a 443 self.graph_b = graph_b 444 445 def __str__(self): 446 """See function :func:`diff_report`.""" 447 return self.diff_report() 448 449 def _indent(self, lines: str) -> str: 450 return "\n".join(["\t" + line for line in lines.splitlines()]) 451 452 def diff_report(self) -> str: 453 """Return a string representation of the graph difference. 454 455 The report shows the first pair of nodes that diverges. It also shows the source 456 location of the pair of nodes. 457 458 Returns: 459 graph_diff_report (str): A string representation of the graph difference. 460 """ 461 graph_a = self.graph_a 462 graph_b = self.graph_b 463 464 graph_a_str = str(graph_a) 465 graph_b_str = str(graph_b) 466 467 if graph_a_str == graph_b_str: 468 return "" 469 470 graph_diff = difflib.ndiff( 471 graph_a_str.splitlines(True), graph_b_str.splitlines(True) 472 ) 473 graph_diff_report = ["Graph diff:", self._indent("".join(graph_diff))] 474 475 for node_a, node_b in itertools.zip_longest(graph_a.nodes(), graph_b.nodes()): 476 if str(node_a) != str(node_b): 477 graph_diff_report.append("First diverging operator:") 478 node_diff = difflib.ndiff( 479 str(node_a).splitlines(True), str(node_b).splitlines(True) 480 ) 481 source_printout = ["node diff:", self._indent("".join(node_diff))] 482 483 stack_a = node_a.sourceRange() if node_a else None 484 if stack_a: 485 source_printout.extend( 486 ["Former source location:", self._indent(str(stack_a))] 487 ) 488 stack_b = node_b.sourceRange() if node_b else None 489 if stack_b: 490 source_printout.extend( 491 ["Latter source location:", self._indent(str(stack_b))] 492 ) 493 494 graph_diff_report.extend(source_printout) 495 496 break 497 498 return "\n".join(graph_diff_report) 499 500 501def _check_graph_diff( 502 model: torch.nn.Module | torch.jit.ScriptModule, 503 test_input_groups: Sequence[tuple[tuple[Any, ...], Mapping[str, Any]]], 504 export_options: _experimental.ExportOptions, 505 model_to_graph_func: Callable[ 506 [ 507 torch.nn.Module, 508 tuple[Any, ...], 509 Mapping[str, Any], 510 _experimental.ExportOptions, 511 ], 512 _C.Graph, 513 ], 514) -> str: 515 """Check if graph produced by `model_to_graph_func` is the same across `test_input_groups`. 516 517 Args: 518 model: See :func:`check_export_model_diff`. 519 test_input_groups: See :func:`check_export_model_diff`. 520 export_options: See :func:`check_export_model_diff`. 521 model_to_graph_func: A function to convert a PyTorch model to a JIT IR graph. 522 523 Returns: 524 graph_diff_report (str): A string representation of the graph difference. 525 """ 526 if len(test_input_groups) < 2: 527 raise ValueError("Need at least two groups of test inputs to compare.") 528 529 ref_jit_graph = None 530 for args, kwargs in test_input_groups: 531 jit_graph = model_to_graph_func(model, args, kwargs, export_options) 532 if ref_jit_graph is None: 533 ref_jit_graph = jit_graph 534 continue 535 536 graph_diff_report = _GraphDiff(ref_jit_graph, jit_graph).diff_report() 537 if graph_diff_report: 538 return graph_diff_report 539 return "" 540 541 542def _traced_graph_from_model( 543 model: torch.nn.Module | torch.jit.ScriptModule, 544 args: tuple[Any, ...], 545 kwargs: Mapping[str, Any], 546 export_options: _experimental.ExportOptions, 547) -> _C.Graph: 548 """As part of the ONNX export steps, create a traced JIT graph from a PyTorch model. 549 550 Args: 551 model: See :func:`check_export_model_diff`. 552 args: See :func:`check_export_model_diff`. 553 kwargs: See :func:`check_export_model_diff`. 554 export_options: See :func:`check_export_model_diff`. 555 556 Returns: 557 jit_graph (_C.Graph): A traced JIT graph. 558 """ 559 training = export_options.training 560 verbose = export_options.verbose 561 562 with utils.exporter_context(model, training, verbose): 563 export_inputs = _prepare_input_for_export(args, kwargs) 564 model = utils._pre_trace_quant_model(model, export_inputs) 565 jit_graph, _, _, _ = utils._create_jit_graph(model, export_inputs) 566 return jit_graph 567 568 569def _onnx_graph_from_model( 570 model: torch.nn.Module | torch.jit.ScriptModule, 571 args: tuple[Any, ...], 572 kwargs: Mapping[str, Any], 573 export_options: _experimental.ExportOptions, 574) -> _C.Graph: 575 """As part of the ONNX export steps, export an ONNX JIT graph from a PyTorch model. 576 577 Args: 578 model: See :func:`check_export_model_diff`. 579 args: See :func:`check_export_model_diff`. 580 kwargs: See :func:`check_export_model_diff`. 581 export_options: See :func:`check_export_model_diff`. 582 583 Returns: 584 onnx_graph (_C.Graph): An ONNX JIT graph. 585 """ 586 # TODO: refactor utils.py to remove duplicated code of context setup. See #78834 587 opset_version = export_options.opset_version 588 operator_export_type = export_options.operator_export_type 589 export_modules_as_functions = export_options.export_modules_as_functions 590 training = export_options.training 591 verbose = export_options.verbose 592 dynamic_axes = export_options.dynamic_axes 593 input_names = export_options.input_names 594 output_names = export_options.output_names 595 596 if opset_version is None: 597 opset_version = _constants.ONNX_DEFAULT_OPSET 598 599 utils._setup_trace_module_map(model, export_modules_as_functions) 600 601 if not operator_export_type: 602 operator_export_type = _C_onnx.OperatorExportTypes.ONNX 603 604 GLOBALS.export_onnx_opset_version = opset_version 605 GLOBALS.operator_export_type = operator_export_type 606 607 with utils.exporter_context(model, training, verbose): 608 do_constant_folding = utils._decide_constant_folding( 609 export_options.do_constant_folding, operator_export_type, training 610 ) 611 612 if dynamic_axes is None: 613 dynamic_axes = {} 614 utils._validate_dynamic_axes(dynamic_axes, model, input_names, output_names) 615 616 export_inputs = _prepare_input_for_export(args, kwargs) 617 export_inputs = utils._decide_input_format(model, export_inputs) 618 onnx_graph, _, _ = utils._model_to_graph( 619 model, 620 export_inputs, 621 verbose, 622 input_names, 623 output_names, 624 operator_export_type, 625 do_constant_folding, 626 training=training, 627 dynamic_axes=dynamic_axes, 628 ) 629 630 return onnx_graph 631 632 633def _onnx_graph_from_aten_graph( 634 graph: torch.Graph, 635 export_options: _experimental.ExportOptions, 636 params_dict: dict[str, Any] | None = None, 637) -> tuple[torch.Graph, dict[str, Any]]: 638 if params_dict is None: 639 params_dict = {} 640 operator_export_type = export_options.operator_export_type 641 dynamic_axes = export_options.dynamic_axes or {} 642 input_names = export_options.input_names 643 training = export_options.training 644 do_constant_folding = export_options.do_constant_folding 645 opset_version = export_options.opset_version or _constants.ONNX_DEFAULT_OPSET 646 647 GLOBALS.export_onnx_opset_version = opset_version 648 GLOBALS.operator_export_type = operator_export_type 649 650 do_constant_folding = utils._decide_constant_folding( 651 do_constant_folding, operator_export_type, training 652 ) 653 654 # TODO: Below is doing aten graph to onnx. It should be abstracted as a 655 # function in torch/onnx/utils.py. 656 graph = graph.copy() 657 graph = utils._optimize_graph( 658 graph, 659 operator_export_type, 660 params_dict=params_dict, 661 dynamic_axes=dynamic_axes, 662 input_names=input_names, 663 ) 664 665 if training is None or training == _C_onnx.TrainingMode.EVAL: 666 params_dict = torch._C._jit_pass_onnx_eval_peephole(graph, params_dict) 667 668 if ( 669 do_constant_folding 670 and opset_version >= _constants.ONNX_CONSTANT_FOLDING_MIN_OPSET 671 ): 672 params_dict = _C._jit_pass_onnx_constant_fold(graph, params_dict, opset_version) 673 _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) 674 675 if GLOBALS.onnx_shape_inference: 676 _C._jit_pass_onnx_graph_shape_type_inference(graph, params_dict, opset_version) 677 678 params_dict = _C._jit_pass_onnx_eliminate_unused_items(graph, params_dict) 679 680 # For ONNX opset < 9, constants only have three data types: float16, float, double. 681 # In this pass transform constants of other data types to float/double + cast operator. 682 if opset_version < 9: 683 _C._jit_pass_onnx_cast_all_constant_to_floating(graph) 684 685 params_dict = _C._jit_pass_filter_non_tensor_arguments(params_dict) 686 _C._jit_decay_packed_param_input_types(graph) 687 688 _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) 689 690 if export_options.verbose: 691 print("ONNX graph: ", graph) 692 693 return graph, params_dict 694 695 696def _onnx_proto_from_onnx_graph( 697 onnx_graph: torch.Graph, 698 export_options: _experimental.ExportOptions, 699 params_dict: dict[str, Any], 700) -> tuple[bytes, Mapping[str, bytes]]: 701 opset_version = export_options.opset_version or _constants.ONNX_DEFAULT_OPSET 702 dynamic_axes = export_options.dynamic_axes or {} 703 operator_export_type = export_options.operator_export_type 704 val_keep_init_as_ip = utils._decide_keep_init_as_input( 705 export_options.keep_initializers_as_inputs, 706 operator_export_type, 707 opset_version, 708 ) 709 val_add_node_names = utils._decide_add_node_names(True, operator_export_type) 710 custom_opsets = export_options.custom_opsets or {} 711 712 proto, export_map, _, _ = onnx_graph._export_onnx( # type: ignore[attr-defined] 713 params_dict, 714 opset_version, 715 dynamic_axes, 716 False, 717 operator_export_type, 718 not export_options.verbose, 719 val_keep_init_as_ip, 720 custom_opsets, 721 val_add_node_names, 722 "", 723 {}, 724 ) 725 726 return proto, export_map 727 728 729def check_export_model_diff( 730 model: torch.nn.Module | torch.jit.ScriptModule, 731 test_input_groups: Sequence[tuple[tuple[Any, ...], Mapping[str, Any]]], 732 export_options: _experimental.ExportOptions | None = None, 733) -> str: 734 """Verify exported model discrepancy between different groups of inputs. 735 736 A graph is exported for each group of inputs. The exported graphs are then compared 737 to each other, and discrepancies of first pair of nodes are reported. This function 738 first checks the jit graph. If no discrepancies were found, it then checks the onnx 739 graph. 740 741 Unless otherwise specified, the jit/ONNX graph is expected to be the same, regardless 742 of the inputs used for exporting. A discrepancy implies the graph exported is 743 not accurate when run on other groups of inputs, which will typically results in 744 runtime errors or mismatching output. 745 746 Args: 747 model (torch.nn.Module or torch.jit.ScriptModule): The model to be exported. 748 test_input_groups (Sequence[Tuple[Tuple[Any, ...], Mapping[str, Any]]]): A sequence 749 of input groups to be used to export the model. Each input group is a pair of 750 (args, kwargs). 751 export_options (_experimental.ExportOptions, optional): An _experimental.ExportOptions 752 object that controls the export behavior. 753 754 Returns: 755 str: A string containing the diff of the exported models. 756 """ 757 export_options = ( 758 _experimental.ExportOptions() if export_options is None else export_options 759 ) 760 761 jit_diff_report = _check_graph_diff( 762 model, test_input_groups, export_options, _traced_graph_from_model 763 ) 764 if jit_diff_report: 765 return jit_diff_report 766 767 return _check_graph_diff( 768 model, test_input_groups, export_options, _onnx_graph_from_model 769 ) 770 771 772def verify( 773 model: _ModelType, 774 input_args: _InputArgsType, 775 input_kwargs: _InputKwargsType | None = None, 776 do_constant_folding: bool = True, 777 dynamic_axes: Mapping[str, Mapping[int, str] | Mapping[str, Sequence[int]]] 778 | None = None, 779 input_names: Sequence[str] | None = None, 780 output_names: Sequence[str] | None = None, 781 training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL, 782 opset_version: int | None = None, 783 keep_initializers_as_inputs: bool = True, 784 verbose: bool = False, 785 fixed_batch_size: bool = False, 786 use_external_data: bool = False, 787 additional_test_inputs: Sequence[_InputArgsType] | None = None, 788 options: VerificationOptions | None = None, 789): 790 """Verify model export to ONNX against original PyTorch model. 791 792 Args: 793 model (torch.nn.Module or torch.jit.ScriptModule): See :func:`torch.onnx.export`. 794 input_args (tuple): See :func:`torch.onnx.export`. 795 input_kwargs (dict): See :func:`torch.onnx.export`. 796 do_constant_folding (bool, optional): See :func:`torch.onnx.export`. 797 dynamic_axes (dict, optional): See :func:`torch.onnx.export`. 798 input_names (list, optional): See :func:`torch.onnx.export`. 799 output_names (list, optional): See :func:`torch.onnx.export`. 800 training (torch.onnx.TrainingMode): See :func:`torch.onnx.export`. 801 opset_version (int, optional): See :func:`torch.onnx.export`. 802 keep_initializers_as_inputs (bool, optional): See :func:`torch.onnx.export`. 803 verbose (bool, optional): See :func:`torch.onnx.export`. 804 fixed_batch_size (bool, optional): Legacy argument, used only by rnn test cases. 805 use_external_data (bool, optional): Explicitly specify whether to export the 806 model with external data. 807 additional_test_inputs (list, optional): List of tuples. Each tuple is a group of 808 input arguments to test. Currently only *args are supported. 809 options (_VerificationOptions, optional): A _VerificationOptions object that 810 controls the verification behavior. 811 812 Raises: 813 AssertionError: if outputs from ONNX model and PyTorch model are not 814 equal up to specified precision. 815 ValueError: if arguments provided are invalid. 816 """ 817 if options is None: 818 options = VerificationOptions() 819 820 if training == torch.onnx.TrainingMode.TRAINING: 821 model.train() 822 elif training == torch.onnx.TrainingMode.EVAL: 823 model.eval() 824 with torch.no_grad(), contextlib.ExitStack() as stack: 825 model_f: str | io.BytesIO = io.BytesIO() 826 if use_external_data: 827 tmpdir_path = stack.enter_context(tempfile.TemporaryDirectory()) 828 model_f = os.path.join(tmpdir_path, "model.onnx") 829 830 inputs_for_export = _prepare_input_for_export(input_args, input_kwargs) 831 832 # TODO(#77679): remove this and treat mutating model separately. 833 model_copy = _try_clone_model(model) 834 utils._export( 835 model, 836 inputs_for_export, 837 model_f, 838 opset_version=opset_version, 839 do_constant_folding=do_constant_folding, 840 keep_initializers_as_inputs=keep_initializers_as_inputs, 841 dynamic_axes=dynamic_axes, 842 input_names=input_names, 843 output_names=output_names, 844 fixed_batch_size=fixed_batch_size, 845 training=training, 846 verbose=verbose, 847 ) 848 849 _compare_onnx_pytorch_model( 850 pt_model=model_copy, 851 onnx_model_f=model_f, 852 input_args=input_args, 853 input_kwargs=input_kwargs, 854 additional_test_inputs=additional_test_inputs, 855 options=options, 856 ) 857 858 859def verify_aten_graph( 860 graph: torch.Graph, 861 input_args: tuple[Any, ...], 862 export_options: _experimental.ExportOptions, 863 params_dict: dict[str, Any] | None = None, 864 verification_options: VerificationOptions | None = None, 865) -> tuple[AssertionError | None, torch.Graph, _OutputsType, _OutputsType]: 866 if verification_options is None: 867 verification_options = VerificationOptions() 868 if params_dict is None: 869 params_dict = {} 870 871 original_jit_graph = graph 872 graph = graph.copy() 873 874 # Execute aten graph and get reference torch jit outputs. 875 graph_inputs = list(graph.inputs()) 876 jit_inputs = tuple([arg for arg in input_args if arg is not None]) 877 weights = [params_dict[v.debugName()] for v in graph_inputs[len(jit_inputs) :]] 878 assert all(w is not None for w in weights) 879 # TODO: Only copy the argument if mutation is detected in Graph. 880 jit_inputs = copy.deepcopy(jit_inputs) 881 jit_input_and_parameters = jit_inputs + tuple(weights) 882 jit_outs = torch._C._jit_interpret_graph(graph, jit_input_and_parameters) # type: ignore[attr-defined] 883 if not isinstance(jit_outs, (list, tuple)): 884 jit_outs = [jit_outs] 885 886 # Convert aten graph to onnx graph. 887 graph, onnx_params_dict = _onnx_graph_from_aten_graph( 888 graph, export_options, params_dict 889 ) 890 891 proto, export_map = _onnx_proto_from_onnx_graph( 892 graph, export_options, onnx_params_dict 893 ) 894 model_f: str | io.BytesIO = io.BytesIO() 895 export_type = _exporter_states.ExportTypes.PROTOBUF_FILE 896 onnx_proto_utils._export_file(proto, model_f, export_type, export_map) 897 898 # NOTE: Verification is unstable. Try catch to emit information for debugging. 899 try: 900 # NOTE: Input might be dce'ed, so we need to remove those from the input args. 901 new_input_names = {v.debugName() for v in graph.inputs()} 902 new_input_args = [] 903 for v, arg in zip(original_jit_graph.inputs(), input_args): 904 if v.debugName() in new_input_names: 905 new_input_args.append(arg) 906 input_args = tuple(new_input_args) 907 908 onnx_inputs = _prepare_input_for_onnx( 909 input_args, 910 {}, 911 verification_options.remained_onnx_input_idx, 912 verification_options.flatten, 913 ) 914 915 onnx_session = _onnx_backend_session(model_f, verification_options.backend) 916 onnx_outs = _run_onnx(onnx_session, onnx_inputs) 917 del onnx_session # To free device memory 918 919 try: 920 _compare_onnx_pytorch_outputs( 921 onnx_outs=onnx_outs, 922 pt_outs=jit_outs, 923 options=verification_options, 924 ) 925 except AssertionError as e: 926 return e, graph, jit_outs, onnx_outs 927 928 return None, graph, jit_outs, onnx_outs 929 930 except Exception as e: 931 print("Unexpected error during verification.") 932 print("jit graph: ", original_jit_graph) 933 print("onnx graph: ", graph) 934 raise e 935 936 937class GraphInfoPrettyPrinter: 938 graph_info: GraphInfo | None 939 upper_printer: GraphInfoPrettyPrinter | None 940 lower_printer: GraphInfoPrettyPrinter | None 941 942 graph_str_lambdas: Mapping[int, str] 943 connector_str_lambdas: Mapping[int, str] 944 children_str_lambdas: Mapping[int, str] 945 946 def __init__(self, graph_info: GraphInfo | None): 947 self.graph_info = graph_info 948 if ( 949 graph_info is not None 950 and graph_info.upper_graph_info is not None 951 and graph_info.lower_graph_info is not None 952 ): 953 self.upper_printer = GraphInfoPrettyPrinter(graph_info.upper_graph_info) 954 self.lower_printer = GraphInfoPrettyPrinter(graph_info.lower_graph_info) 955 else: 956 self.upper_printer = None 957 self.lower_printer = None 958 959 def _total_rows(self) -> int: 960 if self.graph_info is None: 961 return 1 962 if self.upper_printer and self.lower_printer: 963 return ( 964 self.upper_printer._total_rows() + self.lower_printer._total_rows() + 1 965 ) 966 return 2 # Two lines: node count + id. 967 968 def _node_count_segment_str(self) -> str: 969 if self.graph_info is None: 970 return "..." 971 node_count = self.graph_info.essential_node_count() 972 has_mismatch = self.graph_info.has_mismatch() 973 error_node_kind = ( 974 f"({self.graph_info.essential_node_kinds().pop()})" 975 if node_count == 1 and has_mismatch 976 else "" 977 ) 978 979 return f"{node_count} {'X' if has_mismatch else chr(0x2713)} {error_node_kind}" 980 981 def _graph_id_segment_str(self) -> str: 982 if self.graph_info is None: 983 return "" 984 return f"id: {self.graph_info.id}" 985 986 def _max_segment_columns(self) -> int: 987 return max( 988 map(len, (self._node_count_segment_str(), self._graph_id_segment_str())) 989 ) 990 991 def _graph_segment_str_at_line(self, line: int) -> str: 992 """Get the string representation of the graph segment at the given line.""" 993 if line == 0: 994 result_str = self._node_count_segment_str() 995 result_str += " " * (self._max_segment_columns() - len(result_str)) 996 return result_str 997 if line == 1: 998 result_str = self._graph_id_segment_str() 999 result_str += " " * (self._max_segment_columns() - len(result_str)) 1000 return result_str 1001 if 0 <= line < self._total_rows(): 1002 return " " * self._max_segment_columns() 1003 return "" 1004 1005 def _connector_segment_str_at_line(self, line: int) -> str: 1006 """Get the connector segment string at the given line.""" 1007 if self.upper_printer is None and self.lower_printer is None: 1008 return "" 1009 upper_total_rows = self.upper_printer._total_rows() if self.upper_printer else 1 1010 lower_total_rows = self.lower_printer._total_rows() if self.lower_printer else 1 1011 if line == 0: 1012 return " __" 1013 elif line < upper_total_rows + 1: 1014 return " | " 1015 elif line == upper_total_rows + 1: 1016 return " |__" 1017 elif line < upper_total_rows + lower_total_rows + 1: 1018 return " " 1019 return "" 1020 1021 def _children_str_at_line(self, line: int) -> str: 1022 """Get the string representation of the children at the given line. 1023 1024 Recursively calls `_str_at_line` on children nodes. 1025 """ 1026 if self.upper_printer is None and self.lower_printer is None: 1027 return "" 1028 upper_total_rows = self.upper_printer._total_rows() if self.upper_printer else 1 1029 lower_total_rows = self.lower_printer._total_rows() if self.lower_printer else 1 1030 if 0 <= line < upper_total_rows: 1031 return ( 1032 self.upper_printer._str_at_line(line) if self.upper_printer else "..." 1033 ) 1034 elif upper_total_rows < line < upper_total_rows + lower_total_rows + 1: 1035 return ( 1036 self.lower_printer._str_at_line(line - upper_total_rows - 1) 1037 if self.lower_printer 1038 else "..." 1039 ) 1040 return "" 1041 1042 def _str_at_line(self, line: int) -> str: 1043 """Get the string representation of the graph at the given line.""" 1044 return ( 1045 self._graph_segment_str_at_line(line) 1046 + self._connector_segment_str_at_line(line) 1047 + self._children_str_at_line(line) 1048 ) 1049 1050 def pretty_print(self): 1051 if self.graph_info is None: 1052 print(None) 1053 return 1054 # Print tree. 1055 print(" Tree: ".center(80, "=")) 1056 total_rows = self._total_rows() 1057 for line in range(total_rows): 1058 print(self._str_at_line(line).rstrip()) 1059 if self.graph_info.has_mismatch(): 1060 # Summarize leaf subgraphs with mismatch. 1061 print(" Mismatch leaf subgraphs: ".center(80, "=")) 1062 print( 1063 [ 1064 graph_info.id 1065 for graph_info in self.graph_info.all_mismatch_leaf_graph_info() 1066 ] 1067 ) 1068 # Summarize node kinds with mismatch. 1069 mismatch_node_kinds: dict[str, int] = {} 1070 for graph_info in self.graph_info.all_mismatch_leaf_graph_info(): 1071 node_kinds = graph_info.essential_node_kinds() 1072 if len(node_kinds) == 1: 1073 node_kind = node_kinds.pop() 1074 mismatch_node_kinds[node_kind] = ( 1075 mismatch_node_kinds.get(node_kind, 0) + 1 1076 ) 1077 print(" Mismatch node kinds: ".center(80, "=")) 1078 print(mismatch_node_kinds) 1079 else: 1080 print(" No mismatch found. ".center(80, "=")) 1081 1082 1083class OnnxTestCaseRepro: 1084 def __init__(self, repro_dir): 1085 self.repro_dir = repro_dir 1086 self.proto, self.inputs, self.outputs = onnx_proto_utils.load_test_case( 1087 repro_dir 1088 ) 1089 1090 @classmethod 1091 def create_test_case_repro( 1092 cls, proto: bytes, inputs, outputs, dir: str, name: str | None = None 1093 ): 1094 """Create a repro under "{dir}/test_{name}" for an ONNX test case. 1095 1096 The test case contains the model and the inputs/outputs data. The directory 1097 structure is as follows: 1098 1099 dir 1100 \u251c\u2500\u2500 test_<name> 1101 \u2502 \u251c\u2500\u2500 model.onnx 1102 \u2502 \u2514\u2500\u2500 test_data_set_0 1103 \u2502 \u251c\u2500\u2500 input_0.pb 1104 \u2502 \u251c\u2500\u2500 input_1.pb 1105 \u2502 \u251c\u2500\u2500 output_0.pb 1106 \u2502 \u2514\u2500\u2500 output_1.pb 1107 1108 Args: 1109 proto: ONNX model proto. 1110 inputs: Inputs to the model. 1111 outputs: Outputs of the model. 1112 dir: Directory to save the repro. 1113 name: Name of the test case. If not specified, a name based on current time 1114 will be generated. 1115 Returns: 1116 Path to the repro. 1117 """ 1118 if name is None: 1119 name = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f") 1120 return onnx_proto_utils.export_as_test_case( 1121 proto, 1122 _to_numpy(inputs), 1123 _to_numpy(outputs), 1124 name, 1125 dir, 1126 ) 1127 1128 def validate(self, options: VerificationOptions): 1129 """Run the ONNX test case with options.backend, and compare with the expected outputs. 1130 1131 Args: 1132 options: Options for validation. 1133 1134 Raise: 1135 AssertionError: if outputs from options.backend and expected outputs are not 1136 equal up to specified precision. 1137 """ 1138 onnx_session = _onnx_backend_session(io.BytesIO(self.proto), options.backend) 1139 run_outputs = onnx_session.run(None, self.inputs) 1140 if hasattr(onnx_session, "get_outputs"): 1141 output_names = [o.name for o in onnx_session.get_outputs()] 1142 elif hasattr(onnx_session, "output_names"): 1143 output_names = onnx_session.output_names 1144 else: 1145 raise ValueError(f"Unknown onnx session type: {type(onnx_session)}") 1146 expected_outs = [self.outputs[name] for name in output_names] 1147 _compare_onnx_pytorch_outputs_in_np(run_outputs, expected_outs, options) 1148 1149 1150@dataclasses.dataclass 1151class GraphInfo: 1152 """GraphInfo contains validation information of a TorchScript graph and its converted ONNX graph.""" 1153 1154 graph: torch.Graph 1155 input_args: tuple[Any, ...] 1156 params_dict: dict[str, Any] 1157 export_options: _experimental.ExportOptions = dataclasses.field( 1158 default_factory=_experimental.ExportOptions 1159 ) 1160 mismatch_error: AssertionError | None = dataclasses.field(default=None, init=False) 1161 pt_outs: Sequence[_NumericType] | None = dataclasses.field(default=None, init=False) 1162 upper_graph_info: GraphInfo | None = dataclasses.field(default=None, init=False) 1163 lower_graph_info: GraphInfo | None = dataclasses.field(default=None, init=False) 1164 id: str = dataclasses.field(default="") 1165 _onnx_graph: torch.Graph | None = dataclasses.field(init=False, default=None) 1166 1167 _EXCLUDED_NODE_KINDS: frozenset[str] = frozenset( 1168 {"prim::Constant", "prim::ListConstruct", "aten::ScalarImplicit"} 1169 ) 1170 1171 def clear(self): 1172 """Clear states and results of previous verification.""" 1173 self.mismatch_error = None 1174 self.pt_outs = None 1175 self._onnx_graph = None 1176 self.upper_graph_info = None 1177 self.lower_graph_info = None 1178 1179 def pretty_print_tree(self): 1180 """Pretty print `GraphInfo` tree. 1181 1182 Each node represents a subgraph, showing the number of nodes in the subgraph and 1183 a check mark if the subgraph has output mismatch between torch and ONNX. 1184 1185 The id of the subgraph is shown under the node. The `GraphInfo` object for any 1186 subgraph can be retrieved by calling `graph_info.find_partition(id)`. 1187 1188 Example:: 1189 1190 ==================================== Tree: ===================================== 1191 5 X __2 X __1 \u2713 1192 id: | id: 0 | id: 00 1193 | | 1194 | |__1 X (aten::relu) 1195 | id: 01 1196 | 1197 |__3 X __1 \u2713 1198 id: 1 | id: 10 1199 | 1200 |__2 X __1 X (aten::relu) 1201 id: 11 | id: 110 1202 | 1203 |__1 \u2713 1204 id: 111 1205 =========================== Mismatch leaf subgraphs: =========================== 1206 ['01', '110'] 1207 ============================= Mismatch node kinds: ============================= 1208 {'aten::relu': 2} 1209 1210 """ 1211 GraphInfoPrettyPrinter(self).pretty_print() 1212 1213 def pretty_print_mismatch(self, graph: bool = False): 1214 """Pretty print details of the mismatch between torch and ONNX. 1215 1216 Args: 1217 graph: If True, print the ATen JIT graph and ONNX graph. 1218 """ 1219 print(f" Mismatch info for graph partition {self.id}: ".center(80, "=")) 1220 if graph: 1221 print(" ATen JIT graph ".center(80, "=")) 1222 # TODO: A more compact graph printer. 1223 # * Drop stride, grad, device information. 1224 # * Show source location on a separate line. 1225 print(self.graph) 1226 if self._onnx_graph is not None: 1227 print(" ONNX graph ".center(80, "=")) 1228 print(self._onnx_graph) 1229 if self.has_mismatch(): 1230 print(" Mismatch error ".center(80, "=")) 1231 print(self.mismatch_error) 1232 else: 1233 print(" No mismatch ".center(80, "=")) 1234 1235 def has_mismatch(self) -> bool: 1236 """Return True if the subgraph has output mismatch between torch and ONNX.""" 1237 return self.mismatch_error is not None 1238 1239 def essential_node_count(self) -> int: 1240 """Return the number of nodes in the subgraph excluding those in `_EXCLUDED_NODE_KINDS`.""" 1241 return sum( 1242 1 for n in self.graph.nodes() if n.kind() not in self._EXCLUDED_NODE_KINDS 1243 ) 1244 1245 def essential_node_kinds(self) -> set[str]: 1246 """Return the set of node kinds in the subgraph excluding those in `_EXCLUDED_NODE_KINDS`.""" 1247 return { 1248 n.kind() 1249 for n in self.graph.nodes() 1250 if n.kind() not in self._EXCLUDED_NODE_KINDS 1251 } 1252 1253 def all_mismatch_leaf_graph_info(self) -> list[GraphInfo]: 1254 """Return a list of all leaf `GraphInfo` objects that have mismatch.""" 1255 if not self.has_mismatch(): 1256 return [] 1257 1258 no_mismatch_children = ( 1259 self.upper_graph_info is None or not self.upper_graph_info.has_mismatch() 1260 ) and ( 1261 self.lower_graph_info is None or not self.lower_graph_info.has_mismatch() 1262 ) 1263 1264 if no_mismatch_children: 1265 return [self] 1266 1267 results = [] 1268 if self.upper_graph_info is not None: 1269 results += self.upper_graph_info.all_mismatch_leaf_graph_info() 1270 if self.lower_graph_info is not None: 1271 results += self.lower_graph_info.all_mismatch_leaf_graph_info() 1272 1273 return results 1274 1275 def find_partition(self, id: str) -> GraphInfo | None: 1276 """Find the `GraphInfo` object with the given id.""" 1277 if id == self.id: 1278 return self 1279 current_length = len(self.id) 1280 if len(id) > current_length: 1281 if id[current_length] == "0" and self.upper_graph_info is not None: 1282 return self.upper_graph_info.find_partition(id) 1283 elif id[current_length] == "1" and self.lower_graph_info is not None: 1284 return self.lower_graph_info.find_partition(id) 1285 return None 1286 1287 def export_repro( 1288 self, repro_dir: str | None = None, name: str | None = None 1289 ) -> str: 1290 """Export the subgraph to ONNX along with the input/output data for repro. 1291 1292 The repro directory will contain the following files:: 1293 1294 dir 1295 \u251c\u2500\u2500 test_<name> 1296 \u2502 \u251c\u2500\u2500 model.onnx 1297 \u2502 \u2514\u2500\u2500 test_data_set_0 1298 \u2502 \u251c\u2500\u2500 input_0.pb 1299 \u2502 \u251c\u2500\u2500 input_1.pb 1300 \u2502 \u251c\u2500\u2500 output_0.pb 1301 \u2502 \u2514\u2500\u2500 output_1.pb 1302 1303 Args: 1304 repro_dir: The directory to export the repro files to. Defaults to current 1305 working directory if None. 1306 name: An optional name for the test case folder: "test_{name}". 1307 1308 Returns: 1309 The path to the exported repro directory. 1310 """ 1311 1312 if repro_dir is None: 1313 repro_dir = os.getcwd() 1314 repro_dir = os.path.join(repro_dir, "onnx_debug") 1315 1316 onnx_graph, onnx_params_dict = _onnx_graph_from_aten_graph( 1317 self.graph, self.export_options, self.params_dict 1318 ) 1319 1320 proto, _ = _onnx_proto_from_onnx_graph( 1321 onnx_graph, self.export_options, onnx_params_dict 1322 ) 1323 return OnnxTestCaseRepro.create_test_case_repro( 1324 proto, self.input_args, self.pt_outs, repro_dir, name 1325 ) 1326 1327 def _graph_partition_pivot(self) -> int: 1328 """Find the pivot index to partition the graph. 1329 1330 The pivot is the node that splits the graph into two parts. Each part should 1331 have the similar amount of nodes, excluding non essential ops, defined in 1332 `_EXCLUDED_NODE_KINDS`, such as `prim::Constant`. 1333 If the graph has an odd number of nodes, the upper part will have one more node. 1334 If the graph does not have any node that can be partitioned, return -1. 1335 1336 Returns: 1337 The index of the pivot node. 1338 """ 1339 included_node_indices = [ 1340 i 1341 for i, n in enumerate(self.graph.nodes()) 1342 if n.kind() not in self._EXCLUDED_NODE_KINDS 1343 ] 1344 half_idx = len(included_node_indices) // 2 - 1 1345 if half_idx >= 0 and len(included_node_indices) > half_idx: 1346 return included_node_indices[half_idx] + 1 1347 return -1 1348 1349 def _partition_upper_graph(self) -> torch.Graph: 1350 pivot = self._graph_partition_pivot() 1351 if pivot == -1: 1352 return torch.Graph() 1353 graph = self.graph.copy() # Copy to not mutate parent graph. 1354 original_outputs = list(graph.outputs()) 1355 1356 def _process_bridge_value_for_upper( 1357 new_outputs: list[torch.Value], bridge_value: torch.Value 1358 ) -> torch.Value: 1359 # Add bridge values as upper graph outputs. 1360 new_outputs.append(bridge_value) 1361 return bridge_value 1362 1363 new_outputs: list[torch.Value] = [] 1364 process_bridge_value_for_upper = functools.partial( 1365 _process_bridge_value_for_upper, new_outputs 1366 ) 1367 _, dropped_nodes, complete_upper_nodes_set, _ = self._partition_nodes( 1368 graph, pivot, process_bridge_value_for_upper 1369 ) 1370 1371 for _ in enumerate(original_outputs): 1372 graph.eraseOutput(0) 1373 for output in new_outputs: 1374 graph.registerOutput(output) 1375 1376 for node in reversed(dropped_nodes): 1377 node.destroy() 1378 1379 for i, input in reversed(list(enumerate(list(graph.inputs())))): 1380 if ( 1381 not _has_uses_by_nodes(input, complete_upper_nodes_set) 1382 and input not in new_outputs 1383 ): 1384 try: 1385 graph.eraseInput(i) 1386 except RuntimeError as e: 1387 print(input, graph) 1388 raise e 1389 1390 return graph 1391 1392 def _partition_lower_graph(self) -> torch.Graph: 1393 pivot = self._graph_partition_pivot() 1394 if pivot == -1: 1395 return torch.Graph() 1396 graph = self.graph.copy() # Copy to not mutate parent graph. 1397 original_outputs = list(graph.outputs()) 1398 original_inputs = list(graph.inputs()) 1399 1400 new_outputs = [] 1401 1402 def _process_bridge_value_for_lower( 1403 graph: torch.Graph, bridge_value: torch.Value 1404 ) -> torch.Value: 1405 # Add bridge values as lower graph inputs. 1406 new_input = graph.addInput() 1407 bridge_value.replaceAllUsesWith(new_input) 1408 new_input.copyMetadata(bridge_value) 1409 return new_input 1410 1411 process_bridge_value_for_lower = functools.partial( 1412 _process_bridge_value_for_lower, graph 1413 ) 1414 1415 upper_nodes, lower_nodes, _, complete_lower_nodes_set = self._partition_nodes( 1416 graph, pivot, process_bridge_value_for_lower 1417 ) 1418 1419 for output in original_outputs: 1420 if _produced_by(output, lower_nodes): 1421 new_outputs.append(output) 1422 for _ in enumerate(original_outputs): 1423 graph.eraseOutput(0) 1424 for output in new_outputs: 1425 graph.registerOutput(output) 1426 1427 for input in original_inputs: 1428 if _has_uses_by_nodes(input, complete_lower_nodes_set): 1429 new_input = graph.addInput() 1430 input.replaceAllUsesWith(new_input) 1431 new_input.copyMetadata(input) 1432 1433 for node in reversed(upper_nodes): 1434 if node not in complete_lower_nodes_set: 1435 try: 1436 node.destroy() 1437 except RuntimeError as e: 1438 print(node, graph) 1439 raise e 1440 1441 for _ in original_inputs: 1442 graph.eraseInput(0) 1443 1444 return graph 1445 1446 def _partition_node( 1447 self, 1448 node: torch.Node, 1449 complete_upper_nodes_set: set[torch.Node], 1450 complete_lower_nodes_set: set[torch.Node], 1451 original_graph_outputs: set[torch.Value], 1452 covered_bridge_values: set[torch.Value], 1453 process_bridge_value: Callable[[torch.Value], torch.Value], 1454 ): 1455 if node in complete_lower_nodes_set: 1456 return 1457 1458 if ( 1459 _node_has_uses_by(node, complete_lower_nodes_set) 1460 and node.kind() in self._EXCLUDED_NODE_KINDS 1461 ): 1462 complete_lower_nodes_set.update(_all_nodes([node])) 1463 for input in node.inputs(): 1464 if input in covered_bridge_values: 1465 continue 1466 self._partition_node( 1467 input.node(), 1468 complete_upper_nodes_set, 1469 complete_lower_nodes_set, 1470 original_graph_outputs, 1471 covered_bridge_values, 1472 process_bridge_value, 1473 ) 1474 else: 1475 for output in node.outputs(): 1476 if output in covered_bridge_values: 1477 continue 1478 if ( 1479 _has_uses_by_nodes(output, complete_lower_nodes_set) 1480 or output in original_graph_outputs 1481 ): 1482 covered_bridge_values.add(process_bridge_value(output)) 1483 1484 def _partition_nodes( 1485 self, 1486 graph: torch.Graph, 1487 pivot: int, 1488 process_bridge_value: Callable[[torch.Value], torch.Value], 1489 ) -> tuple[list[torch.Node], list[torch.Node], set[torch.Node], set[torch.Node]]: 1490 nodes = list(graph.nodes()) 1491 upper_nodes = nodes[:pivot] 1492 lower_nodes = nodes[pivot:] 1493 # `upper_nodes` and `complete_upper_nodes_set` differs in that the latter 1494 # recursively contains nodes in subblock of `upper_nodes`. 1495 # The same applies for `lower_nodes` and `complete_lower_nodes_set`. 1496 # With addition that `complete_lower_nodes_set` will include nodes that 1497 # are determined to be copied from `upper_nodes` to `lower_nodes`. 1498 complete_upper_nodes_set = _all_nodes(upper_nodes) 1499 complete_lower_nodes_set = _all_nodes(lower_nodes) 1500 original_graph_outputs = set(graph.outputs()) 1501 # Bridge values are values produced from upper graph, and consumed 1502 # by lower graph. These values need to be become upper graph outputs 1503 # and lower graph inputs, to bridge the interaction. 1504 # Start with all graph inputs marked as covered. If any graph input is 1505 # needed by lower graph, just keep it in lower graph inputs later. 1506 covered_bridge_values = set(graph.inputs()) 1507 for node in upper_nodes: 1508 self._partition_node( 1509 node, 1510 complete_upper_nodes_set, 1511 complete_lower_nodes_set, 1512 original_graph_outputs, 1513 covered_bridge_values, 1514 process_bridge_value, 1515 ) 1516 return ( 1517 upper_nodes, 1518 lower_nodes, 1519 complete_upper_nodes_set, 1520 complete_lower_nodes_set, 1521 ) 1522 1523 def _bridge_kwargs(self): 1524 pt_outs = self.pt_outs 1525 graph_outputs = list(self.graph.outputs()) 1526 assert pt_outs is not None 1527 assert len(graph_outputs) == len( 1528 pt_outs 1529 ), f"{len(graph_outputs)} vs {len(pt_outs)}\nGraph: {self.graph}" 1530 return {v.debugName(): o for v, o in zip(graph_outputs, pt_outs)} 1531 1532 def _args_and_params_for_partition_graph( 1533 self, 1534 graph: torch.Graph, 1535 bridge_kwargs: Mapping[str, _NumericType | Sequence[_NumericType]], 1536 full_kwargs: Mapping[str, torch.Tensor], 1537 full_params: Mapping[str, torch.Tensor], 1538 ): 1539 input_names = [input.debugName() for input in graph.inputs()] 1540 args = tuple(bridge_kwargs[k] for k in input_names if k in bridge_kwargs) 1541 args += tuple(full_kwargs[k] for k in input_names if k in full_kwargs) 1542 params = {k: full_params[k] for k in input_names if k in full_params} 1543 assert len(args) + len(params) == len( 1544 input_names 1545 ), f"{len(args)} + {len(params)} vs {len(input_names)}: {input_names}" 1546 return args, params 1547 1548 def verify_export( 1549 self, options: VerificationOptions 1550 ) -> tuple[AssertionError | None, torch.Graph, _OutputsType, _OutputsType]: 1551 """ 1552 Verify the export from TorchScript IR graph to ONNX. 1553 1554 Export the TorchScript IR graph to ONNX, with the inputs, parameters and export 1555 options recorded in this object. Then verify the exported ONNX graph against 1556 the original TorchScript IR graph under the provided verification options. 1557 1558 Args: 1559 options: The verification options. 1560 1561 Returns: 1562 error: The AssertionError raised during the verification. Returns None if no 1563 error is raised. 1564 onnx_graph: The exported ONNX graph in TorchScript IR format. 1565 onnx_outs: The outputs from running exported ONNX model under the onnx 1566 backend in `options`. 1567 pt_outs: The outputs from running the TorchScript IR graph. 1568 """ 1569 return verify_aten_graph( 1570 self.graph, 1571 input_args=self.input_args, 1572 params_dict=self.params_dict, 1573 export_options=self.export_options, 1574 verification_options=options, 1575 ) 1576 1577 def find_mismatch( 1578 self, 1579 options: VerificationOptions | None = None, 1580 ): 1581 """ 1582 Find all mismatches between the TorchScript IR graph and the exported onnx model. 1583 1584 Binary searches the model graph to find the minimal subgraph that exhibits the 1585 mismatch. A `GraphInfo` object is created for each subgraph, recording the test 1586 inputs and export options, as well as the validation results. 1587 1588 Args: 1589 options: The verification options. 1590 """ 1591 self.clear() 1592 1593 if options is None: 1594 options = VerificationOptions() 1595 1596 if self.export_options.verbose: 1597 print(self.graph) 1598 1599 if len(list(self.graph.outputs())) == 0: 1600 return 1601 1602 assert len(self.input_args) + len(self.params_dict) == len( 1603 list(self.graph.inputs()) 1604 ), ( 1605 f"Number of graph inputs({len(list(self.graph.inputs()))}) does not match " 1606 f"the provided tensor arguments({len(self.input_args)} + {len(self.params_dict)})." 1607 ) 1608 1609 self.mismatch_error, self._onnx_graph, self.pt_outs, _ = self.verify_export( 1610 options 1611 ) 1612 1613 if self.mismatch_error is None: 1614 # No mismatch found in graph. 1615 return 1616 1617 if self.essential_node_count() <= 1: 1618 # Reached leaf node, no more partitioning. 1619 return 1620 1621 full_kwargs = { 1622 k.debugName(): v for k, v in zip(self.graph.inputs(), self.input_args) 1623 } 1624 full_params = self.params_dict 1625 1626 upper_graph = self._partition_upper_graph() 1627 upper_args, upper_params = self._args_and_params_for_partition_graph( 1628 upper_graph, {}, full_kwargs, full_params 1629 ) 1630 self.upper_graph_info = GraphInfo( 1631 upper_graph, 1632 upper_args, 1633 upper_params, 1634 self.export_options, 1635 id=self.id + "0", 1636 ) 1637 1638 self.upper_graph_info.find_mismatch(options) 1639 1640 bridge_kwargs = self.upper_graph_info._bridge_kwargs() 1641 lower_graph = self._partition_lower_graph() 1642 lower_args, lower_params = self._args_and_params_for_partition_graph( 1643 lower_graph, bridge_kwargs, full_kwargs, full_params 1644 ) 1645 self.lower_graph_info = GraphInfo( 1646 lower_graph, 1647 lower_args, 1648 lower_params, 1649 self.export_options, 1650 id=self.id + "1", 1651 ) 1652 1653 self.lower_graph_info.find_mismatch(options) 1654 1655 1656def _all_nodes(nodes: Collection[torch.Node]) -> set[torch.Node]: 1657 all_nodes = set(nodes) 1658 for n in nodes: 1659 for b in n.blocks(): 1660 all_nodes.update(_all_nodes(list(b.nodes()))) 1661 return all_nodes 1662 1663 1664def _has_uses_by_nodes(value: torch.Value, nodes: Collection[torch.Node]) -> bool: 1665 return any(use.user in nodes for use in value.uses()) 1666 1667 1668def _node_has_uses_by(node: torch.Node, nodes: Collection[torch.Node]) -> bool: 1669 for output in node.outputs(): 1670 if _has_uses_by_nodes(output, nodes): 1671 return True 1672 return False 1673 1674 1675def _produced_by(value: torch.Value, nodes: Collection[torch.Node]) -> bool: 1676 return value.node() in nodes 1677 1678 1679def find_mismatch( 1680 model: torch.nn.Module | torch.jit.ScriptModule, 1681 input_args: tuple[Any, ...], 1682 do_constant_folding: bool = True, 1683 training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL, 1684 opset_version: int | None = None, 1685 keep_initializers_as_inputs: bool = True, 1686 verbose: bool = False, 1687 options: VerificationOptions | None = None, 1688) -> GraphInfo: 1689 r"""Find all mismatches between the original model and the exported model. 1690 1691 Experimental. The API is subject to change. 1692 1693 This tool helps debug the mismatch between the original PyTorch model and exported 1694 ONNX model. It binary searches the model graph to find the minimal subgraph that 1695 exhibits the mismatch. 1696 1697 Args: 1698 model: The model to be exported. 1699 input_args: The input arguments to the model. 1700 do_constant_folding: Same as `do_constant_folding` in :func:`torch.onnx.export`. 1701 training: Same as `training` in :func:`torch.onnx.export`. 1702 opset_version: Same as `opset_version` in :func:`torch.onnx.export`. 1703 keep_initializers_as_inputs: Same as `keep_initializers_as_inputs` in :func:`torch.onnx.export`. 1704 verbose: Same as `verbose` in :func:`torch.onnx.export`. 1705 options: The options for the mismatch verification. 1706 1707 Returns: 1708 A GraphInfo object that contains the mismatch information. 1709 1710 Example:: 1711 1712 >>> import torch 1713 >>> import torch.onnx.verification 1714 >>> torch.manual_seed(0) 1715 >>> opset_version = 15 1716 >>> # Define a custom symbolic function for aten::relu. 1717 >>> # The custom symbolic function is incorrect, which will result in mismatches. 1718 >>> def incorrect_relu_symbolic_function(g, self): 1719 ... return self 1720 >>> torch.onnx.register_custom_op_symbolic( 1721 ... "aten::relu", 1722 ... incorrect_relu_symbolic_function, 1723 ... opset_version=opset_version, 1724 ... ) 1725 >>> class Model(torch.nn.Module): 1726 ... def __init__(self) -> None: 1727 ... super().__init__() 1728 ... self.layers = torch.nn.Sequential( 1729 ... torch.nn.Linear(3, 4), 1730 ... torch.nn.ReLU(), 1731 ... torch.nn.Linear(4, 5), 1732 ... torch.nn.ReLU(), 1733 ... torch.nn.Linear(5, 6), 1734 ... ) 1735 ... def forward(self, x): 1736 ... return self.layers(x) 1737 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX) 1738 >>> graph_info = torch.onnx.verification.find_mismatch( 1739 ... Model(), 1740 ... (torch.randn(2, 3),), 1741 ... opset_version=opset_version, 1742 ... ) 1743 ===================== Mismatch info for graph partition : ====================== 1744 ================================ Mismatch error ================================ 1745 Tensor-likes are not close! 1746 Mismatched elements: 12 / 12 (100.0%) 1747 Greatest absolute difference: 0.2328854203224182 at index (1, 2) (up to 1e-07 allowed) 1748 Greatest relative difference: 0.699536174352349 at index (1, 3) (up to 0.001 allowed) 1749 ==================================== Tree: ===================================== 1750 5 X __2 X __1 \u2713 1751 id: | id: 0 | id: 00 1752 | | 1753 | |__1 X (aten::relu) 1754 | id: 01 1755 | 1756 |__3 X __1 \u2713 1757 id: 1 | id: 10 1758 | 1759 |__2 X __1 X (aten::relu) 1760 id: 11 | id: 110 1761 | 1762 |__1 \u2713 1763 id: 111 1764 =========================== Mismatch leaf subgraphs: =========================== 1765 ['01', '110'] 1766 ============================= Mismatch node kinds: ============================= 1767 {'aten::relu': 2} 1768 1769 """ 1770 if options is None: 1771 options = VerificationOptions() 1772 if opset_version is None: 1773 opset_version = _constants.ONNX_DEFAULT_OPSET 1774 """From aten graph, do binary search on graph partition to find operator export discrepancy.""" 1775 # TODO: Copied from utils.py `export` until `_optimize_graph`. 1776 if training == torch.onnx.TrainingMode.TRAINING: 1777 model.train() 1778 elif training == torch.onnx.TrainingMode.EVAL: 1779 model.eval() 1780 with torch.no_grad(): 1781 inputs_for_export = _prepare_input_for_export(input_args, {}) 1782 args = utils._decide_input_format(model, inputs_for_export) 1783 1784 model = utils._pre_trace_quant_model(model, args) 1785 graph, params, torch_out, module = utils._create_jit_graph(model, args) 1786 params_dict = utils._get_named_param_dict(graph, params) 1787 1788 utils._apply_friendly_debug_names(graph, params_dict) 1789 1790 graph_info = GraphInfo( 1791 graph, 1792 input_args, 1793 params_dict, 1794 _experimental.ExportOptions( 1795 do_constant_folding=do_constant_folding, 1796 training=training, 1797 opset_version=opset_version, 1798 keep_initializers_as_inputs=keep_initializers_as_inputs, 1799 verbose=verbose, 1800 ), 1801 ) 1802 graph_info.find_mismatch(options) 1803 graph_info.pretty_print_mismatch() 1804 graph_info.pretty_print_tree() 1805 1806 return graph_info 1807