xref: /aosp_15_r20/external/pytorch/torch/onnx/verification.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""Functions to verify exported ONNX model is functionally equivalent to original PyTorch model.
3
4ONNX Runtime is required, and is used as the ONNX backend for export verification.
5"""
6
7from __future__ import annotations
8
9import contextlib
10import copy
11import dataclasses
12import datetime
13import difflib
14import enum
15import functools
16import io
17import itertools
18import os
19import tempfile
20import warnings
21from typing import Any, Callable, Collection, Mapping, Sequence, Tuple, Union
22
23import numpy as np
24
25import torch
26import torch._C._onnx as _C_onnx
27from torch import _C
28from torch.onnx import _constants, _experimental, _exporter_states, utils
29from torch.onnx._globals import GLOBALS
30from torch.onnx._internal import onnx_proto_utils
31from torch.types import Number
32
33
34_ORT_PROVIDERS = ("CPUExecutionProvider",)
35
36_NumericType = Union[Number, torch.Tensor, np.ndarray]
37_ModelType = Union[torch.nn.Module, torch.jit.ScriptModule]
38_InputArgsType = Union[torch.Tensor, Tuple[Any, ...]]
39_InputKwargsType = Mapping[str, Any]
40_OutputsType = Union[Sequence[_NumericType], Sequence]
41
42
43class OnnxBackend(enum.Enum):
44    """Enum class for ONNX backend used for export verification."""
45
46    REFERENCE = "ONNXReferenceEvaluator"
47    ONNX_RUNTIME_CPU = "CPUExecutionProvider"
48    ONNX_RUNTIME_CUDA = "CUDAExecutionProvider"
49
50
51@dataclasses.dataclass
52class VerificationOptions:
53    """Options for ONNX export verification.
54
55    Attributes:
56        flatten: If True, unpack nested list/tuple/dict inputs into a flattened list of
57            Tensors for ONNX. Set this to False if nested structures are to be preserved
58            for ONNX, which is usually the case with exporting ScriptModules. Default True.
59        ignore_none: Whether to ignore None type in torch output, which is usually the
60            case with tracing. Set this to False, if torch output should keep None type,
61            which is usually the case with exporting ScriptModules. Default to True.
62        check_shape: Whether to check the shapes between PyTorch and ONNX Runtime outputs
63            are exactly the same. Set this to False to allow output shape broadcasting.
64            Default to True.
65        check_dtype: Whether to check the dtypes between PyTorch and ONNX Runtime outputs
66            are consistent. Default to True.
67        backend: ONNX backend for verification. Default to OnnxBackend.ONNX_RUNTIME_CPU.
68        rtol: relative tolerance in comparison between ONNX and PyTorch outputs.
69        atol: absolute tolerance in comparison between ONNX and PyTorch outputs.
70        remained_onnx_input_idx: If provided, only the specified inputs will be passed
71            to the ONNX model. Supply a list when there are unused inputs in the model.
72            Since unused inputs will be removed in the exported ONNX model, supplying
73            all inputs will cause an error on unexpected inputs. This parameter tells
74            the verifier which inputs to pass into the ONNX model.
75        acceptable_error_percentage: acceptable percentage of element mismatches in comparison.
76            It should be a float of value between 0.0 and 1.0.
77    """
78
79    flatten: bool = True
80    ignore_none: bool = True
81    check_shape: bool = True
82    check_dtype: bool = True
83    backend: OnnxBackend = OnnxBackend.ONNX_RUNTIME_CPU
84    rtol: float = 1e-3
85    atol: float = 1e-7
86    remained_onnx_input_idx: Sequence[int] | None = None
87    acceptable_error_percentage: float | None = None
88
89
90def _flatten_tuples(elem):
91    flattened = []
92    for t in elem:
93        if isinstance(t, tuple):
94            flattened.extend(_flatten_tuples(t))
95        else:
96            flattened.append(t)
97    return flattened
98
99
100# TODO(justinchuby): Add type checking by narrowing down the return type when input is None
101def _to_numpy(elem) -> list | np.ndarray:
102    if isinstance(elem, torch.Tensor):
103        if elem.requires_grad:
104            return elem.detach().cpu().numpy()
105        else:
106            return elem.cpu().numpy()
107    elif isinstance(elem, (list, tuple)):
108        return [_to_numpy(inp) for inp in elem]
109    elif isinstance(elem, (bool, int, float)):
110        return np.array(elem)
111    elif isinstance(elem, dict):
112        flattened = []
113        for k in elem:
114            flattened.extend([_to_numpy(k), _to_numpy(elem[k])])
115        return flattened
116    return elem
117
118
119def _inline_flatten_list(inputs, res_list) -> list:
120    for i in inputs:
121        res_list.append(i) if not isinstance(
122            i, (list, tuple)
123        ) else _inline_flatten_list(i, res_list)
124    return res_list
125
126
127def _unpack_to_numpy(values, cast_onnx_accepted=True) -> list:
128    value_unpacked = []
129    for value in values:
130        value_unpacked.extend(
131            utils.unpack_quantized_tensor(value, cast_onnx_accepted=cast_onnx_accepted)
132        )
133    return [_to_numpy(v) for v in value_unpacked]
134
135
136def _run_onnx(onnx_session, inputs) -> _OutputsType:
137    kw_inputs = {}
138    if inputs and isinstance(inputs[-1], dict):
139        kw_inputs = inputs[-1]
140        inputs = inputs[:-1]
141    inputs = _unpack_to_numpy(_flatten_tuples(inputs))
142    ort_inputs = {}
143    for input_name, input in kw_inputs.items():
144        ort_inputs[input_name] = _to_numpy(input)
145    inputs = _to_numpy(inputs)
146    if hasattr(onnx_session, "get_inputs"):
147        # onnxruntime.InferenceSession
148        input_names = [i.name for i in onnx_session.get_inputs()]
149    elif hasattr(onnx_session, "input_names"):
150        # onnx.reference.ReferenceEvaluator
151        input_names = onnx_session.input_names
152    else:
153        raise ValueError(f"Unknown ONNX backend type: {type(onnx_session)}.")
154
155    for i, input in enumerate(inputs):
156        if i == len(input_names) or input_names[i] in ort_inputs:
157            raise ValueError(
158                f"got too many positional inputs. inputs: {inputs}. kw_inputs: {kw_inputs}. "
159                f"input names: {input_names}."
160            )
161        ort_inputs[input_names[i]] = input
162    onnx_outs = onnx_session.run(None, ort_inputs)
163    return onnx_outs
164
165
166def _ort_session(
167    model: str | io.BytesIO, ort_providers: Sequence[str] = _ORT_PROVIDERS
168):
169    try:
170        import onnxruntime  # type: ignore[import]
171    except ImportError as e:
172        raise ImportError("onnxruntime is required for export verification.") from e
173
174    if ort_providers is None:
175        ort_providers = _ORT_PROVIDERS
176
177    session_options = onnxruntime.SessionOptions()
178    # suppress ort warnings.
179    # 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.
180    session_options.log_severity_level = 3
181    ort_session = onnxruntime.InferenceSession(
182        model if isinstance(model, str) else model.getvalue(),
183        session_options,
184        providers=ort_providers,
185    )
186    return ort_session
187
188
189def _onnx_reference_evaluator_session(model: str | io.BytesIO):
190    try:
191        import onnx
192        from onnx import reference as onnx_reference  # type: ignore[attr-defined]
193    except ImportError as exc:
194        raise ImportError("onnx >= 1.13 is required for reference evaluator.") from exc
195
196    proto = (
197        onnx.load(model)  # type: ignore[attr-defined]
198        if isinstance(model, str)
199        else onnx.load_model_from_string(model.getvalue())  # type: ignore[attr-defined]
200    )
201    onnx_session = onnx_reference.ReferenceEvaluator(proto)
202    return onnx_session
203
204
205def _onnx_backend_session(model: str | io.BytesIO, backend: OnnxBackend):
206    if backend == OnnxBackend.REFERENCE:
207        onnx_session = _onnx_reference_evaluator_session(model)
208    elif backend in {OnnxBackend.ONNX_RUNTIME_CPU, OnnxBackend.ONNX_RUNTIME_CUDA}:
209        onnx_session = _ort_session(model, (backend.value,))
210    else:
211        raise ValueError(f"Unsupported backend: {backend}")
212    return onnx_session
213
214
215def _compare_onnx_pytorch_outputs_in_np(
216    onnx_outs: _OutputsType,
217    pt_outs: _OutputsType,
218    options: VerificationOptions,
219):
220    assert (
221        len(onnx_outs) == len(pt_outs)
222    ), f"Number of outputs differ ONNX runtime: ({len(onnx_outs)}) PyTorch: ({len(pt_outs)})"
223    acceptable_error_percentage = options.acceptable_error_percentage
224    if acceptable_error_percentage and (
225        acceptable_error_percentage > 1.0 or acceptable_error_percentage < 0.0
226    ):
227        raise ValueError(
228            "If set, acceptable_error_percentage should be between 0.0 and 1.0"
229        )
230
231    for ort_out, pt_out in zip(onnx_outs, pt_outs):
232        try:
233            # TODO: Remove `check_shape` option once every shape inconsistent issue is addressed.
234            if not options.check_shape:
235                # Allow different but broadcastable output shapes.
236                ort_out, pt_out = np.broadcast_arrays(ort_out, pt_out)
237            torch.testing.assert_close(
238                ort_out,
239                pt_out,
240                rtol=options.rtol,
241                atol=options.atol,
242                check_dtype=options.check_dtype,
243                equal_nan=True,
244            )
245        except AssertionError as e:
246            if acceptable_error_percentage:
247                error_percentage = 1 - np.sum(
248                    np.isclose(ort_out, pt_out, rtol=options.rtol, atol=options.atol)
249                ) / np.prod(ort_out.shape)
250                if error_percentage <= acceptable_error_percentage:
251                    warnings.warn(
252                        f"Suppressed AssertionError:\n{e}.\n"
253                        f"Error percentage {error_percentage} "
254                        f"within acceptable range {acceptable_error_percentage}."
255                    )
256                    continue
257            if ort_out.dtype == np.uint8 or ort_out.dtype == np.int8:
258                warnings.warn("ONNX output is quantized")
259            if pt_out.dtype == np.uint8 or pt_out.dtype == np.int8:
260                warnings.warn("PyTorch output is quantized")
261            raise
262
263
264def _compare_onnx_pytorch_outputs(
265    onnx_outs: _OutputsType,
266    pt_outs: Any,
267    options: VerificationOptions,
268):
269    """
270    Compare ONNX and PyTorch outputs.
271
272    Args:
273        onnx_outs: outputs from ONNX backend.
274        pt_outs: outputs from PyTorch.
275        options: options for verification.
276
277    Raises:
278        AssertionError: if outputs from ONNX model and PyTorch model are not
279            equal up to specified precision.
280        ValueError: if arguments provided are invalid.
281    """
282    if options.ignore_none:
283        # torch.jit._flatten filters None type
284        pt_outs, _ = torch.jit._flatten(pt_outs)
285    else:
286        pt_outs = _inline_flatten_list([pt_outs], [])
287    pt_outs_np = _unpack_to_numpy(pt_outs, cast_onnx_accepted=False)
288    onnx_outs = _inline_flatten_list(onnx_outs, [])
289    _compare_onnx_pytorch_outputs_in_np(onnx_outs, pt_outs_np, options)
290
291
292def _prepare_input_for_pytorch(args, kwargs):
293    """Prepare input for PyTorch model execution.
294
295    Any future changes/formatting to the input before dispatching to the PyTorch
296    model should be made in this function.
297
298    Args:
299        args: positional arguments for PyTorch model forward method.
300        kwargs: keyword arguments for PyTorch model forward method.
301
302    Returns:
303        args: positional arguments for PyTorch model forward method.
304        kwargs: keyword arguments for PyTorch model forward method.
305    """
306    if isinstance(args, (torch.Tensor, dict)):
307        args = (args,)
308    # In-place operators will update input tensor data as well.
309    # Thus inputs are replicated before every forward call.
310    args = copy.deepcopy(args)
311    if kwargs:
312        kwargs = copy.deepcopy(kwargs)
313    else:
314        kwargs = {}
315    return args, kwargs
316
317
318def _prepare_input_for_export(args, kwargs):
319    """Prepare input for ONNX model export.
320
321    Any future changes/formatting to the input before dispatching to the
322    :func:`torch.onnx.export` api should be made in this function.
323
324    Args:
325        args: positional arguments for PyTorch model forward method.
326        kwargs: keyword arguments for PyTorch model forward method.
327
328    Returns:
329        onnx_inputs: positional arguments for ONNX model export, as `args` in
330            :func:`torch.onnx.export`.
331    """
332    args, kwargs = _prepare_input_for_pytorch(args, kwargs)
333    if not kwargs and len(args) > 0 and isinstance(args[-1], dict):
334        onnx_inputs = args + ({},)
335    elif kwargs:
336        onnx_inputs = args + (kwargs,)
337    else:
338        onnx_inputs = args
339    return onnx_inputs
340
341
342def _prepare_input_for_onnx(
343    args, kwargs, remained_onnx_input_idx: Sequence[int] | None, flatten: bool
344):
345    """Prepare input for ONNX model execution in ONNX backend.
346
347    Any future changes/formatting to the input before dispatching to the ONNX backend
348    run should be made in this function.
349
350    Args:
351        args: positional arguments for PyTorch model forward method.
352        kwargs: keyword arguments for PyTorch model forward method.
353        remained_onnx_input_idx: indices of inputs to be used for ONNX model execution.
354        flatten: whether to flatten the input before dispatching to the ONNX model execution.
355
356    Returns:
357        onnx_inputs: positional arguments for ONNX model execution in ONNX backend.
358    """
359    onnx_inputs = _prepare_input_for_export(args, kwargs)
360    if flatten:
361        onnx_inputs, _ = torch.jit._flatten(onnx_inputs)
362    elif onnx_inputs and onnx_inputs[-1] == {}:
363        # Handle empty kwargs (normally removed by flatten).
364        onnx_inputs = onnx_inputs[:-1]
365    if remained_onnx_input_idx is not None:
366        return [onnx_inputs[i] for i in remained_onnx_input_idx]
367    else:
368        return onnx_inputs
369
370
371def _try_clone_model(model):
372    """Used for preserving original model in case forward mutates model states."""
373    try:
374        return copy.deepcopy(model)
375    except Exception:
376        warnings.warn(
377            "Failed to clone model. Model state might be mutated during verification."
378        )
379        return model
380
381
382def _compare_onnx_pytorch_model(
383    pt_model: _ModelType,
384    onnx_model_f: str | io.BytesIO,
385    input_args: _InputArgsType,
386    input_kwargs: _InputKwargsType | None,
387    additional_test_inputs: Sequence[_InputArgsType] | None,
388    options: VerificationOptions,
389):
390    """Compare outputs from ONNX model runs with outputs from PyTorch model runs.
391
392    Args:
393        pt_model: PyTorch model.
394        onnx_model_f: ONNX model file path or file-like object.
395        input_args: positional arguments for PyTorch model forward method.
396        input_kwargs: keyword arguments for PyTorch model forward method.
397        additional_test_inputs: additional positional arguments for PyTorch model
398            forward method.
399        options: options for verification.
400
401    Raises:
402        AssertionError: if outputs from ONNX model and PyTorch model are not
403            equal up to specified precision.
404    """
405    onnx_session = _onnx_backend_session(onnx_model_f, options.backend)
406
407    def compare_onnx_pytorch_model_with_input(input_args, input_kwargs):
408        pt_args, pt_kwargs = _prepare_input_for_pytorch(input_args, input_kwargs)
409        # TODO: remove this and treat mutating model separately. See #77679
410        pt_model_copy = _try_clone_model(pt_model)
411        pt_outs = pt_model_copy(*pt_args, **pt_kwargs)
412
413        onnx_inputs = _prepare_input_for_onnx(
414            input_args, input_kwargs, options.remained_onnx_input_idx, options.flatten
415        )
416
417        onnx_outs = _run_onnx(onnx_session, onnx_inputs)
418
419        _compare_onnx_pytorch_outputs(
420            onnx_outs=onnx_outs,
421            pt_outs=pt_outs,
422            options=options,
423        )
424
425    compare_onnx_pytorch_model_with_input(input_args, input_kwargs)
426
427    if additional_test_inputs:
428        for test_input_args in additional_test_inputs:
429            compare_onnx_pytorch_model_with_input(test_input_args, {})
430
431
432class _GraphDiff:
433    """A class to represent the difference between two graphs."""
434
435    def __init__(self, graph_a: _C.Graph, graph_b: _C.Graph):
436        """Construct a _GraphDiff object.
437
438        Args:
439            graph_a (_C.Graph): First graph to compare.
440            graph_b (_C.Graph): Second graph to compare.
441        """
442        self.graph_a = graph_a
443        self.graph_b = graph_b
444
445    def __str__(self):
446        """See function :func:`diff_report`."""
447        return self.diff_report()
448
449    def _indent(self, lines: str) -> str:
450        return "\n".join(["\t" + line for line in lines.splitlines()])
451
452    def diff_report(self) -> str:
453        """Return a string representation of the graph difference.
454
455        The report shows the first pair of nodes that diverges. It also shows the source
456        location of the pair of nodes.
457
458        Returns:
459            graph_diff_report (str): A string representation of the graph difference.
460        """
461        graph_a = self.graph_a
462        graph_b = self.graph_b
463
464        graph_a_str = str(graph_a)
465        graph_b_str = str(graph_b)
466
467        if graph_a_str == graph_b_str:
468            return ""
469
470        graph_diff = difflib.ndiff(
471            graph_a_str.splitlines(True), graph_b_str.splitlines(True)
472        )
473        graph_diff_report = ["Graph diff:", self._indent("".join(graph_diff))]
474
475        for node_a, node_b in itertools.zip_longest(graph_a.nodes(), graph_b.nodes()):
476            if str(node_a) != str(node_b):
477                graph_diff_report.append("First diverging operator:")
478                node_diff = difflib.ndiff(
479                    str(node_a).splitlines(True), str(node_b).splitlines(True)
480                )
481                source_printout = ["node diff:", self._indent("".join(node_diff))]
482
483                stack_a = node_a.sourceRange() if node_a else None
484                if stack_a:
485                    source_printout.extend(
486                        ["Former source location:", self._indent(str(stack_a))]
487                    )
488                stack_b = node_b.sourceRange() if node_b else None
489                if stack_b:
490                    source_printout.extend(
491                        ["Latter source location:", self._indent(str(stack_b))]
492                    )
493
494                graph_diff_report.extend(source_printout)
495
496                break
497
498        return "\n".join(graph_diff_report)
499
500
501def _check_graph_diff(
502    model: torch.nn.Module | torch.jit.ScriptModule,
503    test_input_groups: Sequence[tuple[tuple[Any, ...], Mapping[str, Any]]],
504    export_options: _experimental.ExportOptions,
505    model_to_graph_func: Callable[
506        [
507            torch.nn.Module,
508            tuple[Any, ...],
509            Mapping[str, Any],
510            _experimental.ExportOptions,
511        ],
512        _C.Graph,
513    ],
514) -> str:
515    """Check if graph produced by `model_to_graph_func` is the same across `test_input_groups`.
516
517    Args:
518        model: See :func:`check_export_model_diff`.
519        test_input_groups: See :func:`check_export_model_diff`.
520        export_options: See :func:`check_export_model_diff`.
521        model_to_graph_func: A function to convert a PyTorch model to a JIT IR graph.
522
523    Returns:
524        graph_diff_report (str): A string representation of the graph difference.
525    """
526    if len(test_input_groups) < 2:
527        raise ValueError("Need at least two groups of test inputs to compare.")
528
529    ref_jit_graph = None
530    for args, kwargs in test_input_groups:
531        jit_graph = model_to_graph_func(model, args, kwargs, export_options)
532        if ref_jit_graph is None:
533            ref_jit_graph = jit_graph
534            continue
535
536        graph_diff_report = _GraphDiff(ref_jit_graph, jit_graph).diff_report()
537        if graph_diff_report:
538            return graph_diff_report
539    return ""
540
541
542def _traced_graph_from_model(
543    model: torch.nn.Module | torch.jit.ScriptModule,
544    args: tuple[Any, ...],
545    kwargs: Mapping[str, Any],
546    export_options: _experimental.ExportOptions,
547) -> _C.Graph:
548    """As part of the ONNX export steps, create a traced JIT graph from a PyTorch model.
549
550    Args:
551        model: See :func:`check_export_model_diff`.
552        args: See :func:`check_export_model_diff`.
553        kwargs: See :func:`check_export_model_diff`.
554        export_options: See :func:`check_export_model_diff`.
555
556    Returns:
557        jit_graph (_C.Graph): A traced JIT graph.
558    """
559    training = export_options.training
560    verbose = export_options.verbose
561
562    with utils.exporter_context(model, training, verbose):
563        export_inputs = _prepare_input_for_export(args, kwargs)
564        model = utils._pre_trace_quant_model(model, export_inputs)
565        jit_graph, _, _, _ = utils._create_jit_graph(model, export_inputs)
566        return jit_graph
567
568
569def _onnx_graph_from_model(
570    model: torch.nn.Module | torch.jit.ScriptModule,
571    args: tuple[Any, ...],
572    kwargs: Mapping[str, Any],
573    export_options: _experimental.ExportOptions,
574) -> _C.Graph:
575    """As part of the ONNX export steps, export an ONNX JIT graph from a PyTorch model.
576
577    Args:
578        model: See :func:`check_export_model_diff`.
579        args: See :func:`check_export_model_diff`.
580        kwargs: See :func:`check_export_model_diff`.
581        export_options: See :func:`check_export_model_diff`.
582
583    Returns:
584        onnx_graph (_C.Graph): An ONNX JIT graph.
585    """
586    # TODO: refactor utils.py to remove duplicated code of context setup. See #78834
587    opset_version = export_options.opset_version
588    operator_export_type = export_options.operator_export_type
589    export_modules_as_functions = export_options.export_modules_as_functions
590    training = export_options.training
591    verbose = export_options.verbose
592    dynamic_axes = export_options.dynamic_axes
593    input_names = export_options.input_names
594    output_names = export_options.output_names
595
596    if opset_version is None:
597        opset_version = _constants.ONNX_DEFAULT_OPSET
598
599    utils._setup_trace_module_map(model, export_modules_as_functions)
600
601    if not operator_export_type:
602        operator_export_type = _C_onnx.OperatorExportTypes.ONNX
603
604    GLOBALS.export_onnx_opset_version = opset_version
605    GLOBALS.operator_export_type = operator_export_type
606
607    with utils.exporter_context(model, training, verbose):
608        do_constant_folding = utils._decide_constant_folding(
609            export_options.do_constant_folding, operator_export_type, training
610        )
611
612        if dynamic_axes is None:
613            dynamic_axes = {}
614        utils._validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
615
616        export_inputs = _prepare_input_for_export(args, kwargs)
617        export_inputs = utils._decide_input_format(model, export_inputs)
618        onnx_graph, _, _ = utils._model_to_graph(
619            model,
620            export_inputs,
621            verbose,
622            input_names,
623            output_names,
624            operator_export_type,
625            do_constant_folding,
626            training=training,
627            dynamic_axes=dynamic_axes,
628        )
629
630        return onnx_graph
631
632
633def _onnx_graph_from_aten_graph(
634    graph: torch.Graph,
635    export_options: _experimental.ExportOptions,
636    params_dict: dict[str, Any] | None = None,
637) -> tuple[torch.Graph, dict[str, Any]]:
638    if params_dict is None:
639        params_dict = {}
640    operator_export_type = export_options.operator_export_type
641    dynamic_axes = export_options.dynamic_axes or {}
642    input_names = export_options.input_names
643    training = export_options.training
644    do_constant_folding = export_options.do_constant_folding
645    opset_version = export_options.opset_version or _constants.ONNX_DEFAULT_OPSET
646
647    GLOBALS.export_onnx_opset_version = opset_version
648    GLOBALS.operator_export_type = operator_export_type
649
650    do_constant_folding = utils._decide_constant_folding(
651        do_constant_folding, operator_export_type, training
652    )
653
654    # TODO: Below is doing aten graph to onnx. It should be abstracted as a
655    # function in torch/onnx/utils.py.
656    graph = graph.copy()
657    graph = utils._optimize_graph(
658        graph,
659        operator_export_type,
660        params_dict=params_dict,
661        dynamic_axes=dynamic_axes,
662        input_names=input_names,
663    )
664
665    if training is None or training == _C_onnx.TrainingMode.EVAL:
666        params_dict = torch._C._jit_pass_onnx_eval_peephole(graph, params_dict)
667
668    if (
669        do_constant_folding
670        and opset_version >= _constants.ONNX_CONSTANT_FOLDING_MIN_OPSET
671    ):
672        params_dict = _C._jit_pass_onnx_constant_fold(graph, params_dict, opset_version)
673        _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
674
675    if GLOBALS.onnx_shape_inference:
676        _C._jit_pass_onnx_graph_shape_type_inference(graph, params_dict, opset_version)
677
678    params_dict = _C._jit_pass_onnx_eliminate_unused_items(graph, params_dict)
679
680    # For ONNX opset < 9, constants only have three data types: float16, float, double.
681    # In this pass transform constants of other data types to float/double + cast operator.
682    if opset_version < 9:
683        _C._jit_pass_onnx_cast_all_constant_to_floating(graph)
684
685    params_dict = _C._jit_pass_filter_non_tensor_arguments(params_dict)
686    _C._jit_decay_packed_param_input_types(graph)
687
688    _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
689
690    if export_options.verbose:
691        print("ONNX graph: ", graph)
692
693    return graph, params_dict
694
695
696def _onnx_proto_from_onnx_graph(
697    onnx_graph: torch.Graph,
698    export_options: _experimental.ExportOptions,
699    params_dict: dict[str, Any],
700) -> tuple[bytes, Mapping[str, bytes]]:
701    opset_version = export_options.opset_version or _constants.ONNX_DEFAULT_OPSET
702    dynamic_axes = export_options.dynamic_axes or {}
703    operator_export_type = export_options.operator_export_type
704    val_keep_init_as_ip = utils._decide_keep_init_as_input(
705        export_options.keep_initializers_as_inputs,
706        operator_export_type,
707        opset_version,
708    )
709    val_add_node_names = utils._decide_add_node_names(True, operator_export_type)
710    custom_opsets = export_options.custom_opsets or {}
711
712    proto, export_map, _, _ = onnx_graph._export_onnx(  # type: ignore[attr-defined]
713        params_dict,
714        opset_version,
715        dynamic_axes,
716        False,
717        operator_export_type,
718        not export_options.verbose,
719        val_keep_init_as_ip,
720        custom_opsets,
721        val_add_node_names,
722        "",
723        {},
724    )
725
726    return proto, export_map
727
728
729def check_export_model_diff(
730    model: torch.nn.Module | torch.jit.ScriptModule,
731    test_input_groups: Sequence[tuple[tuple[Any, ...], Mapping[str, Any]]],
732    export_options: _experimental.ExportOptions | None = None,
733) -> str:
734    """Verify exported model discrepancy between different groups of inputs.
735
736    A graph is exported for each group of inputs. The exported graphs are then compared
737    to each other, and discrepancies of first pair of nodes are reported. This function
738    first checks the jit graph. If no discrepancies were found, it then checks the onnx
739    graph.
740
741    Unless otherwise specified, the jit/ONNX graph is expected to be the same, regardless
742    of the inputs used for exporting. A discrepancy implies the graph exported is
743    not accurate when run on other groups of inputs, which will typically results in
744    runtime errors or mismatching output.
745
746    Args:
747        model (torch.nn.Module or torch.jit.ScriptModule): The model to be exported.
748        test_input_groups (Sequence[Tuple[Tuple[Any, ...], Mapping[str, Any]]]): A sequence
749            of input groups to be used to export the model. Each input group is a pair of
750            (args, kwargs).
751        export_options (_experimental.ExportOptions, optional): An _experimental.ExportOptions
752            object that controls the export behavior.
753
754    Returns:
755        str: A string containing the diff of the exported models.
756    """
757    export_options = (
758        _experimental.ExportOptions() if export_options is None else export_options
759    )
760
761    jit_diff_report = _check_graph_diff(
762        model, test_input_groups, export_options, _traced_graph_from_model
763    )
764    if jit_diff_report:
765        return jit_diff_report
766
767    return _check_graph_diff(
768        model, test_input_groups, export_options, _onnx_graph_from_model
769    )
770
771
772def verify(
773    model: _ModelType,
774    input_args: _InputArgsType,
775    input_kwargs: _InputKwargsType | None = None,
776    do_constant_folding: bool = True,
777    dynamic_axes: Mapping[str, Mapping[int, str] | Mapping[str, Sequence[int]]]
778    | None = None,
779    input_names: Sequence[str] | None = None,
780    output_names: Sequence[str] | None = None,
781    training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL,
782    opset_version: int | None = None,
783    keep_initializers_as_inputs: bool = True,
784    verbose: bool = False,
785    fixed_batch_size: bool = False,
786    use_external_data: bool = False,
787    additional_test_inputs: Sequence[_InputArgsType] | None = None,
788    options: VerificationOptions | None = None,
789):
790    """Verify model export to ONNX against original PyTorch model.
791
792    Args:
793        model (torch.nn.Module or torch.jit.ScriptModule): See :func:`torch.onnx.export`.
794        input_args (tuple): See :func:`torch.onnx.export`.
795        input_kwargs (dict): See :func:`torch.onnx.export`.
796        do_constant_folding (bool, optional): See :func:`torch.onnx.export`.
797        dynamic_axes (dict, optional): See :func:`torch.onnx.export`.
798        input_names (list, optional): See :func:`torch.onnx.export`.
799        output_names (list, optional): See :func:`torch.onnx.export`.
800        training (torch.onnx.TrainingMode): See :func:`torch.onnx.export`.
801        opset_version (int, optional): See :func:`torch.onnx.export`.
802        keep_initializers_as_inputs (bool, optional): See :func:`torch.onnx.export`.
803        verbose (bool, optional): See :func:`torch.onnx.export`.
804        fixed_batch_size (bool, optional): Legacy argument, used only by rnn test cases.
805        use_external_data (bool, optional): Explicitly specify whether to export the
806            model with external data.
807        additional_test_inputs (list, optional): List of tuples. Each tuple is a group of
808            input arguments to test. Currently only *args are supported.
809        options (_VerificationOptions, optional): A _VerificationOptions object that
810            controls the verification behavior.
811
812    Raises:
813        AssertionError: if outputs from ONNX model and PyTorch model are not
814            equal up to specified precision.
815        ValueError: if arguments provided are invalid.
816    """
817    if options is None:
818        options = VerificationOptions()
819
820    if training == torch.onnx.TrainingMode.TRAINING:
821        model.train()
822    elif training == torch.onnx.TrainingMode.EVAL:
823        model.eval()
824    with torch.no_grad(), contextlib.ExitStack() as stack:
825        model_f: str | io.BytesIO = io.BytesIO()
826        if use_external_data:
827            tmpdir_path = stack.enter_context(tempfile.TemporaryDirectory())
828            model_f = os.path.join(tmpdir_path, "model.onnx")
829
830        inputs_for_export = _prepare_input_for_export(input_args, input_kwargs)
831
832        # TODO(#77679): remove this and treat mutating model separately.
833        model_copy = _try_clone_model(model)
834        utils._export(
835            model,
836            inputs_for_export,
837            model_f,
838            opset_version=opset_version,
839            do_constant_folding=do_constant_folding,
840            keep_initializers_as_inputs=keep_initializers_as_inputs,
841            dynamic_axes=dynamic_axes,
842            input_names=input_names,
843            output_names=output_names,
844            fixed_batch_size=fixed_batch_size,
845            training=training,
846            verbose=verbose,
847        )
848
849        _compare_onnx_pytorch_model(
850            pt_model=model_copy,
851            onnx_model_f=model_f,
852            input_args=input_args,
853            input_kwargs=input_kwargs,
854            additional_test_inputs=additional_test_inputs,
855            options=options,
856        )
857
858
859def verify_aten_graph(
860    graph: torch.Graph,
861    input_args: tuple[Any, ...],
862    export_options: _experimental.ExportOptions,
863    params_dict: dict[str, Any] | None = None,
864    verification_options: VerificationOptions | None = None,
865) -> tuple[AssertionError | None, torch.Graph, _OutputsType, _OutputsType]:
866    if verification_options is None:
867        verification_options = VerificationOptions()
868    if params_dict is None:
869        params_dict = {}
870
871    original_jit_graph = graph
872    graph = graph.copy()
873
874    # Execute aten graph and get reference torch jit outputs.
875    graph_inputs = list(graph.inputs())
876    jit_inputs = tuple([arg for arg in input_args if arg is not None])
877    weights = [params_dict[v.debugName()] for v in graph_inputs[len(jit_inputs) :]]
878    assert all(w is not None for w in weights)
879    # TODO: Only copy the argument if mutation is detected in Graph.
880    jit_inputs = copy.deepcopy(jit_inputs)
881    jit_input_and_parameters = jit_inputs + tuple(weights)
882    jit_outs = torch._C._jit_interpret_graph(graph, jit_input_and_parameters)  # type: ignore[attr-defined]
883    if not isinstance(jit_outs, (list, tuple)):
884        jit_outs = [jit_outs]
885
886    # Convert aten graph to onnx graph.
887    graph, onnx_params_dict = _onnx_graph_from_aten_graph(
888        graph, export_options, params_dict
889    )
890
891    proto, export_map = _onnx_proto_from_onnx_graph(
892        graph, export_options, onnx_params_dict
893    )
894    model_f: str | io.BytesIO = io.BytesIO()
895    export_type = _exporter_states.ExportTypes.PROTOBUF_FILE
896    onnx_proto_utils._export_file(proto, model_f, export_type, export_map)
897
898    # NOTE: Verification is unstable. Try catch to emit information for debugging.
899    try:
900        # NOTE: Input might be dce'ed, so we need to remove those from the input args.
901        new_input_names = {v.debugName() for v in graph.inputs()}
902        new_input_args = []
903        for v, arg in zip(original_jit_graph.inputs(), input_args):
904            if v.debugName() in new_input_names:
905                new_input_args.append(arg)
906        input_args = tuple(new_input_args)
907
908        onnx_inputs = _prepare_input_for_onnx(
909            input_args,
910            {},
911            verification_options.remained_onnx_input_idx,
912            verification_options.flatten,
913        )
914
915        onnx_session = _onnx_backend_session(model_f, verification_options.backend)
916        onnx_outs = _run_onnx(onnx_session, onnx_inputs)
917        del onnx_session  # To free device memory
918
919        try:
920            _compare_onnx_pytorch_outputs(
921                onnx_outs=onnx_outs,
922                pt_outs=jit_outs,
923                options=verification_options,
924            )
925        except AssertionError as e:
926            return e, graph, jit_outs, onnx_outs
927
928        return None, graph, jit_outs, onnx_outs
929
930    except Exception as e:
931        print("Unexpected error during verification.")
932        print("jit graph: ", original_jit_graph)
933        print("onnx graph: ", graph)
934        raise e
935
936
937class GraphInfoPrettyPrinter:
938    graph_info: GraphInfo | None
939    upper_printer: GraphInfoPrettyPrinter | None
940    lower_printer: GraphInfoPrettyPrinter | None
941
942    graph_str_lambdas: Mapping[int, str]
943    connector_str_lambdas: Mapping[int, str]
944    children_str_lambdas: Mapping[int, str]
945
946    def __init__(self, graph_info: GraphInfo | None):
947        self.graph_info = graph_info
948        if (
949            graph_info is not None
950            and graph_info.upper_graph_info is not None
951            and graph_info.lower_graph_info is not None
952        ):
953            self.upper_printer = GraphInfoPrettyPrinter(graph_info.upper_graph_info)
954            self.lower_printer = GraphInfoPrettyPrinter(graph_info.lower_graph_info)
955        else:
956            self.upper_printer = None
957            self.lower_printer = None
958
959    def _total_rows(self) -> int:
960        if self.graph_info is None:
961            return 1
962        if self.upper_printer and self.lower_printer:
963            return (
964                self.upper_printer._total_rows() + self.lower_printer._total_rows() + 1
965            )
966        return 2  # Two lines: node count + id.
967
968    def _node_count_segment_str(self) -> str:
969        if self.graph_info is None:
970            return "..."
971        node_count = self.graph_info.essential_node_count()
972        has_mismatch = self.graph_info.has_mismatch()
973        error_node_kind = (
974            f"({self.graph_info.essential_node_kinds().pop()})"
975            if node_count == 1 and has_mismatch
976            else ""
977        )
978
979        return f"{node_count} {'X' if has_mismatch else chr(0x2713)} {error_node_kind}"
980
981    def _graph_id_segment_str(self) -> str:
982        if self.graph_info is None:
983            return ""
984        return f"id: {self.graph_info.id}"
985
986    def _max_segment_columns(self) -> int:
987        return max(
988            map(len, (self._node_count_segment_str(), self._graph_id_segment_str()))
989        )
990
991    def _graph_segment_str_at_line(self, line: int) -> str:
992        """Get the string representation of the graph segment at the given line."""
993        if line == 0:
994            result_str = self._node_count_segment_str()
995            result_str += " " * (self._max_segment_columns() - len(result_str))
996            return result_str
997        if line == 1:
998            result_str = self._graph_id_segment_str()
999            result_str += " " * (self._max_segment_columns() - len(result_str))
1000            return result_str
1001        if 0 <= line < self._total_rows():
1002            return " " * self._max_segment_columns()
1003        return ""
1004
1005    def _connector_segment_str_at_line(self, line: int) -> str:
1006        """Get the connector segment string at the given line."""
1007        if self.upper_printer is None and self.lower_printer is None:
1008            return ""
1009        upper_total_rows = self.upper_printer._total_rows() if self.upper_printer else 1
1010        lower_total_rows = self.lower_printer._total_rows() if self.lower_printer else 1
1011        if line == 0:
1012            return "  __"
1013        elif line < upper_total_rows + 1:
1014            return " |  "
1015        elif line == upper_total_rows + 1:
1016            return " |__"
1017        elif line < upper_total_rows + lower_total_rows + 1:
1018            return "    "
1019        return ""
1020
1021    def _children_str_at_line(self, line: int) -> str:
1022        """Get the string representation of the children at the given line.
1023
1024        Recursively calls `_str_at_line` on children nodes.
1025        """
1026        if self.upper_printer is None and self.lower_printer is None:
1027            return ""
1028        upper_total_rows = self.upper_printer._total_rows() if self.upper_printer else 1
1029        lower_total_rows = self.lower_printer._total_rows() if self.lower_printer else 1
1030        if 0 <= line < upper_total_rows:
1031            return (
1032                self.upper_printer._str_at_line(line) if self.upper_printer else "..."
1033            )
1034        elif upper_total_rows < line < upper_total_rows + lower_total_rows + 1:
1035            return (
1036                self.lower_printer._str_at_line(line - upper_total_rows - 1)
1037                if self.lower_printer
1038                else "..."
1039            )
1040        return ""
1041
1042    def _str_at_line(self, line: int) -> str:
1043        """Get the string representation of the graph at the given line."""
1044        return (
1045            self._graph_segment_str_at_line(line)
1046            + self._connector_segment_str_at_line(line)
1047            + self._children_str_at_line(line)
1048        )
1049
1050    def pretty_print(self):
1051        if self.graph_info is None:
1052            print(None)
1053            return
1054        # Print tree.
1055        print(" Tree: ".center(80, "="))
1056        total_rows = self._total_rows()
1057        for line in range(total_rows):
1058            print(self._str_at_line(line).rstrip())
1059        if self.graph_info.has_mismatch():
1060            # Summarize leaf subgraphs with mismatch.
1061            print(" Mismatch leaf subgraphs: ".center(80, "="))
1062            print(
1063                [
1064                    graph_info.id
1065                    for graph_info in self.graph_info.all_mismatch_leaf_graph_info()
1066                ]
1067            )
1068            # Summarize node kinds with mismatch.
1069            mismatch_node_kinds: dict[str, int] = {}
1070            for graph_info in self.graph_info.all_mismatch_leaf_graph_info():
1071                node_kinds = graph_info.essential_node_kinds()
1072                if len(node_kinds) == 1:
1073                    node_kind = node_kinds.pop()
1074                    mismatch_node_kinds[node_kind] = (
1075                        mismatch_node_kinds.get(node_kind, 0) + 1
1076                    )
1077            print(" Mismatch node kinds: ".center(80, "="))
1078            print(mismatch_node_kinds)
1079        else:
1080            print(" No mismatch found. ".center(80, "="))
1081
1082
1083class OnnxTestCaseRepro:
1084    def __init__(self, repro_dir):
1085        self.repro_dir = repro_dir
1086        self.proto, self.inputs, self.outputs = onnx_proto_utils.load_test_case(
1087            repro_dir
1088        )
1089
1090    @classmethod
1091    def create_test_case_repro(
1092        cls, proto: bytes, inputs, outputs, dir: str, name: str | None = None
1093    ):
1094        """Create a repro under "{dir}/test_{name}" for an ONNX test case.
1095
1096        The test case contains the model and the inputs/outputs data. The directory
1097        structure is as follows:
1098
1099        dir
1100        \u251c\u2500\u2500 test_<name>
1101        \u2502   \u251c\u2500\u2500 model.onnx
1102        \u2502   \u2514\u2500\u2500 test_data_set_0
1103        \u2502       \u251c\u2500\u2500 input_0.pb
1104        \u2502       \u251c\u2500\u2500 input_1.pb
1105        \u2502       \u251c\u2500\u2500 output_0.pb
1106        \u2502       \u2514\u2500\u2500 output_1.pb
1107
1108        Args:
1109            proto: ONNX model proto.
1110            inputs: Inputs to the model.
1111            outputs: Outputs of the model.
1112            dir: Directory to save the repro.
1113            name: Name of the test case. If not specified, a name based on current time
1114                will be generated.
1115        Returns:
1116            Path to the repro.
1117        """
1118        if name is None:
1119            name = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f")
1120        return onnx_proto_utils.export_as_test_case(
1121            proto,
1122            _to_numpy(inputs),
1123            _to_numpy(outputs),
1124            name,
1125            dir,
1126        )
1127
1128    def validate(self, options: VerificationOptions):
1129        """Run the ONNX test case with options.backend, and compare with the expected outputs.
1130
1131        Args:
1132            options: Options for validation.
1133
1134        Raise:
1135            AssertionError: if outputs from options.backend and expected outputs are not
1136                equal up to specified precision.
1137        """
1138        onnx_session = _onnx_backend_session(io.BytesIO(self.proto), options.backend)
1139        run_outputs = onnx_session.run(None, self.inputs)
1140        if hasattr(onnx_session, "get_outputs"):
1141            output_names = [o.name for o in onnx_session.get_outputs()]
1142        elif hasattr(onnx_session, "output_names"):
1143            output_names = onnx_session.output_names
1144        else:
1145            raise ValueError(f"Unknown onnx session type: {type(onnx_session)}")
1146        expected_outs = [self.outputs[name] for name in output_names]
1147        _compare_onnx_pytorch_outputs_in_np(run_outputs, expected_outs, options)
1148
1149
1150@dataclasses.dataclass
1151class GraphInfo:
1152    """GraphInfo contains validation information of a TorchScript graph and its converted ONNX graph."""
1153
1154    graph: torch.Graph
1155    input_args: tuple[Any, ...]
1156    params_dict: dict[str, Any]
1157    export_options: _experimental.ExportOptions = dataclasses.field(
1158        default_factory=_experimental.ExportOptions
1159    )
1160    mismatch_error: AssertionError | None = dataclasses.field(default=None, init=False)
1161    pt_outs: Sequence[_NumericType] | None = dataclasses.field(default=None, init=False)
1162    upper_graph_info: GraphInfo | None = dataclasses.field(default=None, init=False)
1163    lower_graph_info: GraphInfo | None = dataclasses.field(default=None, init=False)
1164    id: str = dataclasses.field(default="")
1165    _onnx_graph: torch.Graph | None = dataclasses.field(init=False, default=None)
1166
1167    _EXCLUDED_NODE_KINDS: frozenset[str] = frozenset(
1168        {"prim::Constant", "prim::ListConstruct", "aten::ScalarImplicit"}
1169    )
1170
1171    def clear(self):
1172        """Clear states and results of previous verification."""
1173        self.mismatch_error = None
1174        self.pt_outs = None
1175        self._onnx_graph = None
1176        self.upper_graph_info = None
1177        self.lower_graph_info = None
1178
1179    def pretty_print_tree(self):
1180        """Pretty print `GraphInfo` tree.
1181
1182        Each node represents a subgraph, showing the number of nodes in the subgraph and
1183        a check mark if the subgraph has output mismatch between torch and ONNX.
1184
1185        The id of the subgraph is shown under the node. The `GraphInfo` object for any
1186        subgraph can be retrieved by calling `graph_info.find_partition(id)`.
1187
1188        Example::
1189
1190            ==================================== Tree: =====================================
1191            5 X   __2 X    __1 \u2713
1192            id:  |  id: 0 |  id: 00
1193                 |        |
1194                 |        |__1 X (aten::relu)
1195                 |           id: 01
1196                 |
1197                 |__3 X    __1 \u2713
1198                    id: 1 |  id: 10
1199                          |
1200                          |__2 X     __1 X (aten::relu)
1201                             id: 11 |  id: 110
1202                                    |
1203                                    |__1 \u2713
1204                                       id: 111
1205            =========================== Mismatch leaf subgraphs: ===========================
1206            ['01', '110']
1207            ============================= Mismatch node kinds: =============================
1208            {'aten::relu': 2}
1209
1210        """
1211        GraphInfoPrettyPrinter(self).pretty_print()
1212
1213    def pretty_print_mismatch(self, graph: bool = False):
1214        """Pretty print details of the mismatch between torch and ONNX.
1215
1216        Args:
1217            graph: If True, print the ATen JIT graph and ONNX graph.
1218        """
1219        print(f" Mismatch info for graph partition {self.id}: ".center(80, "="))
1220        if graph:
1221            print(" ATen JIT graph ".center(80, "="))
1222            # TODO: A more compact graph printer.
1223            #   * Drop stride, grad, device information.
1224            #   * Show source location on a separate line.
1225            print(self.graph)
1226            if self._onnx_graph is not None:
1227                print(" ONNX graph ".center(80, "="))
1228                print(self._onnx_graph)
1229        if self.has_mismatch():
1230            print(" Mismatch error ".center(80, "="))
1231            print(self.mismatch_error)
1232        else:
1233            print(" No mismatch ".center(80, "="))
1234
1235    def has_mismatch(self) -> bool:
1236        """Return True if the subgraph has output mismatch between torch and ONNX."""
1237        return self.mismatch_error is not None
1238
1239    def essential_node_count(self) -> int:
1240        """Return the number of nodes in the subgraph excluding those in `_EXCLUDED_NODE_KINDS`."""
1241        return sum(
1242            1 for n in self.graph.nodes() if n.kind() not in self._EXCLUDED_NODE_KINDS
1243        )
1244
1245    def essential_node_kinds(self) -> set[str]:
1246        """Return the set of node kinds in the subgraph excluding those in `_EXCLUDED_NODE_KINDS`."""
1247        return {
1248            n.kind()
1249            for n in self.graph.nodes()
1250            if n.kind() not in self._EXCLUDED_NODE_KINDS
1251        }
1252
1253    def all_mismatch_leaf_graph_info(self) -> list[GraphInfo]:
1254        """Return a list of all leaf `GraphInfo` objects that have mismatch."""
1255        if not self.has_mismatch():
1256            return []
1257
1258        no_mismatch_children = (
1259            self.upper_graph_info is None or not self.upper_graph_info.has_mismatch()
1260        ) and (
1261            self.lower_graph_info is None or not self.lower_graph_info.has_mismatch()
1262        )
1263
1264        if no_mismatch_children:
1265            return [self]
1266
1267        results = []
1268        if self.upper_graph_info is not None:
1269            results += self.upper_graph_info.all_mismatch_leaf_graph_info()
1270        if self.lower_graph_info is not None:
1271            results += self.lower_graph_info.all_mismatch_leaf_graph_info()
1272
1273        return results
1274
1275    def find_partition(self, id: str) -> GraphInfo | None:
1276        """Find the `GraphInfo` object with the given id."""
1277        if id == self.id:
1278            return self
1279        current_length = len(self.id)
1280        if len(id) > current_length:
1281            if id[current_length] == "0" and self.upper_graph_info is not None:
1282                return self.upper_graph_info.find_partition(id)
1283            elif id[current_length] == "1" and self.lower_graph_info is not None:
1284                return self.lower_graph_info.find_partition(id)
1285        return None
1286
1287    def export_repro(
1288        self, repro_dir: str | None = None, name: str | None = None
1289    ) -> str:
1290        """Export the subgraph to ONNX along with the input/output data for repro.
1291
1292        The repro directory will contain the following files::
1293
1294            dir
1295            \u251c\u2500\u2500 test_<name>
1296            \u2502   \u251c\u2500\u2500 model.onnx
1297            \u2502   \u2514\u2500\u2500 test_data_set_0
1298            \u2502       \u251c\u2500\u2500 input_0.pb
1299            \u2502       \u251c\u2500\u2500 input_1.pb
1300            \u2502       \u251c\u2500\u2500 output_0.pb
1301            \u2502       \u2514\u2500\u2500 output_1.pb
1302
1303        Args:
1304            repro_dir: The directory to export the repro files to. Defaults to current
1305                working directory if None.
1306            name: An optional name for the test case folder: "test_{name}".
1307
1308        Returns:
1309            The path to the exported repro directory.
1310        """
1311
1312        if repro_dir is None:
1313            repro_dir = os.getcwd()
1314        repro_dir = os.path.join(repro_dir, "onnx_debug")
1315
1316        onnx_graph, onnx_params_dict = _onnx_graph_from_aten_graph(
1317            self.graph, self.export_options, self.params_dict
1318        )
1319
1320        proto, _ = _onnx_proto_from_onnx_graph(
1321            onnx_graph, self.export_options, onnx_params_dict
1322        )
1323        return OnnxTestCaseRepro.create_test_case_repro(
1324            proto, self.input_args, self.pt_outs, repro_dir, name
1325        )
1326
1327    def _graph_partition_pivot(self) -> int:
1328        """Find the pivot index to partition the graph.
1329
1330        The pivot is the node that splits the graph into two parts. Each part should
1331        have the similar amount of nodes, excluding non essential ops, defined in
1332        `_EXCLUDED_NODE_KINDS`, such as `prim::Constant`.
1333        If the graph has an odd number of nodes, the upper part will have one more node.
1334        If the graph does not have any node that can be partitioned, return -1.
1335
1336        Returns:
1337            The index of the pivot node.
1338        """
1339        included_node_indices = [
1340            i
1341            for i, n in enumerate(self.graph.nodes())
1342            if n.kind() not in self._EXCLUDED_NODE_KINDS
1343        ]
1344        half_idx = len(included_node_indices) // 2 - 1
1345        if half_idx >= 0 and len(included_node_indices) > half_idx:
1346            return included_node_indices[half_idx] + 1
1347        return -1
1348
1349    def _partition_upper_graph(self) -> torch.Graph:
1350        pivot = self._graph_partition_pivot()
1351        if pivot == -1:
1352            return torch.Graph()
1353        graph = self.graph.copy()  # Copy to not mutate parent graph.
1354        original_outputs = list(graph.outputs())
1355
1356        def _process_bridge_value_for_upper(
1357            new_outputs: list[torch.Value], bridge_value: torch.Value
1358        ) -> torch.Value:
1359            # Add bridge values as upper graph outputs.
1360            new_outputs.append(bridge_value)
1361            return bridge_value
1362
1363        new_outputs: list[torch.Value] = []
1364        process_bridge_value_for_upper = functools.partial(
1365            _process_bridge_value_for_upper, new_outputs
1366        )
1367        _, dropped_nodes, complete_upper_nodes_set, _ = self._partition_nodes(
1368            graph, pivot, process_bridge_value_for_upper
1369        )
1370
1371        for _ in enumerate(original_outputs):
1372            graph.eraseOutput(0)
1373        for output in new_outputs:
1374            graph.registerOutput(output)
1375
1376        for node in reversed(dropped_nodes):
1377            node.destroy()
1378
1379        for i, input in reversed(list(enumerate(list(graph.inputs())))):
1380            if (
1381                not _has_uses_by_nodes(input, complete_upper_nodes_set)
1382                and input not in new_outputs
1383            ):
1384                try:
1385                    graph.eraseInput(i)
1386                except RuntimeError as e:
1387                    print(input, graph)
1388                    raise e
1389
1390        return graph
1391
1392    def _partition_lower_graph(self) -> torch.Graph:
1393        pivot = self._graph_partition_pivot()
1394        if pivot == -1:
1395            return torch.Graph()
1396        graph = self.graph.copy()  # Copy to not mutate parent graph.
1397        original_outputs = list(graph.outputs())
1398        original_inputs = list(graph.inputs())
1399
1400        new_outputs = []
1401
1402        def _process_bridge_value_for_lower(
1403            graph: torch.Graph, bridge_value: torch.Value
1404        ) -> torch.Value:
1405            # Add bridge values as lower graph inputs.
1406            new_input = graph.addInput()
1407            bridge_value.replaceAllUsesWith(new_input)
1408            new_input.copyMetadata(bridge_value)
1409            return new_input
1410
1411        process_bridge_value_for_lower = functools.partial(
1412            _process_bridge_value_for_lower, graph
1413        )
1414
1415        upper_nodes, lower_nodes, _, complete_lower_nodes_set = self._partition_nodes(
1416            graph, pivot, process_bridge_value_for_lower
1417        )
1418
1419        for output in original_outputs:
1420            if _produced_by(output, lower_nodes):
1421                new_outputs.append(output)
1422        for _ in enumerate(original_outputs):
1423            graph.eraseOutput(0)
1424        for output in new_outputs:
1425            graph.registerOutput(output)
1426
1427        for input in original_inputs:
1428            if _has_uses_by_nodes(input, complete_lower_nodes_set):
1429                new_input = graph.addInput()
1430                input.replaceAllUsesWith(new_input)
1431                new_input.copyMetadata(input)
1432
1433        for node in reversed(upper_nodes):
1434            if node not in complete_lower_nodes_set:
1435                try:
1436                    node.destroy()
1437                except RuntimeError as e:
1438                    print(node, graph)
1439                    raise e
1440
1441        for _ in original_inputs:
1442            graph.eraseInput(0)
1443
1444        return graph
1445
1446    def _partition_node(
1447        self,
1448        node: torch.Node,
1449        complete_upper_nodes_set: set[torch.Node],
1450        complete_lower_nodes_set: set[torch.Node],
1451        original_graph_outputs: set[torch.Value],
1452        covered_bridge_values: set[torch.Value],
1453        process_bridge_value: Callable[[torch.Value], torch.Value],
1454    ):
1455        if node in complete_lower_nodes_set:
1456            return
1457
1458        if (
1459            _node_has_uses_by(node, complete_lower_nodes_set)
1460            and node.kind() in self._EXCLUDED_NODE_KINDS
1461        ):
1462            complete_lower_nodes_set.update(_all_nodes([node]))
1463            for input in node.inputs():
1464                if input in covered_bridge_values:
1465                    continue
1466                self._partition_node(
1467                    input.node(),
1468                    complete_upper_nodes_set,
1469                    complete_lower_nodes_set,
1470                    original_graph_outputs,
1471                    covered_bridge_values,
1472                    process_bridge_value,
1473                )
1474        else:
1475            for output in node.outputs():
1476                if output in covered_bridge_values:
1477                    continue
1478                if (
1479                    _has_uses_by_nodes(output, complete_lower_nodes_set)
1480                    or output in original_graph_outputs
1481                ):
1482                    covered_bridge_values.add(process_bridge_value(output))
1483
1484    def _partition_nodes(
1485        self,
1486        graph: torch.Graph,
1487        pivot: int,
1488        process_bridge_value: Callable[[torch.Value], torch.Value],
1489    ) -> tuple[list[torch.Node], list[torch.Node], set[torch.Node], set[torch.Node]]:
1490        nodes = list(graph.nodes())
1491        upper_nodes = nodes[:pivot]
1492        lower_nodes = nodes[pivot:]
1493        # `upper_nodes` and `complete_upper_nodes_set` differs in that the latter
1494        # recursively contains nodes in subblock of `upper_nodes`.
1495        # The same applies for `lower_nodes` and `complete_lower_nodes_set`.
1496        # With addition that `complete_lower_nodes_set` will include nodes that
1497        # are determined to be copied from `upper_nodes` to `lower_nodes`.
1498        complete_upper_nodes_set = _all_nodes(upper_nodes)
1499        complete_lower_nodes_set = _all_nodes(lower_nodes)
1500        original_graph_outputs = set(graph.outputs())
1501        # Bridge values are values produced from upper graph, and consumed
1502        # by lower graph. These values need to be become upper graph outputs
1503        # and lower graph inputs, to bridge the interaction.
1504        # Start with all graph inputs marked as covered. If any graph input is
1505        # needed by lower graph, just keep it in lower graph inputs later.
1506        covered_bridge_values = set(graph.inputs())
1507        for node in upper_nodes:
1508            self._partition_node(
1509                node,
1510                complete_upper_nodes_set,
1511                complete_lower_nodes_set,
1512                original_graph_outputs,
1513                covered_bridge_values,
1514                process_bridge_value,
1515            )
1516        return (
1517            upper_nodes,
1518            lower_nodes,
1519            complete_upper_nodes_set,
1520            complete_lower_nodes_set,
1521        )
1522
1523    def _bridge_kwargs(self):
1524        pt_outs = self.pt_outs
1525        graph_outputs = list(self.graph.outputs())
1526        assert pt_outs is not None
1527        assert len(graph_outputs) == len(
1528            pt_outs
1529        ), f"{len(graph_outputs)} vs {len(pt_outs)}\nGraph: {self.graph}"
1530        return {v.debugName(): o for v, o in zip(graph_outputs, pt_outs)}
1531
1532    def _args_and_params_for_partition_graph(
1533        self,
1534        graph: torch.Graph,
1535        bridge_kwargs: Mapping[str, _NumericType | Sequence[_NumericType]],
1536        full_kwargs: Mapping[str, torch.Tensor],
1537        full_params: Mapping[str, torch.Tensor],
1538    ):
1539        input_names = [input.debugName() for input in graph.inputs()]
1540        args = tuple(bridge_kwargs[k] for k in input_names if k in bridge_kwargs)
1541        args += tuple(full_kwargs[k] for k in input_names if k in full_kwargs)
1542        params = {k: full_params[k] for k in input_names if k in full_params}
1543        assert len(args) + len(params) == len(
1544            input_names
1545        ), f"{len(args)} + {len(params)} vs {len(input_names)}: {input_names}"
1546        return args, params
1547
1548    def verify_export(
1549        self, options: VerificationOptions
1550    ) -> tuple[AssertionError | None, torch.Graph, _OutputsType, _OutputsType]:
1551        """
1552        Verify the export from TorchScript IR graph to ONNX.
1553
1554        Export the TorchScript IR graph to ONNX, with the inputs, parameters and export
1555        options recorded in this object. Then verify the exported ONNX graph against
1556        the original TorchScript IR graph under the provided verification options.
1557
1558        Args:
1559            options: The verification options.
1560
1561        Returns:
1562            error: The AssertionError raised during the verification. Returns None if no
1563            error is raised.
1564            onnx_graph: The exported ONNX graph in TorchScript IR format.
1565            onnx_outs: The outputs from running exported ONNX model under the onnx
1566            backend in `options`.
1567            pt_outs: The outputs from running the TorchScript IR graph.
1568        """
1569        return verify_aten_graph(
1570            self.graph,
1571            input_args=self.input_args,
1572            params_dict=self.params_dict,
1573            export_options=self.export_options,
1574            verification_options=options,
1575        )
1576
1577    def find_mismatch(
1578        self,
1579        options: VerificationOptions | None = None,
1580    ):
1581        """
1582        Find all mismatches between the TorchScript IR graph and the exported onnx model.
1583
1584        Binary searches the model graph to find the minimal subgraph that exhibits the
1585        mismatch. A `GraphInfo` object is created for each subgraph, recording the test
1586        inputs and export options, as well as the validation results.
1587
1588        Args:
1589            options: The verification options.
1590        """
1591        self.clear()
1592
1593        if options is None:
1594            options = VerificationOptions()
1595
1596        if self.export_options.verbose:
1597            print(self.graph)
1598
1599        if len(list(self.graph.outputs())) == 0:
1600            return
1601
1602        assert len(self.input_args) + len(self.params_dict) == len(
1603            list(self.graph.inputs())
1604        ), (
1605            f"Number of graph inputs({len(list(self.graph.inputs()))}) does not match "
1606            f"the provided tensor arguments({len(self.input_args)} + {len(self.params_dict)})."
1607        )
1608
1609        self.mismatch_error, self._onnx_graph, self.pt_outs, _ = self.verify_export(
1610            options
1611        )
1612
1613        if self.mismatch_error is None:
1614            # No mismatch found in graph.
1615            return
1616
1617        if self.essential_node_count() <= 1:
1618            # Reached leaf node, no more partitioning.
1619            return
1620
1621        full_kwargs = {
1622            k.debugName(): v for k, v in zip(self.graph.inputs(), self.input_args)
1623        }
1624        full_params = self.params_dict
1625
1626        upper_graph = self._partition_upper_graph()
1627        upper_args, upper_params = self._args_and_params_for_partition_graph(
1628            upper_graph, {}, full_kwargs, full_params
1629        )
1630        self.upper_graph_info = GraphInfo(
1631            upper_graph,
1632            upper_args,
1633            upper_params,
1634            self.export_options,
1635            id=self.id + "0",
1636        )
1637
1638        self.upper_graph_info.find_mismatch(options)
1639
1640        bridge_kwargs = self.upper_graph_info._bridge_kwargs()
1641        lower_graph = self._partition_lower_graph()
1642        lower_args, lower_params = self._args_and_params_for_partition_graph(
1643            lower_graph, bridge_kwargs, full_kwargs, full_params
1644        )
1645        self.lower_graph_info = GraphInfo(
1646            lower_graph,
1647            lower_args,
1648            lower_params,
1649            self.export_options,
1650            id=self.id + "1",
1651        )
1652
1653        self.lower_graph_info.find_mismatch(options)
1654
1655
1656def _all_nodes(nodes: Collection[torch.Node]) -> set[torch.Node]:
1657    all_nodes = set(nodes)
1658    for n in nodes:
1659        for b in n.blocks():
1660            all_nodes.update(_all_nodes(list(b.nodes())))
1661    return all_nodes
1662
1663
1664def _has_uses_by_nodes(value: torch.Value, nodes: Collection[torch.Node]) -> bool:
1665    return any(use.user in nodes for use in value.uses())
1666
1667
1668def _node_has_uses_by(node: torch.Node, nodes: Collection[torch.Node]) -> bool:
1669    for output in node.outputs():
1670        if _has_uses_by_nodes(output, nodes):
1671            return True
1672    return False
1673
1674
1675def _produced_by(value: torch.Value, nodes: Collection[torch.Node]) -> bool:
1676    return value.node() in nodes
1677
1678
1679def find_mismatch(
1680    model: torch.nn.Module | torch.jit.ScriptModule,
1681    input_args: tuple[Any, ...],
1682    do_constant_folding: bool = True,
1683    training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL,
1684    opset_version: int | None = None,
1685    keep_initializers_as_inputs: bool = True,
1686    verbose: bool = False,
1687    options: VerificationOptions | None = None,
1688) -> GraphInfo:
1689    r"""Find all mismatches between the original model and the exported model.
1690
1691    Experimental. The API is subject to change.
1692
1693    This tool helps debug the mismatch between the original PyTorch model and exported
1694    ONNX model. It binary searches the model graph to find the minimal subgraph that
1695    exhibits the mismatch.
1696
1697    Args:
1698        model: The model to be exported.
1699        input_args: The input arguments to the model.
1700        do_constant_folding: Same as `do_constant_folding` in :func:`torch.onnx.export`.
1701        training: Same as `training` in :func:`torch.onnx.export`.
1702        opset_version: Same as `opset_version` in :func:`torch.onnx.export`.
1703        keep_initializers_as_inputs: Same as `keep_initializers_as_inputs` in :func:`torch.onnx.export`.
1704        verbose: Same as `verbose` in :func:`torch.onnx.export`.
1705        options: The options for the mismatch verification.
1706
1707    Returns:
1708        A GraphInfo object that contains the mismatch information.
1709
1710    Example::
1711
1712        >>> import torch
1713        >>> import torch.onnx.verification
1714        >>> torch.manual_seed(0)
1715        >>> opset_version = 15
1716        >>> # Define a custom symbolic function for aten::relu.
1717        >>> # The custom symbolic function is incorrect, which will result in mismatches.
1718        >>> def incorrect_relu_symbolic_function(g, self):
1719        ...     return self
1720        >>> torch.onnx.register_custom_op_symbolic(
1721        ...     "aten::relu",
1722        ...     incorrect_relu_symbolic_function,
1723        ...     opset_version=opset_version,
1724        ... )
1725        >>> class Model(torch.nn.Module):
1726        ...     def __init__(self) -> None:
1727        ...         super().__init__()
1728        ...         self.layers = torch.nn.Sequential(
1729        ...             torch.nn.Linear(3, 4),
1730        ...             torch.nn.ReLU(),
1731        ...             torch.nn.Linear(4, 5),
1732        ...             torch.nn.ReLU(),
1733        ...             torch.nn.Linear(5, 6),
1734        ...         )
1735        ...     def forward(self, x):
1736        ...         return self.layers(x)
1737        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX)
1738        >>> graph_info = torch.onnx.verification.find_mismatch(
1739        ...     Model(),
1740        ...     (torch.randn(2, 3),),
1741        ...     opset_version=opset_version,
1742        ... )
1743        ===================== Mismatch info for graph partition : ======================
1744        ================================ Mismatch error ================================
1745        Tensor-likes are not close!
1746        Mismatched elements: 12 / 12 (100.0%)
1747        Greatest absolute difference: 0.2328854203224182 at index (1, 2) (up to 1e-07 allowed)
1748        Greatest relative difference: 0.699536174352349 at index (1, 3) (up to 0.001 allowed)
1749        ==================================== Tree: =====================================
1750        5 X   __2 X    __1 \u2713
1751        id:  |  id: 0 |  id: 00
1752             |        |
1753             |        |__1 X (aten::relu)
1754             |           id: 01
1755             |
1756             |__3 X    __1 \u2713
1757                id: 1 |  id: 10
1758                      |
1759                      |__2 X     __1 X (aten::relu)
1760                         id: 11 |  id: 110
1761                                |
1762                                |__1 \u2713
1763                                   id: 111
1764        =========================== Mismatch leaf subgraphs: ===========================
1765        ['01', '110']
1766        ============================= Mismatch node kinds: =============================
1767        {'aten::relu': 2}
1768
1769    """
1770    if options is None:
1771        options = VerificationOptions()
1772    if opset_version is None:
1773        opset_version = _constants.ONNX_DEFAULT_OPSET
1774    """From aten graph, do binary search on graph partition to find operator export discrepancy."""
1775    # TODO: Copied from utils.py `export` until `_optimize_graph`.
1776    if training == torch.onnx.TrainingMode.TRAINING:
1777        model.train()
1778    elif training == torch.onnx.TrainingMode.EVAL:
1779        model.eval()
1780    with torch.no_grad():
1781        inputs_for_export = _prepare_input_for_export(input_args, {})
1782        args = utils._decide_input_format(model, inputs_for_export)
1783
1784        model = utils._pre_trace_quant_model(model, args)
1785        graph, params, torch_out, module = utils._create_jit_graph(model, args)
1786        params_dict = utils._get_named_param_dict(graph, params)
1787
1788        utils._apply_friendly_debug_names(graph, params_dict)
1789
1790        graph_info = GraphInfo(
1791            graph,
1792            input_args,
1793            params_dict,
1794            _experimental.ExportOptions(
1795                do_constant_folding=do_constant_folding,
1796                training=training,
1797                opset_version=opset_version,
1798                keep_initializers_as_inputs=keep_initializers_as_inputs,
1799                verbose=verbose,
1800            ),
1801        )
1802        graph_info.find_mismatch(options)
1803        graph_info.pretty_print_mismatch()
1804        graph_info.pretty_print_tree()
1805
1806        return graph_info
1807