xref: /aosp_15_r20/external/pytorch/torch/onnx/verification.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Worker"""Functions to verify exported ONNX model is functionally equivalent to original PyTorch model.
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard WorkerONNX Runtime is required, and is used as the ONNX backend for export verification.
5*da0073e9SAndroid Build Coastguard Worker"""
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerimport contextlib
10*da0073e9SAndroid Build Coastguard Workerimport copy
11*da0073e9SAndroid Build Coastguard Workerimport dataclasses
12*da0073e9SAndroid Build Coastguard Workerimport datetime
13*da0073e9SAndroid Build Coastguard Workerimport difflib
14*da0073e9SAndroid Build Coastguard Workerimport enum
15*da0073e9SAndroid Build Coastguard Workerimport functools
16*da0073e9SAndroid Build Coastguard Workerimport io
17*da0073e9SAndroid Build Coastguard Workerimport itertools
18*da0073e9SAndroid Build Coastguard Workerimport os
19*da0073e9SAndroid Build Coastguard Workerimport tempfile
20*da0073e9SAndroid Build Coastguard Workerimport warnings
21*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Callable, Collection, Mapping, Sequence, Tuple, Union
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Workerimport numpy as np
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Workerimport torch
26*da0073e9SAndroid Build Coastguard Workerimport torch._C._onnx as _C_onnx
27*da0073e9SAndroid Build Coastguard Workerfrom torch import _C
28*da0073e9SAndroid Build Coastguard Workerfrom torch.onnx import _constants, _experimental, _exporter_states, utils
29*da0073e9SAndroid Build Coastguard Workerfrom torch.onnx._globals import GLOBALS
30*da0073e9SAndroid Build Coastguard Workerfrom torch.onnx._internal import onnx_proto_utils
31*da0073e9SAndroid Build Coastguard Workerfrom torch.types import Number
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Worker_ORT_PROVIDERS = ("CPUExecutionProvider",)
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker_NumericType = Union[Number, torch.Tensor, np.ndarray]
37*da0073e9SAndroid Build Coastguard Worker_ModelType = Union[torch.nn.Module, torch.jit.ScriptModule]
38*da0073e9SAndroid Build Coastguard Worker_InputArgsType = Union[torch.Tensor, Tuple[Any, ...]]
39*da0073e9SAndroid Build Coastguard Worker_InputKwargsType = Mapping[str, Any]
40*da0073e9SAndroid Build Coastguard Worker_OutputsType = Union[Sequence[_NumericType], Sequence]
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Workerclass OnnxBackend(enum.Enum):
44*da0073e9SAndroid Build Coastguard Worker    """Enum class for ONNX backend used for export verification."""
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker    REFERENCE = "ONNXReferenceEvaluator"
47*da0073e9SAndroid Build Coastguard Worker    ONNX_RUNTIME_CPU = "CPUExecutionProvider"
48*da0073e9SAndroid Build Coastguard Worker    ONNX_RUNTIME_CUDA = "CUDAExecutionProvider"
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Worker@dataclasses.dataclass
52*da0073e9SAndroid Build Coastguard Workerclass VerificationOptions:
53*da0073e9SAndroid Build Coastguard Worker    """Options for ONNX export verification.
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker    Attributes:
56*da0073e9SAndroid Build Coastguard Worker        flatten: If True, unpack nested list/tuple/dict inputs into a flattened list of
57*da0073e9SAndroid Build Coastguard Worker            Tensors for ONNX. Set this to False if nested structures are to be preserved
58*da0073e9SAndroid Build Coastguard Worker            for ONNX, which is usually the case with exporting ScriptModules. Default True.
59*da0073e9SAndroid Build Coastguard Worker        ignore_none: Whether to ignore None type in torch output, which is usually the
60*da0073e9SAndroid Build Coastguard Worker            case with tracing. Set this to False, if torch output should keep None type,
61*da0073e9SAndroid Build Coastguard Worker            which is usually the case with exporting ScriptModules. Default to True.
62*da0073e9SAndroid Build Coastguard Worker        check_shape: Whether to check the shapes between PyTorch and ONNX Runtime outputs
63*da0073e9SAndroid Build Coastguard Worker            are exactly the same. Set this to False to allow output shape broadcasting.
64*da0073e9SAndroid Build Coastguard Worker            Default to True.
65*da0073e9SAndroid Build Coastguard Worker        check_dtype: Whether to check the dtypes between PyTorch and ONNX Runtime outputs
66*da0073e9SAndroid Build Coastguard Worker            are consistent. Default to True.
67*da0073e9SAndroid Build Coastguard Worker        backend: ONNX backend for verification. Default to OnnxBackend.ONNX_RUNTIME_CPU.
68*da0073e9SAndroid Build Coastguard Worker        rtol: relative tolerance in comparison between ONNX and PyTorch outputs.
69*da0073e9SAndroid Build Coastguard Worker        atol: absolute tolerance in comparison between ONNX and PyTorch outputs.
70*da0073e9SAndroid Build Coastguard Worker        remained_onnx_input_idx: If provided, only the specified inputs will be passed
71*da0073e9SAndroid Build Coastguard Worker            to the ONNX model. Supply a list when there are unused inputs in the model.
72*da0073e9SAndroid Build Coastguard Worker            Since unused inputs will be removed in the exported ONNX model, supplying
73*da0073e9SAndroid Build Coastguard Worker            all inputs will cause an error on unexpected inputs. This parameter tells
74*da0073e9SAndroid Build Coastguard Worker            the verifier which inputs to pass into the ONNX model.
75*da0073e9SAndroid Build Coastguard Worker        acceptable_error_percentage: acceptable percentage of element mismatches in comparison.
76*da0073e9SAndroid Build Coastguard Worker            It should be a float of value between 0.0 and 1.0.
77*da0073e9SAndroid Build Coastguard Worker    """
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker    flatten: bool = True
80*da0073e9SAndroid Build Coastguard Worker    ignore_none: bool = True
81*da0073e9SAndroid Build Coastguard Worker    check_shape: bool = True
82*da0073e9SAndroid Build Coastguard Worker    check_dtype: bool = True
83*da0073e9SAndroid Build Coastguard Worker    backend: OnnxBackend = OnnxBackend.ONNX_RUNTIME_CPU
84*da0073e9SAndroid Build Coastguard Worker    rtol: float = 1e-3
85*da0073e9SAndroid Build Coastguard Worker    atol: float = 1e-7
86*da0073e9SAndroid Build Coastguard Worker    remained_onnx_input_idx: Sequence[int] | None = None
87*da0073e9SAndroid Build Coastguard Worker    acceptable_error_percentage: float | None = None
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard Workerdef _flatten_tuples(elem):
91*da0073e9SAndroid Build Coastguard Worker    flattened = []
92*da0073e9SAndroid Build Coastguard Worker    for t in elem:
93*da0073e9SAndroid Build Coastguard Worker        if isinstance(t, tuple):
94*da0073e9SAndroid Build Coastguard Worker            flattened.extend(_flatten_tuples(t))
95*da0073e9SAndroid Build Coastguard Worker        else:
96*da0073e9SAndroid Build Coastguard Worker            flattened.append(t)
97*da0073e9SAndroid Build Coastguard Worker    return flattened
98*da0073e9SAndroid Build Coastguard Worker
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker# TODO(justinchuby): Add type checking by narrowing down the return type when input is None
101*da0073e9SAndroid Build Coastguard Workerdef _to_numpy(elem) -> list | np.ndarray:
102*da0073e9SAndroid Build Coastguard Worker    if isinstance(elem, torch.Tensor):
103*da0073e9SAndroid Build Coastguard Worker        if elem.requires_grad:
104*da0073e9SAndroid Build Coastguard Worker            return elem.detach().cpu().numpy()
105*da0073e9SAndroid Build Coastguard Worker        else:
106*da0073e9SAndroid Build Coastguard Worker            return elem.cpu().numpy()
107*da0073e9SAndroid Build Coastguard Worker    elif isinstance(elem, (list, tuple)):
108*da0073e9SAndroid Build Coastguard Worker        return [_to_numpy(inp) for inp in elem]
109*da0073e9SAndroid Build Coastguard Worker    elif isinstance(elem, (bool, int, float)):
110*da0073e9SAndroid Build Coastguard Worker        return np.array(elem)
111*da0073e9SAndroid Build Coastguard Worker    elif isinstance(elem, dict):
112*da0073e9SAndroid Build Coastguard Worker        flattened = []
113*da0073e9SAndroid Build Coastguard Worker        for k in elem:
114*da0073e9SAndroid Build Coastguard Worker            flattened.extend([_to_numpy(k), _to_numpy(elem[k])])
115*da0073e9SAndroid Build Coastguard Worker        return flattened
116*da0073e9SAndroid Build Coastguard Worker    return elem
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Workerdef _inline_flatten_list(inputs, res_list) -> list:
120*da0073e9SAndroid Build Coastguard Worker    for i in inputs:
121*da0073e9SAndroid Build Coastguard Worker        res_list.append(i) if not isinstance(
122*da0073e9SAndroid Build Coastguard Worker            i, (list, tuple)
123*da0073e9SAndroid Build Coastguard Worker        ) else _inline_flatten_list(i, res_list)
124*da0073e9SAndroid Build Coastguard Worker    return res_list
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Workerdef _unpack_to_numpy(values, cast_onnx_accepted=True) -> list:
128*da0073e9SAndroid Build Coastguard Worker    value_unpacked = []
129*da0073e9SAndroid Build Coastguard Worker    for value in values:
130*da0073e9SAndroid Build Coastguard Worker        value_unpacked.extend(
131*da0073e9SAndroid Build Coastguard Worker            utils.unpack_quantized_tensor(value, cast_onnx_accepted=cast_onnx_accepted)
132*da0073e9SAndroid Build Coastguard Worker        )
133*da0073e9SAndroid Build Coastguard Worker    return [_to_numpy(v) for v in value_unpacked]
134*da0073e9SAndroid Build Coastguard Worker
135*da0073e9SAndroid Build Coastguard Worker
136*da0073e9SAndroid Build Coastguard Workerdef _run_onnx(onnx_session, inputs) -> _OutputsType:
137*da0073e9SAndroid Build Coastguard Worker    kw_inputs = {}
138*da0073e9SAndroid Build Coastguard Worker    if inputs and isinstance(inputs[-1], dict):
139*da0073e9SAndroid Build Coastguard Worker        kw_inputs = inputs[-1]
140*da0073e9SAndroid Build Coastguard Worker        inputs = inputs[:-1]
141*da0073e9SAndroid Build Coastguard Worker    inputs = _unpack_to_numpy(_flatten_tuples(inputs))
142*da0073e9SAndroid Build Coastguard Worker    ort_inputs = {}
143*da0073e9SAndroid Build Coastguard Worker    for input_name, input in kw_inputs.items():
144*da0073e9SAndroid Build Coastguard Worker        ort_inputs[input_name] = _to_numpy(input)
145*da0073e9SAndroid Build Coastguard Worker    inputs = _to_numpy(inputs)
146*da0073e9SAndroid Build Coastguard Worker    if hasattr(onnx_session, "get_inputs"):
147*da0073e9SAndroid Build Coastguard Worker        # onnxruntime.InferenceSession
148*da0073e9SAndroid Build Coastguard Worker        input_names = [i.name for i in onnx_session.get_inputs()]
149*da0073e9SAndroid Build Coastguard Worker    elif hasattr(onnx_session, "input_names"):
150*da0073e9SAndroid Build Coastguard Worker        # onnx.reference.ReferenceEvaluator
151*da0073e9SAndroid Build Coastguard Worker        input_names = onnx_session.input_names
152*da0073e9SAndroid Build Coastguard Worker    else:
153*da0073e9SAndroid Build Coastguard Worker        raise ValueError(f"Unknown ONNX backend type: {type(onnx_session)}.")
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Worker    for i, input in enumerate(inputs):
156*da0073e9SAndroid Build Coastguard Worker        if i == len(input_names) or input_names[i] in ort_inputs:
157*da0073e9SAndroid Build Coastguard Worker            raise ValueError(
158*da0073e9SAndroid Build Coastguard Worker                f"got too many positional inputs. inputs: {inputs}. kw_inputs: {kw_inputs}. "
159*da0073e9SAndroid Build Coastguard Worker                f"input names: {input_names}."
160*da0073e9SAndroid Build Coastguard Worker            )
161*da0073e9SAndroid Build Coastguard Worker        ort_inputs[input_names[i]] = input
162*da0073e9SAndroid Build Coastguard Worker    onnx_outs = onnx_session.run(None, ort_inputs)
163*da0073e9SAndroid Build Coastguard Worker    return onnx_outs
164*da0073e9SAndroid Build Coastguard Worker
165*da0073e9SAndroid Build Coastguard Worker
166*da0073e9SAndroid Build Coastguard Workerdef _ort_session(
167*da0073e9SAndroid Build Coastguard Worker    model: str | io.BytesIO, ort_providers: Sequence[str] = _ORT_PROVIDERS
168*da0073e9SAndroid Build Coastguard Worker):
169*da0073e9SAndroid Build Coastguard Worker    try:
170*da0073e9SAndroid Build Coastguard Worker        import onnxruntime  # type: ignore[import]
171*da0073e9SAndroid Build Coastguard Worker    except ImportError as e:
172*da0073e9SAndroid Build Coastguard Worker        raise ImportError("onnxruntime is required for export verification.") from e
173*da0073e9SAndroid Build Coastguard Worker
174*da0073e9SAndroid Build Coastguard Worker    if ort_providers is None:
175*da0073e9SAndroid Build Coastguard Worker        ort_providers = _ORT_PROVIDERS
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Worker    session_options = onnxruntime.SessionOptions()
178*da0073e9SAndroid Build Coastguard Worker    # suppress ort warnings.
179*da0073e9SAndroid Build Coastguard Worker    # 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.
180*da0073e9SAndroid Build Coastguard Worker    session_options.log_severity_level = 3
181*da0073e9SAndroid Build Coastguard Worker    ort_session = onnxruntime.InferenceSession(
182*da0073e9SAndroid Build Coastguard Worker        model if isinstance(model, str) else model.getvalue(),
183*da0073e9SAndroid Build Coastguard Worker        session_options,
184*da0073e9SAndroid Build Coastguard Worker        providers=ort_providers,
185*da0073e9SAndroid Build Coastguard Worker    )
186*da0073e9SAndroid Build Coastguard Worker    return ort_session
187*da0073e9SAndroid Build Coastguard Worker
188*da0073e9SAndroid Build Coastguard Worker
189*da0073e9SAndroid Build Coastguard Workerdef _onnx_reference_evaluator_session(model: str | io.BytesIO):
190*da0073e9SAndroid Build Coastguard Worker    try:
191*da0073e9SAndroid Build Coastguard Worker        import onnx
192*da0073e9SAndroid Build Coastguard Worker        from onnx import reference as onnx_reference  # type: ignore[attr-defined]
193*da0073e9SAndroid Build Coastguard Worker    except ImportError as exc:
194*da0073e9SAndroid Build Coastguard Worker        raise ImportError("onnx >= 1.13 is required for reference evaluator.") from exc
195*da0073e9SAndroid Build Coastguard Worker
196*da0073e9SAndroid Build Coastguard Worker    proto = (
197*da0073e9SAndroid Build Coastguard Worker        onnx.load(model)  # type: ignore[attr-defined]
198*da0073e9SAndroid Build Coastguard Worker        if isinstance(model, str)
199*da0073e9SAndroid Build Coastguard Worker        else onnx.load_model_from_string(model.getvalue())  # type: ignore[attr-defined]
200*da0073e9SAndroid Build Coastguard Worker    )
201*da0073e9SAndroid Build Coastguard Worker    onnx_session = onnx_reference.ReferenceEvaluator(proto)
202*da0073e9SAndroid Build Coastguard Worker    return onnx_session
203*da0073e9SAndroid Build Coastguard Worker
204*da0073e9SAndroid Build Coastguard Worker
205*da0073e9SAndroid Build Coastguard Workerdef _onnx_backend_session(model: str | io.BytesIO, backend: OnnxBackend):
206*da0073e9SAndroid Build Coastguard Worker    if backend == OnnxBackend.REFERENCE:
207*da0073e9SAndroid Build Coastguard Worker        onnx_session = _onnx_reference_evaluator_session(model)
208*da0073e9SAndroid Build Coastguard Worker    elif backend in {OnnxBackend.ONNX_RUNTIME_CPU, OnnxBackend.ONNX_RUNTIME_CUDA}:
209*da0073e9SAndroid Build Coastguard Worker        onnx_session = _ort_session(model, (backend.value,))
210*da0073e9SAndroid Build Coastguard Worker    else:
211*da0073e9SAndroid Build Coastguard Worker        raise ValueError(f"Unsupported backend: {backend}")
212*da0073e9SAndroid Build Coastguard Worker    return onnx_session
213*da0073e9SAndroid Build Coastguard Worker
214*da0073e9SAndroid Build Coastguard Worker
215*da0073e9SAndroid Build Coastguard Workerdef _compare_onnx_pytorch_outputs_in_np(
216*da0073e9SAndroid Build Coastguard Worker    onnx_outs: _OutputsType,
217*da0073e9SAndroid Build Coastguard Worker    pt_outs: _OutputsType,
218*da0073e9SAndroid Build Coastguard Worker    options: VerificationOptions,
219*da0073e9SAndroid Build Coastguard Worker):
220*da0073e9SAndroid Build Coastguard Worker    assert (
221*da0073e9SAndroid Build Coastguard Worker        len(onnx_outs) == len(pt_outs)
222*da0073e9SAndroid Build Coastguard Worker    ), f"Number of outputs differ ONNX runtime: ({len(onnx_outs)}) PyTorch: ({len(pt_outs)})"
223*da0073e9SAndroid Build Coastguard Worker    acceptable_error_percentage = options.acceptable_error_percentage
224*da0073e9SAndroid Build Coastguard Worker    if acceptable_error_percentage and (
225*da0073e9SAndroid Build Coastguard Worker        acceptable_error_percentage > 1.0 or acceptable_error_percentage < 0.0
226*da0073e9SAndroid Build Coastguard Worker    ):
227*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
228*da0073e9SAndroid Build Coastguard Worker            "If set, acceptable_error_percentage should be between 0.0 and 1.0"
229*da0073e9SAndroid Build Coastguard Worker        )
230*da0073e9SAndroid Build Coastguard Worker
231*da0073e9SAndroid Build Coastguard Worker    for ort_out, pt_out in zip(onnx_outs, pt_outs):
232*da0073e9SAndroid Build Coastguard Worker        try:
233*da0073e9SAndroid Build Coastguard Worker            # TODO: Remove `check_shape` option once every shape inconsistent issue is addressed.
234*da0073e9SAndroid Build Coastguard Worker            if not options.check_shape:
235*da0073e9SAndroid Build Coastguard Worker                # Allow different but broadcastable output shapes.
236*da0073e9SAndroid Build Coastguard Worker                ort_out, pt_out = np.broadcast_arrays(ort_out, pt_out)
237*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(
238*da0073e9SAndroid Build Coastguard Worker                ort_out,
239*da0073e9SAndroid Build Coastguard Worker                pt_out,
240*da0073e9SAndroid Build Coastguard Worker                rtol=options.rtol,
241*da0073e9SAndroid Build Coastguard Worker                atol=options.atol,
242*da0073e9SAndroid Build Coastguard Worker                check_dtype=options.check_dtype,
243*da0073e9SAndroid Build Coastguard Worker                equal_nan=True,
244*da0073e9SAndroid Build Coastguard Worker            )
245*da0073e9SAndroid Build Coastguard Worker        except AssertionError as e:
246*da0073e9SAndroid Build Coastguard Worker            if acceptable_error_percentage:
247*da0073e9SAndroid Build Coastguard Worker                error_percentage = 1 - np.sum(
248*da0073e9SAndroid Build Coastguard Worker                    np.isclose(ort_out, pt_out, rtol=options.rtol, atol=options.atol)
249*da0073e9SAndroid Build Coastguard Worker                ) / np.prod(ort_out.shape)
250*da0073e9SAndroid Build Coastguard Worker                if error_percentage <= acceptable_error_percentage:
251*da0073e9SAndroid Build Coastguard Worker                    warnings.warn(
252*da0073e9SAndroid Build Coastguard Worker                        f"Suppressed AssertionError:\n{e}.\n"
253*da0073e9SAndroid Build Coastguard Worker                        f"Error percentage {error_percentage} "
254*da0073e9SAndroid Build Coastguard Worker                        f"within acceptable range {acceptable_error_percentage}."
255*da0073e9SAndroid Build Coastguard Worker                    )
256*da0073e9SAndroid Build Coastguard Worker                    continue
257*da0073e9SAndroid Build Coastguard Worker            if ort_out.dtype == np.uint8 or ort_out.dtype == np.int8:
258*da0073e9SAndroid Build Coastguard Worker                warnings.warn("ONNX output is quantized")
259*da0073e9SAndroid Build Coastguard Worker            if pt_out.dtype == np.uint8 or pt_out.dtype == np.int8:
260*da0073e9SAndroid Build Coastguard Worker                warnings.warn("PyTorch output is quantized")
261*da0073e9SAndroid Build Coastguard Worker            raise
262*da0073e9SAndroid Build Coastguard Worker
263*da0073e9SAndroid Build Coastguard Worker
264*da0073e9SAndroid Build Coastguard Workerdef _compare_onnx_pytorch_outputs(
265*da0073e9SAndroid Build Coastguard Worker    onnx_outs: _OutputsType,
266*da0073e9SAndroid Build Coastguard Worker    pt_outs: Any,
267*da0073e9SAndroid Build Coastguard Worker    options: VerificationOptions,
268*da0073e9SAndroid Build Coastguard Worker):
269*da0073e9SAndroid Build Coastguard Worker    """
270*da0073e9SAndroid Build Coastguard Worker    Compare ONNX and PyTorch outputs.
271*da0073e9SAndroid Build Coastguard Worker
272*da0073e9SAndroid Build Coastguard Worker    Args:
273*da0073e9SAndroid Build Coastguard Worker        onnx_outs: outputs from ONNX backend.
274*da0073e9SAndroid Build Coastguard Worker        pt_outs: outputs from PyTorch.
275*da0073e9SAndroid Build Coastguard Worker        options: options for verification.
276*da0073e9SAndroid Build Coastguard Worker
277*da0073e9SAndroid Build Coastguard Worker    Raises:
278*da0073e9SAndroid Build Coastguard Worker        AssertionError: if outputs from ONNX model and PyTorch model are not
279*da0073e9SAndroid Build Coastguard Worker            equal up to specified precision.
280*da0073e9SAndroid Build Coastguard Worker        ValueError: if arguments provided are invalid.
281*da0073e9SAndroid Build Coastguard Worker    """
282*da0073e9SAndroid Build Coastguard Worker    if options.ignore_none:
283*da0073e9SAndroid Build Coastguard Worker        # torch.jit._flatten filters None type
284*da0073e9SAndroid Build Coastguard Worker        pt_outs, _ = torch.jit._flatten(pt_outs)
285*da0073e9SAndroid Build Coastguard Worker    else:
286*da0073e9SAndroid Build Coastguard Worker        pt_outs = _inline_flatten_list([pt_outs], [])
287*da0073e9SAndroid Build Coastguard Worker    pt_outs_np = _unpack_to_numpy(pt_outs, cast_onnx_accepted=False)
288*da0073e9SAndroid Build Coastguard Worker    onnx_outs = _inline_flatten_list(onnx_outs, [])
289*da0073e9SAndroid Build Coastguard Worker    _compare_onnx_pytorch_outputs_in_np(onnx_outs, pt_outs_np, options)
290*da0073e9SAndroid Build Coastguard Worker
291*da0073e9SAndroid Build Coastguard Worker
292*da0073e9SAndroid Build Coastguard Workerdef _prepare_input_for_pytorch(args, kwargs):
293*da0073e9SAndroid Build Coastguard Worker    """Prepare input for PyTorch model execution.
294*da0073e9SAndroid Build Coastguard Worker
295*da0073e9SAndroid Build Coastguard Worker    Any future changes/formatting to the input before dispatching to the PyTorch
296*da0073e9SAndroid Build Coastguard Worker    model should be made in this function.
297*da0073e9SAndroid Build Coastguard Worker
298*da0073e9SAndroid Build Coastguard Worker    Args:
299*da0073e9SAndroid Build Coastguard Worker        args: positional arguments for PyTorch model forward method.
300*da0073e9SAndroid Build Coastguard Worker        kwargs: keyword arguments for PyTorch model forward method.
301*da0073e9SAndroid Build Coastguard Worker
302*da0073e9SAndroid Build Coastguard Worker    Returns:
303*da0073e9SAndroid Build Coastguard Worker        args: positional arguments for PyTorch model forward method.
304*da0073e9SAndroid Build Coastguard Worker        kwargs: keyword arguments for PyTorch model forward method.
305*da0073e9SAndroid Build Coastguard Worker    """
306*da0073e9SAndroid Build Coastguard Worker    if isinstance(args, (torch.Tensor, dict)):
307*da0073e9SAndroid Build Coastguard Worker        args = (args,)
308*da0073e9SAndroid Build Coastguard Worker    # In-place operators will update input tensor data as well.
309*da0073e9SAndroid Build Coastguard Worker    # Thus inputs are replicated before every forward call.
310*da0073e9SAndroid Build Coastguard Worker    args = copy.deepcopy(args)
311*da0073e9SAndroid Build Coastguard Worker    if kwargs:
312*da0073e9SAndroid Build Coastguard Worker        kwargs = copy.deepcopy(kwargs)
313*da0073e9SAndroid Build Coastguard Worker    else:
314*da0073e9SAndroid Build Coastguard Worker        kwargs = {}
315*da0073e9SAndroid Build Coastguard Worker    return args, kwargs
316*da0073e9SAndroid Build Coastguard Worker
317*da0073e9SAndroid Build Coastguard Worker
318*da0073e9SAndroid Build Coastguard Workerdef _prepare_input_for_export(args, kwargs):
319*da0073e9SAndroid Build Coastguard Worker    """Prepare input for ONNX model export.
320*da0073e9SAndroid Build Coastguard Worker
321*da0073e9SAndroid Build Coastguard Worker    Any future changes/formatting to the input before dispatching to the
322*da0073e9SAndroid Build Coastguard Worker    :func:`torch.onnx.export` api should be made in this function.
323*da0073e9SAndroid Build Coastguard Worker
324*da0073e9SAndroid Build Coastguard Worker    Args:
325*da0073e9SAndroid Build Coastguard Worker        args: positional arguments for PyTorch model forward method.
326*da0073e9SAndroid Build Coastguard Worker        kwargs: keyword arguments for PyTorch model forward method.
327*da0073e9SAndroid Build Coastguard Worker
328*da0073e9SAndroid Build Coastguard Worker    Returns:
329*da0073e9SAndroid Build Coastguard Worker        onnx_inputs: positional arguments for ONNX model export, as `args` in
330*da0073e9SAndroid Build Coastguard Worker            :func:`torch.onnx.export`.
331*da0073e9SAndroid Build Coastguard Worker    """
332*da0073e9SAndroid Build Coastguard Worker    args, kwargs = _prepare_input_for_pytorch(args, kwargs)
333*da0073e9SAndroid Build Coastguard Worker    if not kwargs and len(args) > 0 and isinstance(args[-1], dict):
334*da0073e9SAndroid Build Coastguard Worker        onnx_inputs = args + ({},)
335*da0073e9SAndroid Build Coastguard Worker    elif kwargs:
336*da0073e9SAndroid Build Coastguard Worker        onnx_inputs = args + (kwargs,)
337*da0073e9SAndroid Build Coastguard Worker    else:
338*da0073e9SAndroid Build Coastguard Worker        onnx_inputs = args
339*da0073e9SAndroid Build Coastguard Worker    return onnx_inputs
340*da0073e9SAndroid Build Coastguard Worker
341*da0073e9SAndroid Build Coastguard Worker
342*da0073e9SAndroid Build Coastguard Workerdef _prepare_input_for_onnx(
343*da0073e9SAndroid Build Coastguard Worker    args, kwargs, remained_onnx_input_idx: Sequence[int] | None, flatten: bool
344*da0073e9SAndroid Build Coastguard Worker):
345*da0073e9SAndroid Build Coastguard Worker    """Prepare input for ONNX model execution in ONNX backend.
346*da0073e9SAndroid Build Coastguard Worker
347*da0073e9SAndroid Build Coastguard Worker    Any future changes/formatting to the input before dispatching to the ONNX backend
348*da0073e9SAndroid Build Coastguard Worker    run should be made in this function.
349*da0073e9SAndroid Build Coastguard Worker
350*da0073e9SAndroid Build Coastguard Worker    Args:
351*da0073e9SAndroid Build Coastguard Worker        args: positional arguments for PyTorch model forward method.
352*da0073e9SAndroid Build Coastguard Worker        kwargs: keyword arguments for PyTorch model forward method.
353*da0073e9SAndroid Build Coastguard Worker        remained_onnx_input_idx: indices of inputs to be used for ONNX model execution.
354*da0073e9SAndroid Build Coastguard Worker        flatten: whether to flatten the input before dispatching to the ONNX model execution.
355*da0073e9SAndroid Build Coastguard Worker
356*da0073e9SAndroid Build Coastguard Worker    Returns:
357*da0073e9SAndroid Build Coastguard Worker        onnx_inputs: positional arguments for ONNX model execution in ONNX backend.
358*da0073e9SAndroid Build Coastguard Worker    """
359*da0073e9SAndroid Build Coastguard Worker    onnx_inputs = _prepare_input_for_export(args, kwargs)
360*da0073e9SAndroid Build Coastguard Worker    if flatten:
361*da0073e9SAndroid Build Coastguard Worker        onnx_inputs, _ = torch.jit._flatten(onnx_inputs)
362*da0073e9SAndroid Build Coastguard Worker    elif onnx_inputs and onnx_inputs[-1] == {}:
363*da0073e9SAndroid Build Coastguard Worker        # Handle empty kwargs (normally removed by flatten).
364*da0073e9SAndroid Build Coastguard Worker        onnx_inputs = onnx_inputs[:-1]
365*da0073e9SAndroid Build Coastguard Worker    if remained_onnx_input_idx is not None:
366*da0073e9SAndroid Build Coastguard Worker        return [onnx_inputs[i] for i in remained_onnx_input_idx]
367*da0073e9SAndroid Build Coastguard Worker    else:
368*da0073e9SAndroid Build Coastguard Worker        return onnx_inputs
369*da0073e9SAndroid Build Coastguard Worker
370*da0073e9SAndroid Build Coastguard Worker
371*da0073e9SAndroid Build Coastguard Workerdef _try_clone_model(model):
372*da0073e9SAndroid Build Coastguard Worker    """Used for preserving original model in case forward mutates model states."""
373*da0073e9SAndroid Build Coastguard Worker    try:
374*da0073e9SAndroid Build Coastguard Worker        return copy.deepcopy(model)
375*da0073e9SAndroid Build Coastguard Worker    except Exception:
376*da0073e9SAndroid Build Coastguard Worker        warnings.warn(
377*da0073e9SAndroid Build Coastguard Worker            "Failed to clone model. Model state might be mutated during verification."
378*da0073e9SAndroid Build Coastguard Worker        )
379*da0073e9SAndroid Build Coastguard Worker        return model
380*da0073e9SAndroid Build Coastguard Worker
381*da0073e9SAndroid Build Coastguard Worker
382*da0073e9SAndroid Build Coastguard Workerdef _compare_onnx_pytorch_model(
383*da0073e9SAndroid Build Coastguard Worker    pt_model: _ModelType,
384*da0073e9SAndroid Build Coastguard Worker    onnx_model_f: str | io.BytesIO,
385*da0073e9SAndroid Build Coastguard Worker    input_args: _InputArgsType,
386*da0073e9SAndroid Build Coastguard Worker    input_kwargs: _InputKwargsType | None,
387*da0073e9SAndroid Build Coastguard Worker    additional_test_inputs: Sequence[_InputArgsType] | None,
388*da0073e9SAndroid Build Coastguard Worker    options: VerificationOptions,
389*da0073e9SAndroid Build Coastguard Worker):
390*da0073e9SAndroid Build Coastguard Worker    """Compare outputs from ONNX model runs with outputs from PyTorch model runs.
391*da0073e9SAndroid Build Coastguard Worker
392*da0073e9SAndroid Build Coastguard Worker    Args:
393*da0073e9SAndroid Build Coastguard Worker        pt_model: PyTorch model.
394*da0073e9SAndroid Build Coastguard Worker        onnx_model_f: ONNX model file path or file-like object.
395*da0073e9SAndroid Build Coastguard Worker        input_args: positional arguments for PyTorch model forward method.
396*da0073e9SAndroid Build Coastguard Worker        input_kwargs: keyword arguments for PyTorch model forward method.
397*da0073e9SAndroid Build Coastguard Worker        additional_test_inputs: additional positional arguments for PyTorch model
398*da0073e9SAndroid Build Coastguard Worker            forward method.
399*da0073e9SAndroid Build Coastguard Worker        options: options for verification.
400*da0073e9SAndroid Build Coastguard Worker
401*da0073e9SAndroid Build Coastguard Worker    Raises:
402*da0073e9SAndroid Build Coastguard Worker        AssertionError: if outputs from ONNX model and PyTorch model are not
403*da0073e9SAndroid Build Coastguard Worker            equal up to specified precision.
404*da0073e9SAndroid Build Coastguard Worker    """
405*da0073e9SAndroid Build Coastguard Worker    onnx_session = _onnx_backend_session(onnx_model_f, options.backend)
406*da0073e9SAndroid Build Coastguard Worker
407*da0073e9SAndroid Build Coastguard Worker    def compare_onnx_pytorch_model_with_input(input_args, input_kwargs):
408*da0073e9SAndroid Build Coastguard Worker        pt_args, pt_kwargs = _prepare_input_for_pytorch(input_args, input_kwargs)
409*da0073e9SAndroid Build Coastguard Worker        # TODO: remove this and treat mutating model separately. See #77679
410*da0073e9SAndroid Build Coastguard Worker        pt_model_copy = _try_clone_model(pt_model)
411*da0073e9SAndroid Build Coastguard Worker        pt_outs = pt_model_copy(*pt_args, **pt_kwargs)
412*da0073e9SAndroid Build Coastguard Worker
413*da0073e9SAndroid Build Coastguard Worker        onnx_inputs = _prepare_input_for_onnx(
414*da0073e9SAndroid Build Coastguard Worker            input_args, input_kwargs, options.remained_onnx_input_idx, options.flatten
415*da0073e9SAndroid Build Coastguard Worker        )
416*da0073e9SAndroid Build Coastguard Worker
417*da0073e9SAndroid Build Coastguard Worker        onnx_outs = _run_onnx(onnx_session, onnx_inputs)
418*da0073e9SAndroid Build Coastguard Worker
419*da0073e9SAndroid Build Coastguard Worker        _compare_onnx_pytorch_outputs(
420*da0073e9SAndroid Build Coastguard Worker            onnx_outs=onnx_outs,
421*da0073e9SAndroid Build Coastguard Worker            pt_outs=pt_outs,
422*da0073e9SAndroid Build Coastguard Worker            options=options,
423*da0073e9SAndroid Build Coastguard Worker        )
424*da0073e9SAndroid Build Coastguard Worker
425*da0073e9SAndroid Build Coastguard Worker    compare_onnx_pytorch_model_with_input(input_args, input_kwargs)
426*da0073e9SAndroid Build Coastguard Worker
427*da0073e9SAndroid Build Coastguard Worker    if additional_test_inputs:
428*da0073e9SAndroid Build Coastguard Worker        for test_input_args in additional_test_inputs:
429*da0073e9SAndroid Build Coastguard Worker            compare_onnx_pytorch_model_with_input(test_input_args, {})
430*da0073e9SAndroid Build Coastguard Worker
431*da0073e9SAndroid Build Coastguard Worker
432*da0073e9SAndroid Build Coastguard Workerclass _GraphDiff:
433*da0073e9SAndroid Build Coastguard Worker    """A class to represent the difference between two graphs."""
434*da0073e9SAndroid Build Coastguard Worker
435*da0073e9SAndroid Build Coastguard Worker    def __init__(self, graph_a: _C.Graph, graph_b: _C.Graph):
436*da0073e9SAndroid Build Coastguard Worker        """Construct a _GraphDiff object.
437*da0073e9SAndroid Build Coastguard Worker
438*da0073e9SAndroid Build Coastguard Worker        Args:
439*da0073e9SAndroid Build Coastguard Worker            graph_a (_C.Graph): First graph to compare.
440*da0073e9SAndroid Build Coastguard Worker            graph_b (_C.Graph): Second graph to compare.
441*da0073e9SAndroid Build Coastguard Worker        """
442*da0073e9SAndroid Build Coastguard Worker        self.graph_a = graph_a
443*da0073e9SAndroid Build Coastguard Worker        self.graph_b = graph_b
444*da0073e9SAndroid Build Coastguard Worker
445*da0073e9SAndroid Build Coastguard Worker    def __str__(self):
446*da0073e9SAndroid Build Coastguard Worker        """See function :func:`diff_report`."""
447*da0073e9SAndroid Build Coastguard Worker        return self.diff_report()
448*da0073e9SAndroid Build Coastguard Worker
449*da0073e9SAndroid Build Coastguard Worker    def _indent(self, lines: str) -> str:
450*da0073e9SAndroid Build Coastguard Worker        return "\n".join(["\t" + line for line in lines.splitlines()])
451*da0073e9SAndroid Build Coastguard Worker
452*da0073e9SAndroid Build Coastguard Worker    def diff_report(self) -> str:
453*da0073e9SAndroid Build Coastguard Worker        """Return a string representation of the graph difference.
454*da0073e9SAndroid Build Coastguard Worker
455*da0073e9SAndroid Build Coastguard Worker        The report shows the first pair of nodes that diverges. It also shows the source
456*da0073e9SAndroid Build Coastguard Worker        location of the pair of nodes.
457*da0073e9SAndroid Build Coastguard Worker
458*da0073e9SAndroid Build Coastguard Worker        Returns:
459*da0073e9SAndroid Build Coastguard Worker            graph_diff_report (str): A string representation of the graph difference.
460*da0073e9SAndroid Build Coastguard Worker        """
461*da0073e9SAndroid Build Coastguard Worker        graph_a = self.graph_a
462*da0073e9SAndroid Build Coastguard Worker        graph_b = self.graph_b
463*da0073e9SAndroid Build Coastguard Worker
464*da0073e9SAndroid Build Coastguard Worker        graph_a_str = str(graph_a)
465*da0073e9SAndroid Build Coastguard Worker        graph_b_str = str(graph_b)
466*da0073e9SAndroid Build Coastguard Worker
467*da0073e9SAndroid Build Coastguard Worker        if graph_a_str == graph_b_str:
468*da0073e9SAndroid Build Coastguard Worker            return ""
469*da0073e9SAndroid Build Coastguard Worker
470*da0073e9SAndroid Build Coastguard Worker        graph_diff = difflib.ndiff(
471*da0073e9SAndroid Build Coastguard Worker            graph_a_str.splitlines(True), graph_b_str.splitlines(True)
472*da0073e9SAndroid Build Coastguard Worker        )
473*da0073e9SAndroid Build Coastguard Worker        graph_diff_report = ["Graph diff:", self._indent("".join(graph_diff))]
474*da0073e9SAndroid Build Coastguard Worker
475*da0073e9SAndroid Build Coastguard Worker        for node_a, node_b in itertools.zip_longest(graph_a.nodes(), graph_b.nodes()):
476*da0073e9SAndroid Build Coastguard Worker            if str(node_a) != str(node_b):
477*da0073e9SAndroid Build Coastguard Worker                graph_diff_report.append("First diverging operator:")
478*da0073e9SAndroid Build Coastguard Worker                node_diff = difflib.ndiff(
479*da0073e9SAndroid Build Coastguard Worker                    str(node_a).splitlines(True), str(node_b).splitlines(True)
480*da0073e9SAndroid Build Coastguard Worker                )
481*da0073e9SAndroid Build Coastguard Worker                source_printout = ["node diff:", self._indent("".join(node_diff))]
482*da0073e9SAndroid Build Coastguard Worker
483*da0073e9SAndroid Build Coastguard Worker                stack_a = node_a.sourceRange() if node_a else None
484*da0073e9SAndroid Build Coastguard Worker                if stack_a:
485*da0073e9SAndroid Build Coastguard Worker                    source_printout.extend(
486*da0073e9SAndroid Build Coastguard Worker                        ["Former source location:", self._indent(str(stack_a))]
487*da0073e9SAndroid Build Coastguard Worker                    )
488*da0073e9SAndroid Build Coastguard Worker                stack_b = node_b.sourceRange() if node_b else None
489*da0073e9SAndroid Build Coastguard Worker                if stack_b:
490*da0073e9SAndroid Build Coastguard Worker                    source_printout.extend(
491*da0073e9SAndroid Build Coastguard Worker                        ["Latter source location:", self._indent(str(stack_b))]
492*da0073e9SAndroid Build Coastguard Worker                    )
493*da0073e9SAndroid Build Coastguard Worker
494*da0073e9SAndroid Build Coastguard Worker                graph_diff_report.extend(source_printout)
495*da0073e9SAndroid Build Coastguard Worker
496*da0073e9SAndroid Build Coastguard Worker                break
497*da0073e9SAndroid Build Coastguard Worker
498*da0073e9SAndroid Build Coastguard Worker        return "\n".join(graph_diff_report)
499*da0073e9SAndroid Build Coastguard Worker
500*da0073e9SAndroid Build Coastguard Worker
501*da0073e9SAndroid Build Coastguard Workerdef _check_graph_diff(
502*da0073e9SAndroid Build Coastguard Worker    model: torch.nn.Module | torch.jit.ScriptModule,
503*da0073e9SAndroid Build Coastguard Worker    test_input_groups: Sequence[tuple[tuple[Any, ...], Mapping[str, Any]]],
504*da0073e9SAndroid Build Coastguard Worker    export_options: _experimental.ExportOptions,
505*da0073e9SAndroid Build Coastguard Worker    model_to_graph_func: Callable[
506*da0073e9SAndroid Build Coastguard Worker        [
507*da0073e9SAndroid Build Coastguard Worker            torch.nn.Module,
508*da0073e9SAndroid Build Coastguard Worker            tuple[Any, ...],
509*da0073e9SAndroid Build Coastguard Worker            Mapping[str, Any],
510*da0073e9SAndroid Build Coastguard Worker            _experimental.ExportOptions,
511*da0073e9SAndroid Build Coastguard Worker        ],
512*da0073e9SAndroid Build Coastguard Worker        _C.Graph,
513*da0073e9SAndroid Build Coastguard Worker    ],
514*da0073e9SAndroid Build Coastguard Worker) -> str:
515*da0073e9SAndroid Build Coastguard Worker    """Check if graph produced by `model_to_graph_func` is the same across `test_input_groups`.
516*da0073e9SAndroid Build Coastguard Worker
517*da0073e9SAndroid Build Coastguard Worker    Args:
518*da0073e9SAndroid Build Coastguard Worker        model: See :func:`check_export_model_diff`.
519*da0073e9SAndroid Build Coastguard Worker        test_input_groups: See :func:`check_export_model_diff`.
520*da0073e9SAndroid Build Coastguard Worker        export_options: See :func:`check_export_model_diff`.
521*da0073e9SAndroid Build Coastguard Worker        model_to_graph_func: A function to convert a PyTorch model to a JIT IR graph.
522*da0073e9SAndroid Build Coastguard Worker
523*da0073e9SAndroid Build Coastguard Worker    Returns:
524*da0073e9SAndroid Build Coastguard Worker        graph_diff_report (str): A string representation of the graph difference.
525*da0073e9SAndroid Build Coastguard Worker    """
526*da0073e9SAndroid Build Coastguard Worker    if len(test_input_groups) < 2:
527*da0073e9SAndroid Build Coastguard Worker        raise ValueError("Need at least two groups of test inputs to compare.")
528*da0073e9SAndroid Build Coastguard Worker
529*da0073e9SAndroid Build Coastguard Worker    ref_jit_graph = None
530*da0073e9SAndroid Build Coastguard Worker    for args, kwargs in test_input_groups:
531*da0073e9SAndroid Build Coastguard Worker        jit_graph = model_to_graph_func(model, args, kwargs, export_options)
532*da0073e9SAndroid Build Coastguard Worker        if ref_jit_graph is None:
533*da0073e9SAndroid Build Coastguard Worker            ref_jit_graph = jit_graph
534*da0073e9SAndroid Build Coastguard Worker            continue
535*da0073e9SAndroid Build Coastguard Worker
536*da0073e9SAndroid Build Coastguard Worker        graph_diff_report = _GraphDiff(ref_jit_graph, jit_graph).diff_report()
537*da0073e9SAndroid Build Coastguard Worker        if graph_diff_report:
538*da0073e9SAndroid Build Coastguard Worker            return graph_diff_report
539*da0073e9SAndroid Build Coastguard Worker    return ""
540*da0073e9SAndroid Build Coastguard Worker
541*da0073e9SAndroid Build Coastguard Worker
542*da0073e9SAndroid Build Coastguard Workerdef _traced_graph_from_model(
543*da0073e9SAndroid Build Coastguard Worker    model: torch.nn.Module | torch.jit.ScriptModule,
544*da0073e9SAndroid Build Coastguard Worker    args: tuple[Any, ...],
545*da0073e9SAndroid Build Coastguard Worker    kwargs: Mapping[str, Any],
546*da0073e9SAndroid Build Coastguard Worker    export_options: _experimental.ExportOptions,
547*da0073e9SAndroid Build Coastguard Worker) -> _C.Graph:
548*da0073e9SAndroid Build Coastguard Worker    """As part of the ONNX export steps, create a traced JIT graph from a PyTorch model.
549*da0073e9SAndroid Build Coastguard Worker
550*da0073e9SAndroid Build Coastguard Worker    Args:
551*da0073e9SAndroid Build Coastguard Worker        model: See :func:`check_export_model_diff`.
552*da0073e9SAndroid Build Coastguard Worker        args: See :func:`check_export_model_diff`.
553*da0073e9SAndroid Build Coastguard Worker        kwargs: See :func:`check_export_model_diff`.
554*da0073e9SAndroid Build Coastguard Worker        export_options: See :func:`check_export_model_diff`.
555*da0073e9SAndroid Build Coastguard Worker
556*da0073e9SAndroid Build Coastguard Worker    Returns:
557*da0073e9SAndroid Build Coastguard Worker        jit_graph (_C.Graph): A traced JIT graph.
558*da0073e9SAndroid Build Coastguard Worker    """
559*da0073e9SAndroid Build Coastguard Worker    training = export_options.training
560*da0073e9SAndroid Build Coastguard Worker    verbose = export_options.verbose
561*da0073e9SAndroid Build Coastguard Worker
562*da0073e9SAndroid Build Coastguard Worker    with utils.exporter_context(model, training, verbose):
563*da0073e9SAndroid Build Coastguard Worker        export_inputs = _prepare_input_for_export(args, kwargs)
564*da0073e9SAndroid Build Coastguard Worker        model = utils._pre_trace_quant_model(model, export_inputs)
565*da0073e9SAndroid Build Coastguard Worker        jit_graph, _, _, _ = utils._create_jit_graph(model, export_inputs)
566*da0073e9SAndroid Build Coastguard Worker        return jit_graph
567*da0073e9SAndroid Build Coastguard Worker
568*da0073e9SAndroid Build Coastguard Worker
569*da0073e9SAndroid Build Coastguard Workerdef _onnx_graph_from_model(
570*da0073e9SAndroid Build Coastguard Worker    model: torch.nn.Module | torch.jit.ScriptModule,
571*da0073e9SAndroid Build Coastguard Worker    args: tuple[Any, ...],
572*da0073e9SAndroid Build Coastguard Worker    kwargs: Mapping[str, Any],
573*da0073e9SAndroid Build Coastguard Worker    export_options: _experimental.ExportOptions,
574*da0073e9SAndroid Build Coastguard Worker) -> _C.Graph:
575*da0073e9SAndroid Build Coastguard Worker    """As part of the ONNX export steps, export an ONNX JIT graph from a PyTorch model.
576*da0073e9SAndroid Build Coastguard Worker
577*da0073e9SAndroid Build Coastguard Worker    Args:
578*da0073e9SAndroid Build Coastguard Worker        model: See :func:`check_export_model_diff`.
579*da0073e9SAndroid Build Coastguard Worker        args: See :func:`check_export_model_diff`.
580*da0073e9SAndroid Build Coastguard Worker        kwargs: See :func:`check_export_model_diff`.
581*da0073e9SAndroid Build Coastguard Worker        export_options: See :func:`check_export_model_diff`.
582*da0073e9SAndroid Build Coastguard Worker
583*da0073e9SAndroid Build Coastguard Worker    Returns:
584*da0073e9SAndroid Build Coastguard Worker        onnx_graph (_C.Graph): An ONNX JIT graph.
585*da0073e9SAndroid Build Coastguard Worker    """
586*da0073e9SAndroid Build Coastguard Worker    # TODO: refactor utils.py to remove duplicated code of context setup. See #78834
587*da0073e9SAndroid Build Coastguard Worker    opset_version = export_options.opset_version
588*da0073e9SAndroid Build Coastguard Worker    operator_export_type = export_options.operator_export_type
589*da0073e9SAndroid Build Coastguard Worker    export_modules_as_functions = export_options.export_modules_as_functions
590*da0073e9SAndroid Build Coastguard Worker    training = export_options.training
591*da0073e9SAndroid Build Coastguard Worker    verbose = export_options.verbose
592*da0073e9SAndroid Build Coastguard Worker    dynamic_axes = export_options.dynamic_axes
593*da0073e9SAndroid Build Coastguard Worker    input_names = export_options.input_names
594*da0073e9SAndroid Build Coastguard Worker    output_names = export_options.output_names
595*da0073e9SAndroid Build Coastguard Worker
596*da0073e9SAndroid Build Coastguard Worker    if opset_version is None:
597*da0073e9SAndroid Build Coastguard Worker        opset_version = _constants.ONNX_DEFAULT_OPSET
598*da0073e9SAndroid Build Coastguard Worker
599*da0073e9SAndroid Build Coastguard Worker    utils._setup_trace_module_map(model, export_modules_as_functions)
600*da0073e9SAndroid Build Coastguard Worker
601*da0073e9SAndroid Build Coastguard Worker    if not operator_export_type:
602*da0073e9SAndroid Build Coastguard Worker        operator_export_type = _C_onnx.OperatorExportTypes.ONNX
603*da0073e9SAndroid Build Coastguard Worker
604*da0073e9SAndroid Build Coastguard Worker    GLOBALS.export_onnx_opset_version = opset_version
605*da0073e9SAndroid Build Coastguard Worker    GLOBALS.operator_export_type = operator_export_type
606*da0073e9SAndroid Build Coastguard Worker
607*da0073e9SAndroid Build Coastguard Worker    with utils.exporter_context(model, training, verbose):
608*da0073e9SAndroid Build Coastguard Worker        do_constant_folding = utils._decide_constant_folding(
609*da0073e9SAndroid Build Coastguard Worker            export_options.do_constant_folding, operator_export_type, training
610*da0073e9SAndroid Build Coastguard Worker        )
611*da0073e9SAndroid Build Coastguard Worker
612*da0073e9SAndroid Build Coastguard Worker        if dynamic_axes is None:
613*da0073e9SAndroid Build Coastguard Worker            dynamic_axes = {}
614*da0073e9SAndroid Build Coastguard Worker        utils._validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
615*da0073e9SAndroid Build Coastguard Worker
616*da0073e9SAndroid Build Coastguard Worker        export_inputs = _prepare_input_for_export(args, kwargs)
617*da0073e9SAndroid Build Coastguard Worker        export_inputs = utils._decide_input_format(model, export_inputs)
618*da0073e9SAndroid Build Coastguard Worker        onnx_graph, _, _ = utils._model_to_graph(
619*da0073e9SAndroid Build Coastguard Worker            model,
620*da0073e9SAndroid Build Coastguard Worker            export_inputs,
621*da0073e9SAndroid Build Coastguard Worker            verbose,
622*da0073e9SAndroid Build Coastguard Worker            input_names,
623*da0073e9SAndroid Build Coastguard Worker            output_names,
624*da0073e9SAndroid Build Coastguard Worker            operator_export_type,
625*da0073e9SAndroid Build Coastguard Worker            do_constant_folding,
626*da0073e9SAndroid Build Coastguard Worker            training=training,
627*da0073e9SAndroid Build Coastguard Worker            dynamic_axes=dynamic_axes,
628*da0073e9SAndroid Build Coastguard Worker        )
629*da0073e9SAndroid Build Coastguard Worker
630*da0073e9SAndroid Build Coastguard Worker        return onnx_graph
631*da0073e9SAndroid Build Coastguard Worker
632*da0073e9SAndroid Build Coastguard Worker
633*da0073e9SAndroid Build Coastguard Workerdef _onnx_graph_from_aten_graph(
634*da0073e9SAndroid Build Coastguard Worker    graph: torch.Graph,
635*da0073e9SAndroid Build Coastguard Worker    export_options: _experimental.ExportOptions,
636*da0073e9SAndroid Build Coastguard Worker    params_dict: dict[str, Any] | None = None,
637*da0073e9SAndroid Build Coastguard Worker) -> tuple[torch.Graph, dict[str, Any]]:
638*da0073e9SAndroid Build Coastguard Worker    if params_dict is None:
639*da0073e9SAndroid Build Coastguard Worker        params_dict = {}
640*da0073e9SAndroid Build Coastguard Worker    operator_export_type = export_options.operator_export_type
641*da0073e9SAndroid Build Coastguard Worker    dynamic_axes = export_options.dynamic_axes or {}
642*da0073e9SAndroid Build Coastguard Worker    input_names = export_options.input_names
643*da0073e9SAndroid Build Coastguard Worker    training = export_options.training
644*da0073e9SAndroid Build Coastguard Worker    do_constant_folding = export_options.do_constant_folding
645*da0073e9SAndroid Build Coastguard Worker    opset_version = export_options.opset_version or _constants.ONNX_DEFAULT_OPSET
646*da0073e9SAndroid Build Coastguard Worker
647*da0073e9SAndroid Build Coastguard Worker    GLOBALS.export_onnx_opset_version = opset_version
648*da0073e9SAndroid Build Coastguard Worker    GLOBALS.operator_export_type = operator_export_type
649*da0073e9SAndroid Build Coastguard Worker
650*da0073e9SAndroid Build Coastguard Worker    do_constant_folding = utils._decide_constant_folding(
651*da0073e9SAndroid Build Coastguard Worker        do_constant_folding, operator_export_type, training
652*da0073e9SAndroid Build Coastguard Worker    )
653*da0073e9SAndroid Build Coastguard Worker
654*da0073e9SAndroid Build Coastguard Worker    # TODO: Below is doing aten graph to onnx. It should be abstracted as a
655*da0073e9SAndroid Build Coastguard Worker    # function in torch/onnx/utils.py.
656*da0073e9SAndroid Build Coastguard Worker    graph = graph.copy()
657*da0073e9SAndroid Build Coastguard Worker    graph = utils._optimize_graph(
658*da0073e9SAndroid Build Coastguard Worker        graph,
659*da0073e9SAndroid Build Coastguard Worker        operator_export_type,
660*da0073e9SAndroid Build Coastguard Worker        params_dict=params_dict,
661*da0073e9SAndroid Build Coastguard Worker        dynamic_axes=dynamic_axes,
662*da0073e9SAndroid Build Coastguard Worker        input_names=input_names,
663*da0073e9SAndroid Build Coastguard Worker    )
664*da0073e9SAndroid Build Coastguard Worker
665*da0073e9SAndroid Build Coastguard Worker    if training is None or training == _C_onnx.TrainingMode.EVAL:
666*da0073e9SAndroid Build Coastguard Worker        params_dict = torch._C._jit_pass_onnx_eval_peephole(graph, params_dict)
667*da0073e9SAndroid Build Coastguard Worker
668*da0073e9SAndroid Build Coastguard Worker    if (
669*da0073e9SAndroid Build Coastguard Worker        do_constant_folding
670*da0073e9SAndroid Build Coastguard Worker        and opset_version >= _constants.ONNX_CONSTANT_FOLDING_MIN_OPSET
671*da0073e9SAndroid Build Coastguard Worker    ):
672*da0073e9SAndroid Build Coastguard Worker        params_dict = _C._jit_pass_onnx_constant_fold(graph, params_dict, opset_version)
673*da0073e9SAndroid Build Coastguard Worker        _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
674*da0073e9SAndroid Build Coastguard Worker
675*da0073e9SAndroid Build Coastguard Worker    if GLOBALS.onnx_shape_inference:
676*da0073e9SAndroid Build Coastguard Worker        _C._jit_pass_onnx_graph_shape_type_inference(graph, params_dict, opset_version)
677*da0073e9SAndroid Build Coastguard Worker
678*da0073e9SAndroid Build Coastguard Worker    params_dict = _C._jit_pass_onnx_eliminate_unused_items(graph, params_dict)
679*da0073e9SAndroid Build Coastguard Worker
680*da0073e9SAndroid Build Coastguard Worker    # For ONNX opset < 9, constants only have three data types: float16, float, double.
681*da0073e9SAndroid Build Coastguard Worker    # In this pass transform constants of other data types to float/double + cast operator.
682*da0073e9SAndroid Build Coastguard Worker    if opset_version < 9:
683*da0073e9SAndroid Build Coastguard Worker        _C._jit_pass_onnx_cast_all_constant_to_floating(graph)
684*da0073e9SAndroid Build Coastguard Worker
685*da0073e9SAndroid Build Coastguard Worker    params_dict = _C._jit_pass_filter_non_tensor_arguments(params_dict)
686*da0073e9SAndroid Build Coastguard Worker    _C._jit_decay_packed_param_input_types(graph)
687*da0073e9SAndroid Build Coastguard Worker
688*da0073e9SAndroid Build Coastguard Worker    _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
689*da0073e9SAndroid Build Coastguard Worker
690*da0073e9SAndroid Build Coastguard Worker    if export_options.verbose:
691*da0073e9SAndroid Build Coastguard Worker        print("ONNX graph: ", graph)
692*da0073e9SAndroid Build Coastguard Worker
693*da0073e9SAndroid Build Coastguard Worker    return graph, params_dict
694*da0073e9SAndroid Build Coastguard Worker
695*da0073e9SAndroid Build Coastguard Worker
696*da0073e9SAndroid Build Coastguard Workerdef _onnx_proto_from_onnx_graph(
697*da0073e9SAndroid Build Coastguard Worker    onnx_graph: torch.Graph,
698*da0073e9SAndroid Build Coastguard Worker    export_options: _experimental.ExportOptions,
699*da0073e9SAndroid Build Coastguard Worker    params_dict: dict[str, Any],
700*da0073e9SAndroid Build Coastguard Worker) -> tuple[bytes, Mapping[str, bytes]]:
701*da0073e9SAndroid Build Coastguard Worker    opset_version = export_options.opset_version or _constants.ONNX_DEFAULT_OPSET
702*da0073e9SAndroid Build Coastguard Worker    dynamic_axes = export_options.dynamic_axes or {}
703*da0073e9SAndroid Build Coastguard Worker    operator_export_type = export_options.operator_export_type
704*da0073e9SAndroid Build Coastguard Worker    val_keep_init_as_ip = utils._decide_keep_init_as_input(
705*da0073e9SAndroid Build Coastguard Worker        export_options.keep_initializers_as_inputs,
706*da0073e9SAndroid Build Coastguard Worker        operator_export_type,
707*da0073e9SAndroid Build Coastguard Worker        opset_version,
708*da0073e9SAndroid Build Coastguard Worker    )
709*da0073e9SAndroid Build Coastguard Worker    val_add_node_names = utils._decide_add_node_names(True, operator_export_type)
710*da0073e9SAndroid Build Coastguard Worker    custom_opsets = export_options.custom_opsets or {}
711*da0073e9SAndroid Build Coastguard Worker
712*da0073e9SAndroid Build Coastguard Worker    proto, export_map, _, _ = onnx_graph._export_onnx(  # type: ignore[attr-defined]
713*da0073e9SAndroid Build Coastguard Worker        params_dict,
714*da0073e9SAndroid Build Coastguard Worker        opset_version,
715*da0073e9SAndroid Build Coastguard Worker        dynamic_axes,
716*da0073e9SAndroid Build Coastguard Worker        False,
717*da0073e9SAndroid Build Coastguard Worker        operator_export_type,
718*da0073e9SAndroid Build Coastguard Worker        not export_options.verbose,
719*da0073e9SAndroid Build Coastguard Worker        val_keep_init_as_ip,
720*da0073e9SAndroid Build Coastguard Worker        custom_opsets,
721*da0073e9SAndroid Build Coastguard Worker        val_add_node_names,
722*da0073e9SAndroid Build Coastguard Worker        "",
723*da0073e9SAndroid Build Coastguard Worker        {},
724*da0073e9SAndroid Build Coastguard Worker    )
725*da0073e9SAndroid Build Coastguard Worker
726*da0073e9SAndroid Build Coastguard Worker    return proto, export_map
727*da0073e9SAndroid Build Coastguard Worker
728*da0073e9SAndroid Build Coastguard Worker
729*da0073e9SAndroid Build Coastguard Workerdef check_export_model_diff(
730*da0073e9SAndroid Build Coastguard Worker    model: torch.nn.Module | torch.jit.ScriptModule,
731*da0073e9SAndroid Build Coastguard Worker    test_input_groups: Sequence[tuple[tuple[Any, ...], Mapping[str, Any]]],
732*da0073e9SAndroid Build Coastguard Worker    export_options: _experimental.ExportOptions | None = None,
733*da0073e9SAndroid Build Coastguard Worker) -> str:
734*da0073e9SAndroid Build Coastguard Worker    """Verify exported model discrepancy between different groups of inputs.
735*da0073e9SAndroid Build Coastguard Worker
736*da0073e9SAndroid Build Coastguard Worker    A graph is exported for each group of inputs. The exported graphs are then compared
737*da0073e9SAndroid Build Coastguard Worker    to each other, and discrepancies of first pair of nodes are reported. This function
738*da0073e9SAndroid Build Coastguard Worker    first checks the jit graph. If no discrepancies were found, it then checks the onnx
739*da0073e9SAndroid Build Coastguard Worker    graph.
740*da0073e9SAndroid Build Coastguard Worker
741*da0073e9SAndroid Build Coastguard Worker    Unless otherwise specified, the jit/ONNX graph is expected to be the same, regardless
742*da0073e9SAndroid Build Coastguard Worker    of the inputs used for exporting. A discrepancy implies the graph exported is
743*da0073e9SAndroid Build Coastguard Worker    not accurate when run on other groups of inputs, which will typically results in
744*da0073e9SAndroid Build Coastguard Worker    runtime errors or mismatching output.
745*da0073e9SAndroid Build Coastguard Worker
746*da0073e9SAndroid Build Coastguard Worker    Args:
747*da0073e9SAndroid Build Coastguard Worker        model (torch.nn.Module or torch.jit.ScriptModule): The model to be exported.
748*da0073e9SAndroid Build Coastguard Worker        test_input_groups (Sequence[Tuple[Tuple[Any, ...], Mapping[str, Any]]]): A sequence
749*da0073e9SAndroid Build Coastguard Worker            of input groups to be used to export the model. Each input group is a pair of
750*da0073e9SAndroid Build Coastguard Worker            (args, kwargs).
751*da0073e9SAndroid Build Coastguard Worker        export_options (_experimental.ExportOptions, optional): An _experimental.ExportOptions
752*da0073e9SAndroid Build Coastguard Worker            object that controls the export behavior.
753*da0073e9SAndroid Build Coastguard Worker
754*da0073e9SAndroid Build Coastguard Worker    Returns:
755*da0073e9SAndroid Build Coastguard Worker        str: A string containing the diff of the exported models.
756*da0073e9SAndroid Build Coastguard Worker    """
757*da0073e9SAndroid Build Coastguard Worker    export_options = (
758*da0073e9SAndroid Build Coastguard Worker        _experimental.ExportOptions() if export_options is None else export_options
759*da0073e9SAndroid Build Coastguard Worker    )
760*da0073e9SAndroid Build Coastguard Worker
761*da0073e9SAndroid Build Coastguard Worker    jit_diff_report = _check_graph_diff(
762*da0073e9SAndroid Build Coastguard Worker        model, test_input_groups, export_options, _traced_graph_from_model
763*da0073e9SAndroid Build Coastguard Worker    )
764*da0073e9SAndroid Build Coastguard Worker    if jit_diff_report:
765*da0073e9SAndroid Build Coastguard Worker        return jit_diff_report
766*da0073e9SAndroid Build Coastguard Worker
767*da0073e9SAndroid Build Coastguard Worker    return _check_graph_diff(
768*da0073e9SAndroid Build Coastguard Worker        model, test_input_groups, export_options, _onnx_graph_from_model
769*da0073e9SAndroid Build Coastguard Worker    )
770*da0073e9SAndroid Build Coastguard Worker
771*da0073e9SAndroid Build Coastguard Worker
772*da0073e9SAndroid Build Coastguard Workerdef verify(
773*da0073e9SAndroid Build Coastguard Worker    model: _ModelType,
774*da0073e9SAndroid Build Coastguard Worker    input_args: _InputArgsType,
775*da0073e9SAndroid Build Coastguard Worker    input_kwargs: _InputKwargsType | None = None,
776*da0073e9SAndroid Build Coastguard Worker    do_constant_folding: bool = True,
777*da0073e9SAndroid Build Coastguard Worker    dynamic_axes: Mapping[str, Mapping[int, str] | Mapping[str, Sequence[int]]]
778*da0073e9SAndroid Build Coastguard Worker    | None = None,
779*da0073e9SAndroid Build Coastguard Worker    input_names: Sequence[str] | None = None,
780*da0073e9SAndroid Build Coastguard Worker    output_names: Sequence[str] | None = None,
781*da0073e9SAndroid Build Coastguard Worker    training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL,
782*da0073e9SAndroid Build Coastguard Worker    opset_version: int | None = None,
783*da0073e9SAndroid Build Coastguard Worker    keep_initializers_as_inputs: bool = True,
784*da0073e9SAndroid Build Coastguard Worker    verbose: bool = False,
785*da0073e9SAndroid Build Coastguard Worker    fixed_batch_size: bool = False,
786*da0073e9SAndroid Build Coastguard Worker    use_external_data: bool = False,
787*da0073e9SAndroid Build Coastguard Worker    additional_test_inputs: Sequence[_InputArgsType] | None = None,
788*da0073e9SAndroid Build Coastguard Worker    options: VerificationOptions | None = None,
789*da0073e9SAndroid Build Coastguard Worker):
790*da0073e9SAndroid Build Coastguard Worker    """Verify model export to ONNX against original PyTorch model.
791*da0073e9SAndroid Build Coastguard Worker
792*da0073e9SAndroid Build Coastguard Worker    Args:
793*da0073e9SAndroid Build Coastguard Worker        model (torch.nn.Module or torch.jit.ScriptModule): See :func:`torch.onnx.export`.
794*da0073e9SAndroid Build Coastguard Worker        input_args (tuple): See :func:`torch.onnx.export`.
795*da0073e9SAndroid Build Coastguard Worker        input_kwargs (dict): See :func:`torch.onnx.export`.
796*da0073e9SAndroid Build Coastguard Worker        do_constant_folding (bool, optional): See :func:`torch.onnx.export`.
797*da0073e9SAndroid Build Coastguard Worker        dynamic_axes (dict, optional): See :func:`torch.onnx.export`.
798*da0073e9SAndroid Build Coastguard Worker        input_names (list, optional): See :func:`torch.onnx.export`.
799*da0073e9SAndroid Build Coastguard Worker        output_names (list, optional): See :func:`torch.onnx.export`.
800*da0073e9SAndroid Build Coastguard Worker        training (torch.onnx.TrainingMode): See :func:`torch.onnx.export`.
801*da0073e9SAndroid Build Coastguard Worker        opset_version (int, optional): See :func:`torch.onnx.export`.
802*da0073e9SAndroid Build Coastguard Worker        keep_initializers_as_inputs (bool, optional): See :func:`torch.onnx.export`.
803*da0073e9SAndroid Build Coastguard Worker        verbose (bool, optional): See :func:`torch.onnx.export`.
804*da0073e9SAndroid Build Coastguard Worker        fixed_batch_size (bool, optional): Legacy argument, used only by rnn test cases.
805*da0073e9SAndroid Build Coastguard Worker        use_external_data (bool, optional): Explicitly specify whether to export the
806*da0073e9SAndroid Build Coastguard Worker            model with external data.
807*da0073e9SAndroid Build Coastguard Worker        additional_test_inputs (list, optional): List of tuples. Each tuple is a group of
808*da0073e9SAndroid Build Coastguard Worker            input arguments to test. Currently only *args are supported.
809*da0073e9SAndroid Build Coastguard Worker        options (_VerificationOptions, optional): A _VerificationOptions object that
810*da0073e9SAndroid Build Coastguard Worker            controls the verification behavior.
811*da0073e9SAndroid Build Coastguard Worker
812*da0073e9SAndroid Build Coastguard Worker    Raises:
813*da0073e9SAndroid Build Coastguard Worker        AssertionError: if outputs from ONNX model and PyTorch model are not
814*da0073e9SAndroid Build Coastguard Worker            equal up to specified precision.
815*da0073e9SAndroid Build Coastguard Worker        ValueError: if arguments provided are invalid.
816*da0073e9SAndroid Build Coastguard Worker    """
817*da0073e9SAndroid Build Coastguard Worker    if options is None:
818*da0073e9SAndroid Build Coastguard Worker        options = VerificationOptions()
819*da0073e9SAndroid Build Coastguard Worker
820*da0073e9SAndroid Build Coastguard Worker    if training == torch.onnx.TrainingMode.TRAINING:
821*da0073e9SAndroid Build Coastguard Worker        model.train()
822*da0073e9SAndroid Build Coastguard Worker    elif training == torch.onnx.TrainingMode.EVAL:
823*da0073e9SAndroid Build Coastguard Worker        model.eval()
824*da0073e9SAndroid Build Coastguard Worker    with torch.no_grad(), contextlib.ExitStack() as stack:
825*da0073e9SAndroid Build Coastguard Worker        model_f: str | io.BytesIO = io.BytesIO()
826*da0073e9SAndroid Build Coastguard Worker        if use_external_data:
827*da0073e9SAndroid Build Coastguard Worker            tmpdir_path = stack.enter_context(tempfile.TemporaryDirectory())
828*da0073e9SAndroid Build Coastguard Worker            model_f = os.path.join(tmpdir_path, "model.onnx")
829*da0073e9SAndroid Build Coastguard Worker
830*da0073e9SAndroid Build Coastguard Worker        inputs_for_export = _prepare_input_for_export(input_args, input_kwargs)
831*da0073e9SAndroid Build Coastguard Worker
832*da0073e9SAndroid Build Coastguard Worker        # TODO(#77679): remove this and treat mutating model separately.
833*da0073e9SAndroid Build Coastguard Worker        model_copy = _try_clone_model(model)
834*da0073e9SAndroid Build Coastguard Worker        utils._export(
835*da0073e9SAndroid Build Coastguard Worker            model,
836*da0073e9SAndroid Build Coastguard Worker            inputs_for_export,
837*da0073e9SAndroid Build Coastguard Worker            model_f,
838*da0073e9SAndroid Build Coastguard Worker            opset_version=opset_version,
839*da0073e9SAndroid Build Coastguard Worker            do_constant_folding=do_constant_folding,
840*da0073e9SAndroid Build Coastguard Worker            keep_initializers_as_inputs=keep_initializers_as_inputs,
841*da0073e9SAndroid Build Coastguard Worker            dynamic_axes=dynamic_axes,
842*da0073e9SAndroid Build Coastguard Worker            input_names=input_names,
843*da0073e9SAndroid Build Coastguard Worker            output_names=output_names,
844*da0073e9SAndroid Build Coastguard Worker            fixed_batch_size=fixed_batch_size,
845*da0073e9SAndroid Build Coastguard Worker            training=training,
846*da0073e9SAndroid Build Coastguard Worker            verbose=verbose,
847*da0073e9SAndroid Build Coastguard Worker        )
848*da0073e9SAndroid Build Coastguard Worker
849*da0073e9SAndroid Build Coastguard Worker        _compare_onnx_pytorch_model(
850*da0073e9SAndroid Build Coastguard Worker            pt_model=model_copy,
851*da0073e9SAndroid Build Coastguard Worker            onnx_model_f=model_f,
852*da0073e9SAndroid Build Coastguard Worker            input_args=input_args,
853*da0073e9SAndroid Build Coastguard Worker            input_kwargs=input_kwargs,
854*da0073e9SAndroid Build Coastguard Worker            additional_test_inputs=additional_test_inputs,
855*da0073e9SAndroid Build Coastguard Worker            options=options,
856*da0073e9SAndroid Build Coastguard Worker        )
857*da0073e9SAndroid Build Coastguard Worker
858*da0073e9SAndroid Build Coastguard Worker
859*da0073e9SAndroid Build Coastguard Workerdef verify_aten_graph(
860*da0073e9SAndroid Build Coastguard Worker    graph: torch.Graph,
861*da0073e9SAndroid Build Coastguard Worker    input_args: tuple[Any, ...],
862*da0073e9SAndroid Build Coastguard Worker    export_options: _experimental.ExportOptions,
863*da0073e9SAndroid Build Coastguard Worker    params_dict: dict[str, Any] | None = None,
864*da0073e9SAndroid Build Coastguard Worker    verification_options: VerificationOptions | None = None,
865*da0073e9SAndroid Build Coastguard Worker) -> tuple[AssertionError | None, torch.Graph, _OutputsType, _OutputsType]:
866*da0073e9SAndroid Build Coastguard Worker    if verification_options is None:
867*da0073e9SAndroid Build Coastguard Worker        verification_options = VerificationOptions()
868*da0073e9SAndroid Build Coastguard Worker    if params_dict is None:
869*da0073e9SAndroid Build Coastguard Worker        params_dict = {}
870*da0073e9SAndroid Build Coastguard Worker
871*da0073e9SAndroid Build Coastguard Worker    original_jit_graph = graph
872*da0073e9SAndroid Build Coastguard Worker    graph = graph.copy()
873*da0073e9SAndroid Build Coastguard Worker
874*da0073e9SAndroid Build Coastguard Worker    # Execute aten graph and get reference torch jit outputs.
875*da0073e9SAndroid Build Coastguard Worker    graph_inputs = list(graph.inputs())
876*da0073e9SAndroid Build Coastguard Worker    jit_inputs = tuple([arg for arg in input_args if arg is not None])
877*da0073e9SAndroid Build Coastguard Worker    weights = [params_dict[v.debugName()] for v in graph_inputs[len(jit_inputs) :]]
878*da0073e9SAndroid Build Coastguard Worker    assert all(w is not None for w in weights)
879*da0073e9SAndroid Build Coastguard Worker    # TODO: Only copy the argument if mutation is detected in Graph.
880*da0073e9SAndroid Build Coastguard Worker    jit_inputs = copy.deepcopy(jit_inputs)
881*da0073e9SAndroid Build Coastguard Worker    jit_input_and_parameters = jit_inputs + tuple(weights)
882*da0073e9SAndroid Build Coastguard Worker    jit_outs = torch._C._jit_interpret_graph(graph, jit_input_and_parameters)  # type: ignore[attr-defined]
883*da0073e9SAndroid Build Coastguard Worker    if not isinstance(jit_outs, (list, tuple)):
884*da0073e9SAndroid Build Coastguard Worker        jit_outs = [jit_outs]
885*da0073e9SAndroid Build Coastguard Worker
886*da0073e9SAndroid Build Coastguard Worker    # Convert aten graph to onnx graph.
887*da0073e9SAndroid Build Coastguard Worker    graph, onnx_params_dict = _onnx_graph_from_aten_graph(
888*da0073e9SAndroid Build Coastguard Worker        graph, export_options, params_dict
889*da0073e9SAndroid Build Coastguard Worker    )
890*da0073e9SAndroid Build Coastguard Worker
891*da0073e9SAndroid Build Coastguard Worker    proto, export_map = _onnx_proto_from_onnx_graph(
892*da0073e9SAndroid Build Coastguard Worker        graph, export_options, onnx_params_dict
893*da0073e9SAndroid Build Coastguard Worker    )
894*da0073e9SAndroid Build Coastguard Worker    model_f: str | io.BytesIO = io.BytesIO()
895*da0073e9SAndroid Build Coastguard Worker    export_type = _exporter_states.ExportTypes.PROTOBUF_FILE
896*da0073e9SAndroid Build Coastguard Worker    onnx_proto_utils._export_file(proto, model_f, export_type, export_map)
897*da0073e9SAndroid Build Coastguard Worker
898*da0073e9SAndroid Build Coastguard Worker    # NOTE: Verification is unstable. Try catch to emit information for debugging.
899*da0073e9SAndroid Build Coastguard Worker    try:
900*da0073e9SAndroid Build Coastguard Worker        # NOTE: Input might be dce'ed, so we need to remove those from the input args.
901*da0073e9SAndroid Build Coastguard Worker        new_input_names = {v.debugName() for v in graph.inputs()}
902*da0073e9SAndroid Build Coastguard Worker        new_input_args = []
903*da0073e9SAndroid Build Coastguard Worker        for v, arg in zip(original_jit_graph.inputs(), input_args):
904*da0073e9SAndroid Build Coastguard Worker            if v.debugName() in new_input_names:
905*da0073e9SAndroid Build Coastguard Worker                new_input_args.append(arg)
906*da0073e9SAndroid Build Coastguard Worker        input_args = tuple(new_input_args)
907*da0073e9SAndroid Build Coastguard Worker
908*da0073e9SAndroid Build Coastguard Worker        onnx_inputs = _prepare_input_for_onnx(
909*da0073e9SAndroid Build Coastguard Worker            input_args,
910*da0073e9SAndroid Build Coastguard Worker            {},
911*da0073e9SAndroid Build Coastguard Worker            verification_options.remained_onnx_input_idx,
912*da0073e9SAndroid Build Coastguard Worker            verification_options.flatten,
913*da0073e9SAndroid Build Coastguard Worker        )
914*da0073e9SAndroid Build Coastguard Worker
915*da0073e9SAndroid Build Coastguard Worker        onnx_session = _onnx_backend_session(model_f, verification_options.backend)
916*da0073e9SAndroid Build Coastguard Worker        onnx_outs = _run_onnx(onnx_session, onnx_inputs)
917*da0073e9SAndroid Build Coastguard Worker        del onnx_session  # To free device memory
918*da0073e9SAndroid Build Coastguard Worker
919*da0073e9SAndroid Build Coastguard Worker        try:
920*da0073e9SAndroid Build Coastguard Worker            _compare_onnx_pytorch_outputs(
921*da0073e9SAndroid Build Coastguard Worker                onnx_outs=onnx_outs,
922*da0073e9SAndroid Build Coastguard Worker                pt_outs=jit_outs,
923*da0073e9SAndroid Build Coastguard Worker                options=verification_options,
924*da0073e9SAndroid Build Coastguard Worker            )
925*da0073e9SAndroid Build Coastguard Worker        except AssertionError as e:
926*da0073e9SAndroid Build Coastguard Worker            return e, graph, jit_outs, onnx_outs
927*da0073e9SAndroid Build Coastguard Worker
928*da0073e9SAndroid Build Coastguard Worker        return None, graph, jit_outs, onnx_outs
929*da0073e9SAndroid Build Coastguard Worker
930*da0073e9SAndroid Build Coastguard Worker    except Exception as e:
931*da0073e9SAndroid Build Coastguard Worker        print("Unexpected error during verification.")
932*da0073e9SAndroid Build Coastguard Worker        print("jit graph: ", original_jit_graph)
933*da0073e9SAndroid Build Coastguard Worker        print("onnx graph: ", graph)
934*da0073e9SAndroid Build Coastguard Worker        raise e
935*da0073e9SAndroid Build Coastguard Worker
936*da0073e9SAndroid Build Coastguard Worker
937*da0073e9SAndroid Build Coastguard Workerclass GraphInfoPrettyPrinter:
938*da0073e9SAndroid Build Coastguard Worker    graph_info: GraphInfo | None
939*da0073e9SAndroid Build Coastguard Worker    upper_printer: GraphInfoPrettyPrinter | None
940*da0073e9SAndroid Build Coastguard Worker    lower_printer: GraphInfoPrettyPrinter | None
941*da0073e9SAndroid Build Coastguard Worker
942*da0073e9SAndroid Build Coastguard Worker    graph_str_lambdas: Mapping[int, str]
943*da0073e9SAndroid Build Coastguard Worker    connector_str_lambdas: Mapping[int, str]
944*da0073e9SAndroid Build Coastguard Worker    children_str_lambdas: Mapping[int, str]
945*da0073e9SAndroid Build Coastguard Worker
946*da0073e9SAndroid Build Coastguard Worker    def __init__(self, graph_info: GraphInfo | None):
947*da0073e9SAndroid Build Coastguard Worker        self.graph_info = graph_info
948*da0073e9SAndroid Build Coastguard Worker        if (
949*da0073e9SAndroid Build Coastguard Worker            graph_info is not None
950*da0073e9SAndroid Build Coastguard Worker            and graph_info.upper_graph_info is not None
951*da0073e9SAndroid Build Coastguard Worker            and graph_info.lower_graph_info is not None
952*da0073e9SAndroid Build Coastguard Worker        ):
953*da0073e9SAndroid Build Coastguard Worker            self.upper_printer = GraphInfoPrettyPrinter(graph_info.upper_graph_info)
954*da0073e9SAndroid Build Coastguard Worker            self.lower_printer = GraphInfoPrettyPrinter(graph_info.lower_graph_info)
955*da0073e9SAndroid Build Coastguard Worker        else:
956*da0073e9SAndroid Build Coastguard Worker            self.upper_printer = None
957*da0073e9SAndroid Build Coastguard Worker            self.lower_printer = None
958*da0073e9SAndroid Build Coastguard Worker
959*da0073e9SAndroid Build Coastguard Worker    def _total_rows(self) -> int:
960*da0073e9SAndroid Build Coastguard Worker        if self.graph_info is None:
961*da0073e9SAndroid Build Coastguard Worker            return 1
962*da0073e9SAndroid Build Coastguard Worker        if self.upper_printer and self.lower_printer:
963*da0073e9SAndroid Build Coastguard Worker            return (
964*da0073e9SAndroid Build Coastguard Worker                self.upper_printer._total_rows() + self.lower_printer._total_rows() + 1
965*da0073e9SAndroid Build Coastguard Worker            )
966*da0073e9SAndroid Build Coastguard Worker        return 2  # Two lines: node count + id.
967*da0073e9SAndroid Build Coastguard Worker
968*da0073e9SAndroid Build Coastguard Worker    def _node_count_segment_str(self) -> str:
969*da0073e9SAndroid Build Coastguard Worker        if self.graph_info is None:
970*da0073e9SAndroid Build Coastguard Worker            return "..."
971*da0073e9SAndroid Build Coastguard Worker        node_count = self.graph_info.essential_node_count()
972*da0073e9SAndroid Build Coastguard Worker        has_mismatch = self.graph_info.has_mismatch()
973*da0073e9SAndroid Build Coastguard Worker        error_node_kind = (
974*da0073e9SAndroid Build Coastguard Worker            f"({self.graph_info.essential_node_kinds().pop()})"
975*da0073e9SAndroid Build Coastguard Worker            if node_count == 1 and has_mismatch
976*da0073e9SAndroid Build Coastguard Worker            else ""
977*da0073e9SAndroid Build Coastguard Worker        )
978*da0073e9SAndroid Build Coastguard Worker
979*da0073e9SAndroid Build Coastguard Worker        return f"{node_count} {'X' if has_mismatch else chr(0x2713)} {error_node_kind}"
980*da0073e9SAndroid Build Coastguard Worker
981*da0073e9SAndroid Build Coastguard Worker    def _graph_id_segment_str(self) -> str:
982*da0073e9SAndroid Build Coastguard Worker        if self.graph_info is None:
983*da0073e9SAndroid Build Coastguard Worker            return ""
984*da0073e9SAndroid Build Coastguard Worker        return f"id: {self.graph_info.id}"
985*da0073e9SAndroid Build Coastguard Worker
986*da0073e9SAndroid Build Coastguard Worker    def _max_segment_columns(self) -> int:
987*da0073e9SAndroid Build Coastguard Worker        return max(
988*da0073e9SAndroid Build Coastguard Worker            map(len, (self._node_count_segment_str(), self._graph_id_segment_str()))
989*da0073e9SAndroid Build Coastguard Worker        )
990*da0073e9SAndroid Build Coastguard Worker
991*da0073e9SAndroid Build Coastguard Worker    def _graph_segment_str_at_line(self, line: int) -> str:
992*da0073e9SAndroid Build Coastguard Worker        """Get the string representation of the graph segment at the given line."""
993*da0073e9SAndroid Build Coastguard Worker        if line == 0:
994*da0073e9SAndroid Build Coastguard Worker            result_str = self._node_count_segment_str()
995*da0073e9SAndroid Build Coastguard Worker            result_str += " " * (self._max_segment_columns() - len(result_str))
996*da0073e9SAndroid Build Coastguard Worker            return result_str
997*da0073e9SAndroid Build Coastguard Worker        if line == 1:
998*da0073e9SAndroid Build Coastguard Worker            result_str = self._graph_id_segment_str()
999*da0073e9SAndroid Build Coastguard Worker            result_str += " " * (self._max_segment_columns() - len(result_str))
1000*da0073e9SAndroid Build Coastguard Worker            return result_str
1001*da0073e9SAndroid Build Coastguard Worker        if 0 <= line < self._total_rows():
1002*da0073e9SAndroid Build Coastguard Worker            return " " * self._max_segment_columns()
1003*da0073e9SAndroid Build Coastguard Worker        return ""
1004*da0073e9SAndroid Build Coastguard Worker
1005*da0073e9SAndroid Build Coastguard Worker    def _connector_segment_str_at_line(self, line: int) -> str:
1006*da0073e9SAndroid Build Coastguard Worker        """Get the connector segment string at the given line."""
1007*da0073e9SAndroid Build Coastguard Worker        if self.upper_printer is None and self.lower_printer is None:
1008*da0073e9SAndroid Build Coastguard Worker            return ""
1009*da0073e9SAndroid Build Coastguard Worker        upper_total_rows = self.upper_printer._total_rows() if self.upper_printer else 1
1010*da0073e9SAndroid Build Coastguard Worker        lower_total_rows = self.lower_printer._total_rows() if self.lower_printer else 1
1011*da0073e9SAndroid Build Coastguard Worker        if line == 0:
1012*da0073e9SAndroid Build Coastguard Worker            return "  __"
1013*da0073e9SAndroid Build Coastguard Worker        elif line < upper_total_rows + 1:
1014*da0073e9SAndroid Build Coastguard Worker            return " |  "
1015*da0073e9SAndroid Build Coastguard Worker        elif line == upper_total_rows + 1:
1016*da0073e9SAndroid Build Coastguard Worker            return " |__"
1017*da0073e9SAndroid Build Coastguard Worker        elif line < upper_total_rows + lower_total_rows + 1:
1018*da0073e9SAndroid Build Coastguard Worker            return "    "
1019*da0073e9SAndroid Build Coastguard Worker        return ""
1020*da0073e9SAndroid Build Coastguard Worker
1021*da0073e9SAndroid Build Coastguard Worker    def _children_str_at_line(self, line: int) -> str:
1022*da0073e9SAndroid Build Coastguard Worker        """Get the string representation of the children at the given line.
1023*da0073e9SAndroid Build Coastguard Worker
1024*da0073e9SAndroid Build Coastguard Worker        Recursively calls `_str_at_line` on children nodes.
1025*da0073e9SAndroid Build Coastguard Worker        """
1026*da0073e9SAndroid Build Coastguard Worker        if self.upper_printer is None and self.lower_printer is None:
1027*da0073e9SAndroid Build Coastguard Worker            return ""
1028*da0073e9SAndroid Build Coastguard Worker        upper_total_rows = self.upper_printer._total_rows() if self.upper_printer else 1
1029*da0073e9SAndroid Build Coastguard Worker        lower_total_rows = self.lower_printer._total_rows() if self.lower_printer else 1
1030*da0073e9SAndroid Build Coastguard Worker        if 0 <= line < upper_total_rows:
1031*da0073e9SAndroid Build Coastguard Worker            return (
1032*da0073e9SAndroid Build Coastguard Worker                self.upper_printer._str_at_line(line) if self.upper_printer else "..."
1033*da0073e9SAndroid Build Coastguard Worker            )
1034*da0073e9SAndroid Build Coastguard Worker        elif upper_total_rows < line < upper_total_rows + lower_total_rows + 1:
1035*da0073e9SAndroid Build Coastguard Worker            return (
1036*da0073e9SAndroid Build Coastguard Worker                self.lower_printer._str_at_line(line - upper_total_rows - 1)
1037*da0073e9SAndroid Build Coastguard Worker                if self.lower_printer
1038*da0073e9SAndroid Build Coastguard Worker                else "..."
1039*da0073e9SAndroid Build Coastguard Worker            )
1040*da0073e9SAndroid Build Coastguard Worker        return ""
1041*da0073e9SAndroid Build Coastguard Worker
1042*da0073e9SAndroid Build Coastguard Worker    def _str_at_line(self, line: int) -> str:
1043*da0073e9SAndroid Build Coastguard Worker        """Get the string representation of the graph at the given line."""
1044*da0073e9SAndroid Build Coastguard Worker        return (
1045*da0073e9SAndroid Build Coastguard Worker            self._graph_segment_str_at_line(line)
1046*da0073e9SAndroid Build Coastguard Worker            + self._connector_segment_str_at_line(line)
1047*da0073e9SAndroid Build Coastguard Worker            + self._children_str_at_line(line)
1048*da0073e9SAndroid Build Coastguard Worker        )
1049*da0073e9SAndroid Build Coastguard Worker
1050*da0073e9SAndroid Build Coastguard Worker    def pretty_print(self):
1051*da0073e9SAndroid Build Coastguard Worker        if self.graph_info is None:
1052*da0073e9SAndroid Build Coastguard Worker            print(None)
1053*da0073e9SAndroid Build Coastguard Worker            return
1054*da0073e9SAndroid Build Coastguard Worker        # Print tree.
1055*da0073e9SAndroid Build Coastguard Worker        print(" Tree: ".center(80, "="))
1056*da0073e9SAndroid Build Coastguard Worker        total_rows = self._total_rows()
1057*da0073e9SAndroid Build Coastguard Worker        for line in range(total_rows):
1058*da0073e9SAndroid Build Coastguard Worker            print(self._str_at_line(line).rstrip())
1059*da0073e9SAndroid Build Coastguard Worker        if self.graph_info.has_mismatch():
1060*da0073e9SAndroid Build Coastguard Worker            # Summarize leaf subgraphs with mismatch.
1061*da0073e9SAndroid Build Coastguard Worker            print(" Mismatch leaf subgraphs: ".center(80, "="))
1062*da0073e9SAndroid Build Coastguard Worker            print(
1063*da0073e9SAndroid Build Coastguard Worker                [
1064*da0073e9SAndroid Build Coastguard Worker                    graph_info.id
1065*da0073e9SAndroid Build Coastguard Worker                    for graph_info in self.graph_info.all_mismatch_leaf_graph_info()
1066*da0073e9SAndroid Build Coastguard Worker                ]
1067*da0073e9SAndroid Build Coastguard Worker            )
1068*da0073e9SAndroid Build Coastguard Worker            # Summarize node kinds with mismatch.
1069*da0073e9SAndroid Build Coastguard Worker            mismatch_node_kinds: dict[str, int] = {}
1070*da0073e9SAndroid Build Coastguard Worker            for graph_info in self.graph_info.all_mismatch_leaf_graph_info():
1071*da0073e9SAndroid Build Coastguard Worker                node_kinds = graph_info.essential_node_kinds()
1072*da0073e9SAndroid Build Coastguard Worker                if len(node_kinds) == 1:
1073*da0073e9SAndroid Build Coastguard Worker                    node_kind = node_kinds.pop()
1074*da0073e9SAndroid Build Coastguard Worker                    mismatch_node_kinds[node_kind] = (
1075*da0073e9SAndroid Build Coastguard Worker                        mismatch_node_kinds.get(node_kind, 0) + 1
1076*da0073e9SAndroid Build Coastguard Worker                    )
1077*da0073e9SAndroid Build Coastguard Worker            print(" Mismatch node kinds: ".center(80, "="))
1078*da0073e9SAndroid Build Coastguard Worker            print(mismatch_node_kinds)
1079*da0073e9SAndroid Build Coastguard Worker        else:
1080*da0073e9SAndroid Build Coastguard Worker            print(" No mismatch found. ".center(80, "="))
1081*da0073e9SAndroid Build Coastguard Worker
1082*da0073e9SAndroid Build Coastguard Worker
1083*da0073e9SAndroid Build Coastguard Workerclass OnnxTestCaseRepro:
1084*da0073e9SAndroid Build Coastguard Worker    def __init__(self, repro_dir):
1085*da0073e9SAndroid Build Coastguard Worker        self.repro_dir = repro_dir
1086*da0073e9SAndroid Build Coastguard Worker        self.proto, self.inputs, self.outputs = onnx_proto_utils.load_test_case(
1087*da0073e9SAndroid Build Coastguard Worker            repro_dir
1088*da0073e9SAndroid Build Coastguard Worker        )
1089*da0073e9SAndroid Build Coastguard Worker
1090*da0073e9SAndroid Build Coastguard Worker    @classmethod
1091*da0073e9SAndroid Build Coastguard Worker    def create_test_case_repro(
1092*da0073e9SAndroid Build Coastguard Worker        cls, proto: bytes, inputs, outputs, dir: str, name: str | None = None
1093*da0073e9SAndroid Build Coastguard Worker    ):
1094*da0073e9SAndroid Build Coastguard Worker        """Create a repro under "{dir}/test_{name}" for an ONNX test case.
1095*da0073e9SAndroid Build Coastguard Worker
1096*da0073e9SAndroid Build Coastguard Worker        The test case contains the model and the inputs/outputs data. The directory
1097*da0073e9SAndroid Build Coastguard Worker        structure is as follows:
1098*da0073e9SAndroid Build Coastguard Worker
1099*da0073e9SAndroid Build Coastguard Worker        dir
1100*da0073e9SAndroid Build Coastguard Worker        \u251c\u2500\u2500 test_<name>
1101*da0073e9SAndroid Build Coastguard Worker        \u2502   \u251c\u2500\u2500 model.onnx
1102*da0073e9SAndroid Build Coastguard Worker        \u2502   \u2514\u2500\u2500 test_data_set_0
1103*da0073e9SAndroid Build Coastguard Worker        \u2502       \u251c\u2500\u2500 input_0.pb
1104*da0073e9SAndroid Build Coastguard Worker        \u2502       \u251c\u2500\u2500 input_1.pb
1105*da0073e9SAndroid Build Coastguard Worker        \u2502       \u251c\u2500\u2500 output_0.pb
1106*da0073e9SAndroid Build Coastguard Worker        \u2502       \u2514\u2500\u2500 output_1.pb
1107*da0073e9SAndroid Build Coastguard Worker
1108*da0073e9SAndroid Build Coastguard Worker        Args:
1109*da0073e9SAndroid Build Coastguard Worker            proto: ONNX model proto.
1110*da0073e9SAndroid Build Coastguard Worker            inputs: Inputs to the model.
1111*da0073e9SAndroid Build Coastguard Worker            outputs: Outputs of the model.
1112*da0073e9SAndroid Build Coastguard Worker            dir: Directory to save the repro.
1113*da0073e9SAndroid Build Coastguard Worker            name: Name of the test case. If not specified, a name based on current time
1114*da0073e9SAndroid Build Coastguard Worker                will be generated.
1115*da0073e9SAndroid Build Coastguard Worker        Returns:
1116*da0073e9SAndroid Build Coastguard Worker            Path to the repro.
1117*da0073e9SAndroid Build Coastguard Worker        """
1118*da0073e9SAndroid Build Coastguard Worker        if name is None:
1119*da0073e9SAndroid Build Coastguard Worker            name = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f")
1120*da0073e9SAndroid Build Coastguard Worker        return onnx_proto_utils.export_as_test_case(
1121*da0073e9SAndroid Build Coastguard Worker            proto,
1122*da0073e9SAndroid Build Coastguard Worker            _to_numpy(inputs),
1123*da0073e9SAndroid Build Coastguard Worker            _to_numpy(outputs),
1124*da0073e9SAndroid Build Coastguard Worker            name,
1125*da0073e9SAndroid Build Coastguard Worker            dir,
1126*da0073e9SAndroid Build Coastguard Worker        )
1127*da0073e9SAndroid Build Coastguard Worker
1128*da0073e9SAndroid Build Coastguard Worker    def validate(self, options: VerificationOptions):
1129*da0073e9SAndroid Build Coastguard Worker        """Run the ONNX test case with options.backend, and compare with the expected outputs.
1130*da0073e9SAndroid Build Coastguard Worker
1131*da0073e9SAndroid Build Coastguard Worker        Args:
1132*da0073e9SAndroid Build Coastguard Worker            options: Options for validation.
1133*da0073e9SAndroid Build Coastguard Worker
1134*da0073e9SAndroid Build Coastguard Worker        Raise:
1135*da0073e9SAndroid Build Coastguard Worker            AssertionError: if outputs from options.backend and expected outputs are not
1136*da0073e9SAndroid Build Coastguard Worker                equal up to specified precision.
1137*da0073e9SAndroid Build Coastguard Worker        """
1138*da0073e9SAndroid Build Coastguard Worker        onnx_session = _onnx_backend_session(io.BytesIO(self.proto), options.backend)
1139*da0073e9SAndroid Build Coastguard Worker        run_outputs = onnx_session.run(None, self.inputs)
1140*da0073e9SAndroid Build Coastguard Worker        if hasattr(onnx_session, "get_outputs"):
1141*da0073e9SAndroid Build Coastguard Worker            output_names = [o.name for o in onnx_session.get_outputs()]
1142*da0073e9SAndroid Build Coastguard Worker        elif hasattr(onnx_session, "output_names"):
1143*da0073e9SAndroid Build Coastguard Worker            output_names = onnx_session.output_names
1144*da0073e9SAndroid Build Coastguard Worker        else:
1145*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Unknown onnx session type: {type(onnx_session)}")
1146*da0073e9SAndroid Build Coastguard Worker        expected_outs = [self.outputs[name] for name in output_names]
1147*da0073e9SAndroid Build Coastguard Worker        _compare_onnx_pytorch_outputs_in_np(run_outputs, expected_outs, options)
1148*da0073e9SAndroid Build Coastguard Worker
1149*da0073e9SAndroid Build Coastguard Worker
1150*da0073e9SAndroid Build Coastguard Worker@dataclasses.dataclass
1151*da0073e9SAndroid Build Coastguard Workerclass GraphInfo:
1152*da0073e9SAndroid Build Coastguard Worker    """GraphInfo contains validation information of a TorchScript graph and its converted ONNX graph."""
1153*da0073e9SAndroid Build Coastguard Worker
1154*da0073e9SAndroid Build Coastguard Worker    graph: torch.Graph
1155*da0073e9SAndroid Build Coastguard Worker    input_args: tuple[Any, ...]
1156*da0073e9SAndroid Build Coastguard Worker    params_dict: dict[str, Any]
1157*da0073e9SAndroid Build Coastguard Worker    export_options: _experimental.ExportOptions = dataclasses.field(
1158*da0073e9SAndroid Build Coastguard Worker        default_factory=_experimental.ExportOptions
1159*da0073e9SAndroid Build Coastguard Worker    )
1160*da0073e9SAndroid Build Coastguard Worker    mismatch_error: AssertionError | None = dataclasses.field(default=None, init=False)
1161*da0073e9SAndroid Build Coastguard Worker    pt_outs: Sequence[_NumericType] | None = dataclasses.field(default=None, init=False)
1162*da0073e9SAndroid Build Coastguard Worker    upper_graph_info: GraphInfo | None = dataclasses.field(default=None, init=False)
1163*da0073e9SAndroid Build Coastguard Worker    lower_graph_info: GraphInfo | None = dataclasses.field(default=None, init=False)
1164*da0073e9SAndroid Build Coastguard Worker    id: str = dataclasses.field(default="")
1165*da0073e9SAndroid Build Coastguard Worker    _onnx_graph: torch.Graph | None = dataclasses.field(init=False, default=None)
1166*da0073e9SAndroid Build Coastguard Worker
1167*da0073e9SAndroid Build Coastguard Worker    _EXCLUDED_NODE_KINDS: frozenset[str] = frozenset(
1168*da0073e9SAndroid Build Coastguard Worker        {"prim::Constant", "prim::ListConstruct", "aten::ScalarImplicit"}
1169*da0073e9SAndroid Build Coastguard Worker    )
1170*da0073e9SAndroid Build Coastguard Worker
1171*da0073e9SAndroid Build Coastguard Worker    def clear(self):
1172*da0073e9SAndroid Build Coastguard Worker        """Clear states and results of previous verification."""
1173*da0073e9SAndroid Build Coastguard Worker        self.mismatch_error = None
1174*da0073e9SAndroid Build Coastguard Worker        self.pt_outs = None
1175*da0073e9SAndroid Build Coastguard Worker        self._onnx_graph = None
1176*da0073e9SAndroid Build Coastguard Worker        self.upper_graph_info = None
1177*da0073e9SAndroid Build Coastguard Worker        self.lower_graph_info = None
1178*da0073e9SAndroid Build Coastguard Worker
1179*da0073e9SAndroid Build Coastguard Worker    def pretty_print_tree(self):
1180*da0073e9SAndroid Build Coastguard Worker        """Pretty print `GraphInfo` tree.
1181*da0073e9SAndroid Build Coastguard Worker
1182*da0073e9SAndroid Build Coastguard Worker        Each node represents a subgraph, showing the number of nodes in the subgraph and
1183*da0073e9SAndroid Build Coastguard Worker        a check mark if the subgraph has output mismatch between torch and ONNX.
1184*da0073e9SAndroid Build Coastguard Worker
1185*da0073e9SAndroid Build Coastguard Worker        The id of the subgraph is shown under the node. The `GraphInfo` object for any
1186*da0073e9SAndroid Build Coastguard Worker        subgraph can be retrieved by calling `graph_info.find_partition(id)`.
1187*da0073e9SAndroid Build Coastguard Worker
1188*da0073e9SAndroid Build Coastguard Worker        Example::
1189*da0073e9SAndroid Build Coastguard Worker
1190*da0073e9SAndroid Build Coastguard Worker            ==================================== Tree: =====================================
1191*da0073e9SAndroid Build Coastguard Worker            5 X   __2 X    __1 \u2713
1192*da0073e9SAndroid Build Coastguard Worker            id:  |  id: 0 |  id: 00
1193*da0073e9SAndroid Build Coastguard Worker                 |        |
1194*da0073e9SAndroid Build Coastguard Worker                 |        |__1 X (aten::relu)
1195*da0073e9SAndroid Build Coastguard Worker                 |           id: 01
1196*da0073e9SAndroid Build Coastguard Worker                 |
1197*da0073e9SAndroid Build Coastguard Worker                 |__3 X    __1 \u2713
1198*da0073e9SAndroid Build Coastguard Worker                    id: 1 |  id: 10
1199*da0073e9SAndroid Build Coastguard Worker                          |
1200*da0073e9SAndroid Build Coastguard Worker                          |__2 X     __1 X (aten::relu)
1201*da0073e9SAndroid Build Coastguard Worker                             id: 11 |  id: 110
1202*da0073e9SAndroid Build Coastguard Worker                                    |
1203*da0073e9SAndroid Build Coastguard Worker                                    |__1 \u2713
1204*da0073e9SAndroid Build Coastguard Worker                                       id: 111
1205*da0073e9SAndroid Build Coastguard Worker            =========================== Mismatch leaf subgraphs: ===========================
1206*da0073e9SAndroid Build Coastguard Worker            ['01', '110']
1207*da0073e9SAndroid Build Coastguard Worker            ============================= Mismatch node kinds: =============================
1208*da0073e9SAndroid Build Coastguard Worker            {'aten::relu': 2}
1209*da0073e9SAndroid Build Coastguard Worker
1210*da0073e9SAndroid Build Coastguard Worker        """
1211*da0073e9SAndroid Build Coastguard Worker        GraphInfoPrettyPrinter(self).pretty_print()
1212*da0073e9SAndroid Build Coastguard Worker
1213*da0073e9SAndroid Build Coastguard Worker    def pretty_print_mismatch(self, graph: bool = False):
1214*da0073e9SAndroid Build Coastguard Worker        """Pretty print details of the mismatch between torch and ONNX.
1215*da0073e9SAndroid Build Coastguard Worker
1216*da0073e9SAndroid Build Coastguard Worker        Args:
1217*da0073e9SAndroid Build Coastguard Worker            graph: If True, print the ATen JIT graph and ONNX graph.
1218*da0073e9SAndroid Build Coastguard Worker        """
1219*da0073e9SAndroid Build Coastguard Worker        print(f" Mismatch info for graph partition {self.id}: ".center(80, "="))
1220*da0073e9SAndroid Build Coastguard Worker        if graph:
1221*da0073e9SAndroid Build Coastguard Worker            print(" ATen JIT graph ".center(80, "="))
1222*da0073e9SAndroid Build Coastguard Worker            # TODO: A more compact graph printer.
1223*da0073e9SAndroid Build Coastguard Worker            #   * Drop stride, grad, device information.
1224*da0073e9SAndroid Build Coastguard Worker            #   * Show source location on a separate line.
1225*da0073e9SAndroid Build Coastguard Worker            print(self.graph)
1226*da0073e9SAndroid Build Coastguard Worker            if self._onnx_graph is not None:
1227*da0073e9SAndroid Build Coastguard Worker                print(" ONNX graph ".center(80, "="))
1228*da0073e9SAndroid Build Coastguard Worker                print(self._onnx_graph)
1229*da0073e9SAndroid Build Coastguard Worker        if self.has_mismatch():
1230*da0073e9SAndroid Build Coastguard Worker            print(" Mismatch error ".center(80, "="))
1231*da0073e9SAndroid Build Coastguard Worker            print(self.mismatch_error)
1232*da0073e9SAndroid Build Coastguard Worker        else:
1233*da0073e9SAndroid Build Coastguard Worker            print(" No mismatch ".center(80, "="))
1234*da0073e9SAndroid Build Coastguard Worker
1235*da0073e9SAndroid Build Coastguard Worker    def has_mismatch(self) -> bool:
1236*da0073e9SAndroid Build Coastguard Worker        """Return True if the subgraph has output mismatch between torch and ONNX."""
1237*da0073e9SAndroid Build Coastguard Worker        return self.mismatch_error is not None
1238*da0073e9SAndroid Build Coastguard Worker
1239*da0073e9SAndroid Build Coastguard Worker    def essential_node_count(self) -> int:
1240*da0073e9SAndroid Build Coastguard Worker        """Return the number of nodes in the subgraph excluding those in `_EXCLUDED_NODE_KINDS`."""
1241*da0073e9SAndroid Build Coastguard Worker        return sum(
1242*da0073e9SAndroid Build Coastguard Worker            1 for n in self.graph.nodes() if n.kind() not in self._EXCLUDED_NODE_KINDS
1243*da0073e9SAndroid Build Coastguard Worker        )
1244*da0073e9SAndroid Build Coastguard Worker
1245*da0073e9SAndroid Build Coastguard Worker    def essential_node_kinds(self) -> set[str]:
1246*da0073e9SAndroid Build Coastguard Worker        """Return the set of node kinds in the subgraph excluding those in `_EXCLUDED_NODE_KINDS`."""
1247*da0073e9SAndroid Build Coastguard Worker        return {
1248*da0073e9SAndroid Build Coastguard Worker            n.kind()
1249*da0073e9SAndroid Build Coastguard Worker            for n in self.graph.nodes()
1250*da0073e9SAndroid Build Coastguard Worker            if n.kind() not in self._EXCLUDED_NODE_KINDS
1251*da0073e9SAndroid Build Coastguard Worker        }
1252*da0073e9SAndroid Build Coastguard Worker
1253*da0073e9SAndroid Build Coastguard Worker    def all_mismatch_leaf_graph_info(self) -> list[GraphInfo]:
1254*da0073e9SAndroid Build Coastguard Worker        """Return a list of all leaf `GraphInfo` objects that have mismatch."""
1255*da0073e9SAndroid Build Coastguard Worker        if not self.has_mismatch():
1256*da0073e9SAndroid Build Coastguard Worker            return []
1257*da0073e9SAndroid Build Coastguard Worker
1258*da0073e9SAndroid Build Coastguard Worker        no_mismatch_children = (
1259*da0073e9SAndroid Build Coastguard Worker            self.upper_graph_info is None or not self.upper_graph_info.has_mismatch()
1260*da0073e9SAndroid Build Coastguard Worker        ) and (
1261*da0073e9SAndroid Build Coastguard Worker            self.lower_graph_info is None or not self.lower_graph_info.has_mismatch()
1262*da0073e9SAndroid Build Coastguard Worker        )
1263*da0073e9SAndroid Build Coastguard Worker
1264*da0073e9SAndroid Build Coastguard Worker        if no_mismatch_children:
1265*da0073e9SAndroid Build Coastguard Worker            return [self]
1266*da0073e9SAndroid Build Coastguard Worker
1267*da0073e9SAndroid Build Coastguard Worker        results = []
1268*da0073e9SAndroid Build Coastguard Worker        if self.upper_graph_info is not None:
1269*da0073e9SAndroid Build Coastguard Worker            results += self.upper_graph_info.all_mismatch_leaf_graph_info()
1270*da0073e9SAndroid Build Coastguard Worker        if self.lower_graph_info is not None:
1271*da0073e9SAndroid Build Coastguard Worker            results += self.lower_graph_info.all_mismatch_leaf_graph_info()
1272*da0073e9SAndroid Build Coastguard Worker
1273*da0073e9SAndroid Build Coastguard Worker        return results
1274*da0073e9SAndroid Build Coastguard Worker
1275*da0073e9SAndroid Build Coastguard Worker    def find_partition(self, id: str) -> GraphInfo | None:
1276*da0073e9SAndroid Build Coastguard Worker        """Find the `GraphInfo` object with the given id."""
1277*da0073e9SAndroid Build Coastguard Worker        if id == self.id:
1278*da0073e9SAndroid Build Coastguard Worker            return self
1279*da0073e9SAndroid Build Coastguard Worker        current_length = len(self.id)
1280*da0073e9SAndroid Build Coastguard Worker        if len(id) > current_length:
1281*da0073e9SAndroid Build Coastguard Worker            if id[current_length] == "0" and self.upper_graph_info is not None:
1282*da0073e9SAndroid Build Coastguard Worker                return self.upper_graph_info.find_partition(id)
1283*da0073e9SAndroid Build Coastguard Worker            elif id[current_length] == "1" and self.lower_graph_info is not None:
1284*da0073e9SAndroid Build Coastguard Worker                return self.lower_graph_info.find_partition(id)
1285*da0073e9SAndroid Build Coastguard Worker        return None
1286*da0073e9SAndroid Build Coastguard Worker
1287*da0073e9SAndroid Build Coastguard Worker    def export_repro(
1288*da0073e9SAndroid Build Coastguard Worker        self, repro_dir: str | None = None, name: str | None = None
1289*da0073e9SAndroid Build Coastguard Worker    ) -> str:
1290*da0073e9SAndroid Build Coastguard Worker        """Export the subgraph to ONNX along with the input/output data for repro.
1291*da0073e9SAndroid Build Coastguard Worker
1292*da0073e9SAndroid Build Coastguard Worker        The repro directory will contain the following files::
1293*da0073e9SAndroid Build Coastguard Worker
1294*da0073e9SAndroid Build Coastguard Worker            dir
1295*da0073e9SAndroid Build Coastguard Worker            \u251c\u2500\u2500 test_<name>
1296*da0073e9SAndroid Build Coastguard Worker            \u2502   \u251c\u2500\u2500 model.onnx
1297*da0073e9SAndroid Build Coastguard Worker            \u2502   \u2514\u2500\u2500 test_data_set_0
1298*da0073e9SAndroid Build Coastguard Worker            \u2502       \u251c\u2500\u2500 input_0.pb
1299*da0073e9SAndroid Build Coastguard Worker            \u2502       \u251c\u2500\u2500 input_1.pb
1300*da0073e9SAndroid Build Coastguard Worker            \u2502       \u251c\u2500\u2500 output_0.pb
1301*da0073e9SAndroid Build Coastguard Worker            \u2502       \u2514\u2500\u2500 output_1.pb
1302*da0073e9SAndroid Build Coastguard Worker
1303*da0073e9SAndroid Build Coastguard Worker        Args:
1304*da0073e9SAndroid Build Coastguard Worker            repro_dir: The directory to export the repro files to. Defaults to current
1305*da0073e9SAndroid Build Coastguard Worker                working directory if None.
1306*da0073e9SAndroid Build Coastguard Worker            name: An optional name for the test case folder: "test_{name}".
1307*da0073e9SAndroid Build Coastguard Worker
1308*da0073e9SAndroid Build Coastguard Worker        Returns:
1309*da0073e9SAndroid Build Coastguard Worker            The path to the exported repro directory.
1310*da0073e9SAndroid Build Coastguard Worker        """
1311*da0073e9SAndroid Build Coastguard Worker
1312*da0073e9SAndroid Build Coastguard Worker        if repro_dir is None:
1313*da0073e9SAndroid Build Coastguard Worker            repro_dir = os.getcwd()
1314*da0073e9SAndroid Build Coastguard Worker        repro_dir = os.path.join(repro_dir, "onnx_debug")
1315*da0073e9SAndroid Build Coastguard Worker
1316*da0073e9SAndroid Build Coastguard Worker        onnx_graph, onnx_params_dict = _onnx_graph_from_aten_graph(
1317*da0073e9SAndroid Build Coastguard Worker            self.graph, self.export_options, self.params_dict
1318*da0073e9SAndroid Build Coastguard Worker        )
1319*da0073e9SAndroid Build Coastguard Worker
1320*da0073e9SAndroid Build Coastguard Worker        proto, _ = _onnx_proto_from_onnx_graph(
1321*da0073e9SAndroid Build Coastguard Worker            onnx_graph, self.export_options, onnx_params_dict
1322*da0073e9SAndroid Build Coastguard Worker        )
1323*da0073e9SAndroid Build Coastguard Worker        return OnnxTestCaseRepro.create_test_case_repro(
1324*da0073e9SAndroid Build Coastguard Worker            proto, self.input_args, self.pt_outs, repro_dir, name
1325*da0073e9SAndroid Build Coastguard Worker        )
1326*da0073e9SAndroid Build Coastguard Worker
1327*da0073e9SAndroid Build Coastguard Worker    def _graph_partition_pivot(self) -> int:
1328*da0073e9SAndroid Build Coastguard Worker        """Find the pivot index to partition the graph.
1329*da0073e9SAndroid Build Coastguard Worker
1330*da0073e9SAndroid Build Coastguard Worker        The pivot is the node that splits the graph into two parts. Each part should
1331*da0073e9SAndroid Build Coastguard Worker        have the similar amount of nodes, excluding non essential ops, defined in
1332*da0073e9SAndroid Build Coastguard Worker        `_EXCLUDED_NODE_KINDS`, such as `prim::Constant`.
1333*da0073e9SAndroid Build Coastguard Worker        If the graph has an odd number of nodes, the upper part will have one more node.
1334*da0073e9SAndroid Build Coastguard Worker        If the graph does not have any node that can be partitioned, return -1.
1335*da0073e9SAndroid Build Coastguard Worker
1336*da0073e9SAndroid Build Coastguard Worker        Returns:
1337*da0073e9SAndroid Build Coastguard Worker            The index of the pivot node.
1338*da0073e9SAndroid Build Coastguard Worker        """
1339*da0073e9SAndroid Build Coastguard Worker        included_node_indices = [
1340*da0073e9SAndroid Build Coastguard Worker            i
1341*da0073e9SAndroid Build Coastguard Worker            for i, n in enumerate(self.graph.nodes())
1342*da0073e9SAndroid Build Coastguard Worker            if n.kind() not in self._EXCLUDED_NODE_KINDS
1343*da0073e9SAndroid Build Coastguard Worker        ]
1344*da0073e9SAndroid Build Coastguard Worker        half_idx = len(included_node_indices) // 2 - 1
1345*da0073e9SAndroid Build Coastguard Worker        if half_idx >= 0 and len(included_node_indices) > half_idx:
1346*da0073e9SAndroid Build Coastguard Worker            return included_node_indices[half_idx] + 1
1347*da0073e9SAndroid Build Coastguard Worker        return -1
1348*da0073e9SAndroid Build Coastguard Worker
1349*da0073e9SAndroid Build Coastguard Worker    def _partition_upper_graph(self) -> torch.Graph:
1350*da0073e9SAndroid Build Coastguard Worker        pivot = self._graph_partition_pivot()
1351*da0073e9SAndroid Build Coastguard Worker        if pivot == -1:
1352*da0073e9SAndroid Build Coastguard Worker            return torch.Graph()
1353*da0073e9SAndroid Build Coastguard Worker        graph = self.graph.copy()  # Copy to not mutate parent graph.
1354*da0073e9SAndroid Build Coastguard Worker        original_outputs = list(graph.outputs())
1355*da0073e9SAndroid Build Coastguard Worker
1356*da0073e9SAndroid Build Coastguard Worker        def _process_bridge_value_for_upper(
1357*da0073e9SAndroid Build Coastguard Worker            new_outputs: list[torch.Value], bridge_value: torch.Value
1358*da0073e9SAndroid Build Coastguard Worker        ) -> torch.Value:
1359*da0073e9SAndroid Build Coastguard Worker            # Add bridge values as upper graph outputs.
1360*da0073e9SAndroid Build Coastguard Worker            new_outputs.append(bridge_value)
1361*da0073e9SAndroid Build Coastguard Worker            return bridge_value
1362*da0073e9SAndroid Build Coastguard Worker
1363*da0073e9SAndroid Build Coastguard Worker        new_outputs: list[torch.Value] = []
1364*da0073e9SAndroid Build Coastguard Worker        process_bridge_value_for_upper = functools.partial(
1365*da0073e9SAndroid Build Coastguard Worker            _process_bridge_value_for_upper, new_outputs
1366*da0073e9SAndroid Build Coastguard Worker        )
1367*da0073e9SAndroid Build Coastguard Worker        _, dropped_nodes, complete_upper_nodes_set, _ = self._partition_nodes(
1368*da0073e9SAndroid Build Coastguard Worker            graph, pivot, process_bridge_value_for_upper
1369*da0073e9SAndroid Build Coastguard Worker        )
1370*da0073e9SAndroid Build Coastguard Worker
1371*da0073e9SAndroid Build Coastguard Worker        for _ in enumerate(original_outputs):
1372*da0073e9SAndroid Build Coastguard Worker            graph.eraseOutput(0)
1373*da0073e9SAndroid Build Coastguard Worker        for output in new_outputs:
1374*da0073e9SAndroid Build Coastguard Worker            graph.registerOutput(output)
1375*da0073e9SAndroid Build Coastguard Worker
1376*da0073e9SAndroid Build Coastguard Worker        for node in reversed(dropped_nodes):
1377*da0073e9SAndroid Build Coastguard Worker            node.destroy()
1378*da0073e9SAndroid Build Coastguard Worker
1379*da0073e9SAndroid Build Coastguard Worker        for i, input in reversed(list(enumerate(list(graph.inputs())))):
1380*da0073e9SAndroid Build Coastguard Worker            if (
1381*da0073e9SAndroid Build Coastguard Worker                not _has_uses_by_nodes(input, complete_upper_nodes_set)
1382*da0073e9SAndroid Build Coastguard Worker                and input not in new_outputs
1383*da0073e9SAndroid Build Coastguard Worker            ):
1384*da0073e9SAndroid Build Coastguard Worker                try:
1385*da0073e9SAndroid Build Coastguard Worker                    graph.eraseInput(i)
1386*da0073e9SAndroid Build Coastguard Worker                except RuntimeError as e:
1387*da0073e9SAndroid Build Coastguard Worker                    print(input, graph)
1388*da0073e9SAndroid Build Coastguard Worker                    raise e
1389*da0073e9SAndroid Build Coastguard Worker
1390*da0073e9SAndroid Build Coastguard Worker        return graph
1391*da0073e9SAndroid Build Coastguard Worker
1392*da0073e9SAndroid Build Coastguard Worker    def _partition_lower_graph(self) -> torch.Graph:
1393*da0073e9SAndroid Build Coastguard Worker        pivot = self._graph_partition_pivot()
1394*da0073e9SAndroid Build Coastguard Worker        if pivot == -1:
1395*da0073e9SAndroid Build Coastguard Worker            return torch.Graph()
1396*da0073e9SAndroid Build Coastguard Worker        graph = self.graph.copy()  # Copy to not mutate parent graph.
1397*da0073e9SAndroid Build Coastguard Worker        original_outputs = list(graph.outputs())
1398*da0073e9SAndroid Build Coastguard Worker        original_inputs = list(graph.inputs())
1399*da0073e9SAndroid Build Coastguard Worker
1400*da0073e9SAndroid Build Coastguard Worker        new_outputs = []
1401*da0073e9SAndroid Build Coastguard Worker
1402*da0073e9SAndroid Build Coastguard Worker        def _process_bridge_value_for_lower(
1403*da0073e9SAndroid Build Coastguard Worker            graph: torch.Graph, bridge_value: torch.Value
1404*da0073e9SAndroid Build Coastguard Worker        ) -> torch.Value:
1405*da0073e9SAndroid Build Coastguard Worker            # Add bridge values as lower graph inputs.
1406*da0073e9SAndroid Build Coastguard Worker            new_input = graph.addInput()
1407*da0073e9SAndroid Build Coastguard Worker            bridge_value.replaceAllUsesWith(new_input)
1408*da0073e9SAndroid Build Coastguard Worker            new_input.copyMetadata(bridge_value)
1409*da0073e9SAndroid Build Coastguard Worker            return new_input
1410*da0073e9SAndroid Build Coastguard Worker
1411*da0073e9SAndroid Build Coastguard Worker        process_bridge_value_for_lower = functools.partial(
1412*da0073e9SAndroid Build Coastguard Worker            _process_bridge_value_for_lower, graph
1413*da0073e9SAndroid Build Coastguard Worker        )
1414*da0073e9SAndroid Build Coastguard Worker
1415*da0073e9SAndroid Build Coastguard Worker        upper_nodes, lower_nodes, _, complete_lower_nodes_set = self._partition_nodes(
1416*da0073e9SAndroid Build Coastguard Worker            graph, pivot, process_bridge_value_for_lower
1417*da0073e9SAndroid Build Coastguard Worker        )
1418*da0073e9SAndroid Build Coastguard Worker
1419*da0073e9SAndroid Build Coastguard Worker        for output in original_outputs:
1420*da0073e9SAndroid Build Coastguard Worker            if _produced_by(output, lower_nodes):
1421*da0073e9SAndroid Build Coastguard Worker                new_outputs.append(output)
1422*da0073e9SAndroid Build Coastguard Worker        for _ in enumerate(original_outputs):
1423*da0073e9SAndroid Build Coastguard Worker            graph.eraseOutput(0)
1424*da0073e9SAndroid Build Coastguard Worker        for output in new_outputs:
1425*da0073e9SAndroid Build Coastguard Worker            graph.registerOutput(output)
1426*da0073e9SAndroid Build Coastguard Worker
1427*da0073e9SAndroid Build Coastguard Worker        for input in original_inputs:
1428*da0073e9SAndroid Build Coastguard Worker            if _has_uses_by_nodes(input, complete_lower_nodes_set):
1429*da0073e9SAndroid Build Coastguard Worker                new_input = graph.addInput()
1430*da0073e9SAndroid Build Coastguard Worker                input.replaceAllUsesWith(new_input)
1431*da0073e9SAndroid Build Coastguard Worker                new_input.copyMetadata(input)
1432*da0073e9SAndroid Build Coastguard Worker
1433*da0073e9SAndroid Build Coastguard Worker        for node in reversed(upper_nodes):
1434*da0073e9SAndroid Build Coastguard Worker            if node not in complete_lower_nodes_set:
1435*da0073e9SAndroid Build Coastguard Worker                try:
1436*da0073e9SAndroid Build Coastguard Worker                    node.destroy()
1437*da0073e9SAndroid Build Coastguard Worker                except RuntimeError as e:
1438*da0073e9SAndroid Build Coastguard Worker                    print(node, graph)
1439*da0073e9SAndroid Build Coastguard Worker                    raise e
1440*da0073e9SAndroid Build Coastguard Worker
1441*da0073e9SAndroid Build Coastguard Worker        for _ in original_inputs:
1442*da0073e9SAndroid Build Coastguard Worker            graph.eraseInput(0)
1443*da0073e9SAndroid Build Coastguard Worker
1444*da0073e9SAndroid Build Coastguard Worker        return graph
1445*da0073e9SAndroid Build Coastguard Worker
1446*da0073e9SAndroid Build Coastguard Worker    def _partition_node(
1447*da0073e9SAndroid Build Coastguard Worker        self,
1448*da0073e9SAndroid Build Coastguard Worker        node: torch.Node,
1449*da0073e9SAndroid Build Coastguard Worker        complete_upper_nodes_set: set[torch.Node],
1450*da0073e9SAndroid Build Coastguard Worker        complete_lower_nodes_set: set[torch.Node],
1451*da0073e9SAndroid Build Coastguard Worker        original_graph_outputs: set[torch.Value],
1452*da0073e9SAndroid Build Coastguard Worker        covered_bridge_values: set[torch.Value],
1453*da0073e9SAndroid Build Coastguard Worker        process_bridge_value: Callable[[torch.Value], torch.Value],
1454*da0073e9SAndroid Build Coastguard Worker    ):
1455*da0073e9SAndroid Build Coastguard Worker        if node in complete_lower_nodes_set:
1456*da0073e9SAndroid Build Coastguard Worker            return
1457*da0073e9SAndroid Build Coastguard Worker
1458*da0073e9SAndroid Build Coastguard Worker        if (
1459*da0073e9SAndroid Build Coastguard Worker            _node_has_uses_by(node, complete_lower_nodes_set)
1460*da0073e9SAndroid Build Coastguard Worker            and node.kind() in self._EXCLUDED_NODE_KINDS
1461*da0073e9SAndroid Build Coastguard Worker        ):
1462*da0073e9SAndroid Build Coastguard Worker            complete_lower_nodes_set.update(_all_nodes([node]))
1463*da0073e9SAndroid Build Coastguard Worker            for input in node.inputs():
1464*da0073e9SAndroid Build Coastguard Worker                if input in covered_bridge_values:
1465*da0073e9SAndroid Build Coastguard Worker                    continue
1466*da0073e9SAndroid Build Coastguard Worker                self._partition_node(
1467*da0073e9SAndroid Build Coastguard Worker                    input.node(),
1468*da0073e9SAndroid Build Coastguard Worker                    complete_upper_nodes_set,
1469*da0073e9SAndroid Build Coastguard Worker                    complete_lower_nodes_set,
1470*da0073e9SAndroid Build Coastguard Worker                    original_graph_outputs,
1471*da0073e9SAndroid Build Coastguard Worker                    covered_bridge_values,
1472*da0073e9SAndroid Build Coastguard Worker                    process_bridge_value,
1473*da0073e9SAndroid Build Coastguard Worker                )
1474*da0073e9SAndroid Build Coastguard Worker        else:
1475*da0073e9SAndroid Build Coastguard Worker            for output in node.outputs():
1476*da0073e9SAndroid Build Coastguard Worker                if output in covered_bridge_values:
1477*da0073e9SAndroid Build Coastguard Worker                    continue
1478*da0073e9SAndroid Build Coastguard Worker                if (
1479*da0073e9SAndroid Build Coastguard Worker                    _has_uses_by_nodes(output, complete_lower_nodes_set)
1480*da0073e9SAndroid Build Coastguard Worker                    or output in original_graph_outputs
1481*da0073e9SAndroid Build Coastguard Worker                ):
1482*da0073e9SAndroid Build Coastguard Worker                    covered_bridge_values.add(process_bridge_value(output))
1483*da0073e9SAndroid Build Coastguard Worker
1484*da0073e9SAndroid Build Coastguard Worker    def _partition_nodes(
1485*da0073e9SAndroid Build Coastguard Worker        self,
1486*da0073e9SAndroid Build Coastguard Worker        graph: torch.Graph,
1487*da0073e9SAndroid Build Coastguard Worker        pivot: int,
1488*da0073e9SAndroid Build Coastguard Worker        process_bridge_value: Callable[[torch.Value], torch.Value],
1489*da0073e9SAndroid Build Coastguard Worker    ) -> tuple[list[torch.Node], list[torch.Node], set[torch.Node], set[torch.Node]]:
1490*da0073e9SAndroid Build Coastguard Worker        nodes = list(graph.nodes())
1491*da0073e9SAndroid Build Coastguard Worker        upper_nodes = nodes[:pivot]
1492*da0073e9SAndroid Build Coastguard Worker        lower_nodes = nodes[pivot:]
1493*da0073e9SAndroid Build Coastguard Worker        # `upper_nodes` and `complete_upper_nodes_set` differs in that the latter
1494*da0073e9SAndroid Build Coastguard Worker        # recursively contains nodes in subblock of `upper_nodes`.
1495*da0073e9SAndroid Build Coastguard Worker        # The same applies for `lower_nodes` and `complete_lower_nodes_set`.
1496*da0073e9SAndroid Build Coastguard Worker        # With addition that `complete_lower_nodes_set` will include nodes that
1497*da0073e9SAndroid Build Coastguard Worker        # are determined to be copied from `upper_nodes` to `lower_nodes`.
1498*da0073e9SAndroid Build Coastguard Worker        complete_upper_nodes_set = _all_nodes(upper_nodes)
1499*da0073e9SAndroid Build Coastguard Worker        complete_lower_nodes_set = _all_nodes(lower_nodes)
1500*da0073e9SAndroid Build Coastguard Worker        original_graph_outputs = set(graph.outputs())
1501*da0073e9SAndroid Build Coastguard Worker        # Bridge values are values produced from upper graph, and consumed
1502*da0073e9SAndroid Build Coastguard Worker        # by lower graph. These values need to be become upper graph outputs
1503*da0073e9SAndroid Build Coastguard Worker        # and lower graph inputs, to bridge the interaction.
1504*da0073e9SAndroid Build Coastguard Worker        # Start with all graph inputs marked as covered. If any graph input is
1505*da0073e9SAndroid Build Coastguard Worker        # needed by lower graph, just keep it in lower graph inputs later.
1506*da0073e9SAndroid Build Coastguard Worker        covered_bridge_values = set(graph.inputs())
1507*da0073e9SAndroid Build Coastguard Worker        for node in upper_nodes:
1508*da0073e9SAndroid Build Coastguard Worker            self._partition_node(
1509*da0073e9SAndroid Build Coastguard Worker                node,
1510*da0073e9SAndroid Build Coastguard Worker                complete_upper_nodes_set,
1511*da0073e9SAndroid Build Coastguard Worker                complete_lower_nodes_set,
1512*da0073e9SAndroid Build Coastguard Worker                original_graph_outputs,
1513*da0073e9SAndroid Build Coastguard Worker                covered_bridge_values,
1514*da0073e9SAndroid Build Coastguard Worker                process_bridge_value,
1515*da0073e9SAndroid Build Coastguard Worker            )
1516*da0073e9SAndroid Build Coastguard Worker        return (
1517*da0073e9SAndroid Build Coastguard Worker            upper_nodes,
1518*da0073e9SAndroid Build Coastguard Worker            lower_nodes,
1519*da0073e9SAndroid Build Coastguard Worker            complete_upper_nodes_set,
1520*da0073e9SAndroid Build Coastguard Worker            complete_lower_nodes_set,
1521*da0073e9SAndroid Build Coastguard Worker        )
1522*da0073e9SAndroid Build Coastguard Worker
1523*da0073e9SAndroid Build Coastguard Worker    def _bridge_kwargs(self):
1524*da0073e9SAndroid Build Coastguard Worker        pt_outs = self.pt_outs
1525*da0073e9SAndroid Build Coastguard Worker        graph_outputs = list(self.graph.outputs())
1526*da0073e9SAndroid Build Coastguard Worker        assert pt_outs is not None
1527*da0073e9SAndroid Build Coastguard Worker        assert len(graph_outputs) == len(
1528*da0073e9SAndroid Build Coastguard Worker            pt_outs
1529*da0073e9SAndroid Build Coastguard Worker        ), f"{len(graph_outputs)} vs {len(pt_outs)}\nGraph: {self.graph}"
1530*da0073e9SAndroid Build Coastguard Worker        return {v.debugName(): o for v, o in zip(graph_outputs, pt_outs)}
1531*da0073e9SAndroid Build Coastguard Worker
1532*da0073e9SAndroid Build Coastguard Worker    def _args_and_params_for_partition_graph(
1533*da0073e9SAndroid Build Coastguard Worker        self,
1534*da0073e9SAndroid Build Coastguard Worker        graph: torch.Graph,
1535*da0073e9SAndroid Build Coastguard Worker        bridge_kwargs: Mapping[str, _NumericType | Sequence[_NumericType]],
1536*da0073e9SAndroid Build Coastguard Worker        full_kwargs: Mapping[str, torch.Tensor],
1537*da0073e9SAndroid Build Coastguard Worker        full_params: Mapping[str, torch.Tensor],
1538*da0073e9SAndroid Build Coastguard Worker    ):
1539*da0073e9SAndroid Build Coastguard Worker        input_names = [input.debugName() for input in graph.inputs()]
1540*da0073e9SAndroid Build Coastguard Worker        args = tuple(bridge_kwargs[k] for k in input_names if k in bridge_kwargs)
1541*da0073e9SAndroid Build Coastguard Worker        args += tuple(full_kwargs[k] for k in input_names if k in full_kwargs)
1542*da0073e9SAndroid Build Coastguard Worker        params = {k: full_params[k] for k in input_names if k in full_params}
1543*da0073e9SAndroid Build Coastguard Worker        assert len(args) + len(params) == len(
1544*da0073e9SAndroid Build Coastguard Worker            input_names
1545*da0073e9SAndroid Build Coastguard Worker        ), f"{len(args)} + {len(params)} vs {len(input_names)}: {input_names}"
1546*da0073e9SAndroid Build Coastguard Worker        return args, params
1547*da0073e9SAndroid Build Coastguard Worker
1548*da0073e9SAndroid Build Coastguard Worker    def verify_export(
1549*da0073e9SAndroid Build Coastguard Worker        self, options: VerificationOptions
1550*da0073e9SAndroid Build Coastguard Worker    ) -> tuple[AssertionError | None, torch.Graph, _OutputsType, _OutputsType]:
1551*da0073e9SAndroid Build Coastguard Worker        """
1552*da0073e9SAndroid Build Coastguard Worker        Verify the export from TorchScript IR graph to ONNX.
1553*da0073e9SAndroid Build Coastguard Worker
1554*da0073e9SAndroid Build Coastguard Worker        Export the TorchScript IR graph to ONNX, with the inputs, parameters and export
1555*da0073e9SAndroid Build Coastguard Worker        options recorded in this object. Then verify the exported ONNX graph against
1556*da0073e9SAndroid Build Coastguard Worker        the original TorchScript IR graph under the provided verification options.
1557*da0073e9SAndroid Build Coastguard Worker
1558*da0073e9SAndroid Build Coastguard Worker        Args:
1559*da0073e9SAndroid Build Coastguard Worker            options: The verification options.
1560*da0073e9SAndroid Build Coastguard Worker
1561*da0073e9SAndroid Build Coastguard Worker        Returns:
1562*da0073e9SAndroid Build Coastguard Worker            error: The AssertionError raised during the verification. Returns None if no
1563*da0073e9SAndroid Build Coastguard Worker            error is raised.
1564*da0073e9SAndroid Build Coastguard Worker            onnx_graph: The exported ONNX graph in TorchScript IR format.
1565*da0073e9SAndroid Build Coastguard Worker            onnx_outs: The outputs from running exported ONNX model under the onnx
1566*da0073e9SAndroid Build Coastguard Worker            backend in `options`.
1567*da0073e9SAndroid Build Coastguard Worker            pt_outs: The outputs from running the TorchScript IR graph.
1568*da0073e9SAndroid Build Coastguard Worker        """
1569*da0073e9SAndroid Build Coastguard Worker        return verify_aten_graph(
1570*da0073e9SAndroid Build Coastguard Worker            self.graph,
1571*da0073e9SAndroid Build Coastguard Worker            input_args=self.input_args,
1572*da0073e9SAndroid Build Coastguard Worker            params_dict=self.params_dict,
1573*da0073e9SAndroid Build Coastguard Worker            export_options=self.export_options,
1574*da0073e9SAndroid Build Coastguard Worker            verification_options=options,
1575*da0073e9SAndroid Build Coastguard Worker        )
1576*da0073e9SAndroid Build Coastguard Worker
1577*da0073e9SAndroid Build Coastguard Worker    def find_mismatch(
1578*da0073e9SAndroid Build Coastguard Worker        self,
1579*da0073e9SAndroid Build Coastguard Worker        options: VerificationOptions | None = None,
1580*da0073e9SAndroid Build Coastguard Worker    ):
1581*da0073e9SAndroid Build Coastguard Worker        """
1582*da0073e9SAndroid Build Coastguard Worker        Find all mismatches between the TorchScript IR graph and the exported onnx model.
1583*da0073e9SAndroid Build Coastguard Worker
1584*da0073e9SAndroid Build Coastguard Worker        Binary searches the model graph to find the minimal subgraph that exhibits the
1585*da0073e9SAndroid Build Coastguard Worker        mismatch. A `GraphInfo` object is created for each subgraph, recording the test
1586*da0073e9SAndroid Build Coastguard Worker        inputs and export options, as well as the validation results.
1587*da0073e9SAndroid Build Coastguard Worker
1588*da0073e9SAndroid Build Coastguard Worker        Args:
1589*da0073e9SAndroid Build Coastguard Worker            options: The verification options.
1590*da0073e9SAndroid Build Coastguard Worker        """
1591*da0073e9SAndroid Build Coastguard Worker        self.clear()
1592*da0073e9SAndroid Build Coastguard Worker
1593*da0073e9SAndroid Build Coastguard Worker        if options is None:
1594*da0073e9SAndroid Build Coastguard Worker            options = VerificationOptions()
1595*da0073e9SAndroid Build Coastguard Worker
1596*da0073e9SAndroid Build Coastguard Worker        if self.export_options.verbose:
1597*da0073e9SAndroid Build Coastguard Worker            print(self.graph)
1598*da0073e9SAndroid Build Coastguard Worker
1599*da0073e9SAndroid Build Coastguard Worker        if len(list(self.graph.outputs())) == 0:
1600*da0073e9SAndroid Build Coastguard Worker            return
1601*da0073e9SAndroid Build Coastguard Worker
1602*da0073e9SAndroid Build Coastguard Worker        assert len(self.input_args) + len(self.params_dict) == len(
1603*da0073e9SAndroid Build Coastguard Worker            list(self.graph.inputs())
1604*da0073e9SAndroid Build Coastguard Worker        ), (
1605*da0073e9SAndroid Build Coastguard Worker            f"Number of graph inputs({len(list(self.graph.inputs()))}) does not match "
1606*da0073e9SAndroid Build Coastguard Worker            f"the provided tensor arguments({len(self.input_args)} + {len(self.params_dict)})."
1607*da0073e9SAndroid Build Coastguard Worker        )
1608*da0073e9SAndroid Build Coastguard Worker
1609*da0073e9SAndroid Build Coastguard Worker        self.mismatch_error, self._onnx_graph, self.pt_outs, _ = self.verify_export(
1610*da0073e9SAndroid Build Coastguard Worker            options
1611*da0073e9SAndroid Build Coastguard Worker        )
1612*da0073e9SAndroid Build Coastguard Worker
1613*da0073e9SAndroid Build Coastguard Worker        if self.mismatch_error is None:
1614*da0073e9SAndroid Build Coastguard Worker            # No mismatch found in graph.
1615*da0073e9SAndroid Build Coastguard Worker            return
1616*da0073e9SAndroid Build Coastguard Worker
1617*da0073e9SAndroid Build Coastguard Worker        if self.essential_node_count() <= 1:
1618*da0073e9SAndroid Build Coastguard Worker            # Reached leaf node, no more partitioning.
1619*da0073e9SAndroid Build Coastguard Worker            return
1620*da0073e9SAndroid Build Coastguard Worker
1621*da0073e9SAndroid Build Coastguard Worker        full_kwargs = {
1622*da0073e9SAndroid Build Coastguard Worker            k.debugName(): v for k, v in zip(self.graph.inputs(), self.input_args)
1623*da0073e9SAndroid Build Coastguard Worker        }
1624*da0073e9SAndroid Build Coastguard Worker        full_params = self.params_dict
1625*da0073e9SAndroid Build Coastguard Worker
1626*da0073e9SAndroid Build Coastguard Worker        upper_graph = self._partition_upper_graph()
1627*da0073e9SAndroid Build Coastguard Worker        upper_args, upper_params = self._args_and_params_for_partition_graph(
1628*da0073e9SAndroid Build Coastguard Worker            upper_graph, {}, full_kwargs, full_params
1629*da0073e9SAndroid Build Coastguard Worker        )
1630*da0073e9SAndroid Build Coastguard Worker        self.upper_graph_info = GraphInfo(
1631*da0073e9SAndroid Build Coastguard Worker            upper_graph,
1632*da0073e9SAndroid Build Coastguard Worker            upper_args,
1633*da0073e9SAndroid Build Coastguard Worker            upper_params,
1634*da0073e9SAndroid Build Coastguard Worker            self.export_options,
1635*da0073e9SAndroid Build Coastguard Worker            id=self.id + "0",
1636*da0073e9SAndroid Build Coastguard Worker        )
1637*da0073e9SAndroid Build Coastguard Worker
1638*da0073e9SAndroid Build Coastguard Worker        self.upper_graph_info.find_mismatch(options)
1639*da0073e9SAndroid Build Coastguard Worker
1640*da0073e9SAndroid Build Coastguard Worker        bridge_kwargs = self.upper_graph_info._bridge_kwargs()
1641*da0073e9SAndroid Build Coastguard Worker        lower_graph = self._partition_lower_graph()
1642*da0073e9SAndroid Build Coastguard Worker        lower_args, lower_params = self._args_and_params_for_partition_graph(
1643*da0073e9SAndroid Build Coastguard Worker            lower_graph, bridge_kwargs, full_kwargs, full_params
1644*da0073e9SAndroid Build Coastguard Worker        )
1645*da0073e9SAndroid Build Coastguard Worker        self.lower_graph_info = GraphInfo(
1646*da0073e9SAndroid Build Coastguard Worker            lower_graph,
1647*da0073e9SAndroid Build Coastguard Worker            lower_args,
1648*da0073e9SAndroid Build Coastguard Worker            lower_params,
1649*da0073e9SAndroid Build Coastguard Worker            self.export_options,
1650*da0073e9SAndroid Build Coastguard Worker            id=self.id + "1",
1651*da0073e9SAndroid Build Coastguard Worker        )
1652*da0073e9SAndroid Build Coastguard Worker
1653*da0073e9SAndroid Build Coastguard Worker        self.lower_graph_info.find_mismatch(options)
1654*da0073e9SAndroid Build Coastguard Worker
1655*da0073e9SAndroid Build Coastguard Worker
1656*da0073e9SAndroid Build Coastguard Workerdef _all_nodes(nodes: Collection[torch.Node]) -> set[torch.Node]:
1657*da0073e9SAndroid Build Coastguard Worker    all_nodes = set(nodes)
1658*da0073e9SAndroid Build Coastguard Worker    for n in nodes:
1659*da0073e9SAndroid Build Coastguard Worker        for b in n.blocks():
1660*da0073e9SAndroid Build Coastguard Worker            all_nodes.update(_all_nodes(list(b.nodes())))
1661*da0073e9SAndroid Build Coastguard Worker    return all_nodes
1662*da0073e9SAndroid Build Coastguard Worker
1663*da0073e9SAndroid Build Coastguard Worker
1664*da0073e9SAndroid Build Coastguard Workerdef _has_uses_by_nodes(value: torch.Value, nodes: Collection[torch.Node]) -> bool:
1665*da0073e9SAndroid Build Coastguard Worker    return any(use.user in nodes for use in value.uses())
1666*da0073e9SAndroid Build Coastguard Worker
1667*da0073e9SAndroid Build Coastguard Worker
1668*da0073e9SAndroid Build Coastguard Workerdef _node_has_uses_by(node: torch.Node, nodes: Collection[torch.Node]) -> bool:
1669*da0073e9SAndroid Build Coastguard Worker    for output in node.outputs():
1670*da0073e9SAndroid Build Coastguard Worker        if _has_uses_by_nodes(output, nodes):
1671*da0073e9SAndroid Build Coastguard Worker            return True
1672*da0073e9SAndroid Build Coastguard Worker    return False
1673*da0073e9SAndroid Build Coastguard Worker
1674*da0073e9SAndroid Build Coastguard Worker
1675*da0073e9SAndroid Build Coastguard Workerdef _produced_by(value: torch.Value, nodes: Collection[torch.Node]) -> bool:
1676*da0073e9SAndroid Build Coastguard Worker    return value.node() in nodes
1677*da0073e9SAndroid Build Coastguard Worker
1678*da0073e9SAndroid Build Coastguard Worker
1679*da0073e9SAndroid Build Coastguard Workerdef find_mismatch(
1680*da0073e9SAndroid Build Coastguard Worker    model: torch.nn.Module | torch.jit.ScriptModule,
1681*da0073e9SAndroid Build Coastguard Worker    input_args: tuple[Any, ...],
1682*da0073e9SAndroid Build Coastguard Worker    do_constant_folding: bool = True,
1683*da0073e9SAndroid Build Coastguard Worker    training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL,
1684*da0073e9SAndroid Build Coastguard Worker    opset_version: int | None = None,
1685*da0073e9SAndroid Build Coastguard Worker    keep_initializers_as_inputs: bool = True,
1686*da0073e9SAndroid Build Coastguard Worker    verbose: bool = False,
1687*da0073e9SAndroid Build Coastguard Worker    options: VerificationOptions | None = None,
1688*da0073e9SAndroid Build Coastguard Worker) -> GraphInfo:
1689*da0073e9SAndroid Build Coastguard Worker    r"""Find all mismatches between the original model and the exported model.
1690*da0073e9SAndroid Build Coastguard Worker
1691*da0073e9SAndroid Build Coastguard Worker    Experimental. The API is subject to change.
1692*da0073e9SAndroid Build Coastguard Worker
1693*da0073e9SAndroid Build Coastguard Worker    This tool helps debug the mismatch between the original PyTorch model and exported
1694*da0073e9SAndroid Build Coastguard Worker    ONNX model. It binary searches the model graph to find the minimal subgraph that
1695*da0073e9SAndroid Build Coastguard Worker    exhibits the mismatch.
1696*da0073e9SAndroid Build Coastguard Worker
1697*da0073e9SAndroid Build Coastguard Worker    Args:
1698*da0073e9SAndroid Build Coastguard Worker        model: The model to be exported.
1699*da0073e9SAndroid Build Coastguard Worker        input_args: The input arguments to the model.
1700*da0073e9SAndroid Build Coastguard Worker        do_constant_folding: Same as `do_constant_folding` in :func:`torch.onnx.export`.
1701*da0073e9SAndroid Build Coastguard Worker        training: Same as `training` in :func:`torch.onnx.export`.
1702*da0073e9SAndroid Build Coastguard Worker        opset_version: Same as `opset_version` in :func:`torch.onnx.export`.
1703*da0073e9SAndroid Build Coastguard Worker        keep_initializers_as_inputs: Same as `keep_initializers_as_inputs` in :func:`torch.onnx.export`.
1704*da0073e9SAndroid Build Coastguard Worker        verbose: Same as `verbose` in :func:`torch.onnx.export`.
1705*da0073e9SAndroid Build Coastguard Worker        options: The options for the mismatch verification.
1706*da0073e9SAndroid Build Coastguard Worker
1707*da0073e9SAndroid Build Coastguard Worker    Returns:
1708*da0073e9SAndroid Build Coastguard Worker        A GraphInfo object that contains the mismatch information.
1709*da0073e9SAndroid Build Coastguard Worker
1710*da0073e9SAndroid Build Coastguard Worker    Example::
1711*da0073e9SAndroid Build Coastguard Worker
1712*da0073e9SAndroid Build Coastguard Worker        >>> import torch
1713*da0073e9SAndroid Build Coastguard Worker        >>> import torch.onnx.verification
1714*da0073e9SAndroid Build Coastguard Worker        >>> torch.manual_seed(0)
1715*da0073e9SAndroid Build Coastguard Worker        >>> opset_version = 15
1716*da0073e9SAndroid Build Coastguard Worker        >>> # Define a custom symbolic function for aten::relu.
1717*da0073e9SAndroid Build Coastguard Worker        >>> # The custom symbolic function is incorrect, which will result in mismatches.
1718*da0073e9SAndroid Build Coastguard Worker        >>> def incorrect_relu_symbolic_function(g, self):
1719*da0073e9SAndroid Build Coastguard Worker        ...     return self
1720*da0073e9SAndroid Build Coastguard Worker        >>> torch.onnx.register_custom_op_symbolic(
1721*da0073e9SAndroid Build Coastguard Worker        ...     "aten::relu",
1722*da0073e9SAndroid Build Coastguard Worker        ...     incorrect_relu_symbolic_function,
1723*da0073e9SAndroid Build Coastguard Worker        ...     opset_version=opset_version,
1724*da0073e9SAndroid Build Coastguard Worker        ... )
1725*da0073e9SAndroid Build Coastguard Worker        >>> class Model(torch.nn.Module):
1726*da0073e9SAndroid Build Coastguard Worker        ...     def __init__(self) -> None:
1727*da0073e9SAndroid Build Coastguard Worker        ...         super().__init__()
1728*da0073e9SAndroid Build Coastguard Worker        ...         self.layers = torch.nn.Sequential(
1729*da0073e9SAndroid Build Coastguard Worker        ...             torch.nn.Linear(3, 4),
1730*da0073e9SAndroid Build Coastguard Worker        ...             torch.nn.ReLU(),
1731*da0073e9SAndroid Build Coastguard Worker        ...             torch.nn.Linear(4, 5),
1732*da0073e9SAndroid Build Coastguard Worker        ...             torch.nn.ReLU(),
1733*da0073e9SAndroid Build Coastguard Worker        ...             torch.nn.Linear(5, 6),
1734*da0073e9SAndroid Build Coastguard Worker        ...         )
1735*da0073e9SAndroid Build Coastguard Worker        ...     def forward(self, x):
1736*da0073e9SAndroid Build Coastguard Worker        ...         return self.layers(x)
1737*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX)
1738*da0073e9SAndroid Build Coastguard Worker        >>> graph_info = torch.onnx.verification.find_mismatch(
1739*da0073e9SAndroid Build Coastguard Worker        ...     Model(),
1740*da0073e9SAndroid Build Coastguard Worker        ...     (torch.randn(2, 3),),
1741*da0073e9SAndroid Build Coastguard Worker        ...     opset_version=opset_version,
1742*da0073e9SAndroid Build Coastguard Worker        ... )
1743*da0073e9SAndroid Build Coastguard Worker        ===================== Mismatch info for graph partition : ======================
1744*da0073e9SAndroid Build Coastguard Worker        ================================ Mismatch error ================================
1745*da0073e9SAndroid Build Coastguard Worker        Tensor-likes are not close!
1746*da0073e9SAndroid Build Coastguard Worker        Mismatched elements: 12 / 12 (100.0%)
1747*da0073e9SAndroid Build Coastguard Worker        Greatest absolute difference: 0.2328854203224182 at index (1, 2) (up to 1e-07 allowed)
1748*da0073e9SAndroid Build Coastguard Worker        Greatest relative difference: 0.699536174352349 at index (1, 3) (up to 0.001 allowed)
1749*da0073e9SAndroid Build Coastguard Worker        ==================================== Tree: =====================================
1750*da0073e9SAndroid Build Coastguard Worker        5 X   __2 X    __1 \u2713
1751*da0073e9SAndroid Build Coastguard Worker        id:  |  id: 0 |  id: 00
1752*da0073e9SAndroid Build Coastguard Worker             |        |
1753*da0073e9SAndroid Build Coastguard Worker             |        |__1 X (aten::relu)
1754*da0073e9SAndroid Build Coastguard Worker             |           id: 01
1755*da0073e9SAndroid Build Coastguard Worker             |
1756*da0073e9SAndroid Build Coastguard Worker             |__3 X    __1 \u2713
1757*da0073e9SAndroid Build Coastguard Worker                id: 1 |  id: 10
1758*da0073e9SAndroid Build Coastguard Worker                      |
1759*da0073e9SAndroid Build Coastguard Worker                      |__2 X     __1 X (aten::relu)
1760*da0073e9SAndroid Build Coastguard Worker                         id: 11 |  id: 110
1761*da0073e9SAndroid Build Coastguard Worker                                |
1762*da0073e9SAndroid Build Coastguard Worker                                |__1 \u2713
1763*da0073e9SAndroid Build Coastguard Worker                                   id: 111
1764*da0073e9SAndroid Build Coastguard Worker        =========================== Mismatch leaf subgraphs: ===========================
1765*da0073e9SAndroid Build Coastguard Worker        ['01', '110']
1766*da0073e9SAndroid Build Coastguard Worker        ============================= Mismatch node kinds: =============================
1767*da0073e9SAndroid Build Coastguard Worker        {'aten::relu': 2}
1768*da0073e9SAndroid Build Coastguard Worker
1769*da0073e9SAndroid Build Coastguard Worker    """
1770*da0073e9SAndroid Build Coastguard Worker    if options is None:
1771*da0073e9SAndroid Build Coastguard Worker        options = VerificationOptions()
1772*da0073e9SAndroid Build Coastguard Worker    if opset_version is None:
1773*da0073e9SAndroid Build Coastguard Worker        opset_version = _constants.ONNX_DEFAULT_OPSET
1774*da0073e9SAndroid Build Coastguard Worker    """From aten graph, do binary search on graph partition to find operator export discrepancy."""
1775*da0073e9SAndroid Build Coastguard Worker    # TODO: Copied from utils.py `export` until `_optimize_graph`.
1776*da0073e9SAndroid Build Coastguard Worker    if training == torch.onnx.TrainingMode.TRAINING:
1777*da0073e9SAndroid Build Coastguard Worker        model.train()
1778*da0073e9SAndroid Build Coastguard Worker    elif training == torch.onnx.TrainingMode.EVAL:
1779*da0073e9SAndroid Build Coastguard Worker        model.eval()
1780*da0073e9SAndroid Build Coastguard Worker    with torch.no_grad():
1781*da0073e9SAndroid Build Coastguard Worker        inputs_for_export = _prepare_input_for_export(input_args, {})
1782*da0073e9SAndroid Build Coastguard Worker        args = utils._decide_input_format(model, inputs_for_export)
1783*da0073e9SAndroid Build Coastguard Worker
1784*da0073e9SAndroid Build Coastguard Worker        model = utils._pre_trace_quant_model(model, args)
1785*da0073e9SAndroid Build Coastguard Worker        graph, params, torch_out, module = utils._create_jit_graph(model, args)
1786*da0073e9SAndroid Build Coastguard Worker        params_dict = utils._get_named_param_dict(graph, params)
1787*da0073e9SAndroid Build Coastguard Worker
1788*da0073e9SAndroid Build Coastguard Worker        utils._apply_friendly_debug_names(graph, params_dict)
1789*da0073e9SAndroid Build Coastguard Worker
1790*da0073e9SAndroid Build Coastguard Worker        graph_info = GraphInfo(
1791*da0073e9SAndroid Build Coastguard Worker            graph,
1792*da0073e9SAndroid Build Coastguard Worker            input_args,
1793*da0073e9SAndroid Build Coastguard Worker            params_dict,
1794*da0073e9SAndroid Build Coastguard Worker            _experimental.ExportOptions(
1795*da0073e9SAndroid Build Coastguard Worker                do_constant_folding=do_constant_folding,
1796*da0073e9SAndroid Build Coastguard Worker                training=training,
1797*da0073e9SAndroid Build Coastguard Worker                opset_version=opset_version,
1798*da0073e9SAndroid Build Coastguard Worker                keep_initializers_as_inputs=keep_initializers_as_inputs,
1799*da0073e9SAndroid Build Coastguard Worker                verbose=verbose,
1800*da0073e9SAndroid Build Coastguard Worker            ),
1801*da0073e9SAndroid Build Coastguard Worker        )
1802*da0073e9SAndroid Build Coastguard Worker        graph_info.find_mismatch(options)
1803*da0073e9SAndroid Build Coastguard Worker        graph_info.pretty_print_mismatch()
1804*da0073e9SAndroid Build Coastguard Worker        graph_info.pretty_print_tree()
1805*da0073e9SAndroid Build Coastguard Worker
1806*da0073e9SAndroid Build Coastguard Worker        return graph_info
1807