1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Worker"""Functions to verify exported ONNX model is functionally equivalent to original PyTorch model. 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard WorkerONNX Runtime is required, and is used as the ONNX backend for export verification. 5*da0073e9SAndroid Build Coastguard Worker""" 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerimport contextlib 10*da0073e9SAndroid Build Coastguard Workerimport copy 11*da0073e9SAndroid Build Coastguard Workerimport dataclasses 12*da0073e9SAndroid Build Coastguard Workerimport datetime 13*da0073e9SAndroid Build Coastguard Workerimport difflib 14*da0073e9SAndroid Build Coastguard Workerimport enum 15*da0073e9SAndroid Build Coastguard Workerimport functools 16*da0073e9SAndroid Build Coastguard Workerimport io 17*da0073e9SAndroid Build Coastguard Workerimport itertools 18*da0073e9SAndroid Build Coastguard Workerimport os 19*da0073e9SAndroid Build Coastguard Workerimport tempfile 20*da0073e9SAndroid Build Coastguard Workerimport warnings 21*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Callable, Collection, Mapping, Sequence, Tuple, Union 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Workerimport numpy as np 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Workerimport torch 26*da0073e9SAndroid Build Coastguard Workerimport torch._C._onnx as _C_onnx 27*da0073e9SAndroid Build Coastguard Workerfrom torch import _C 28*da0073e9SAndroid Build Coastguard Workerfrom torch.onnx import _constants, _experimental, _exporter_states, utils 29*da0073e9SAndroid Build Coastguard Workerfrom torch.onnx._globals import GLOBALS 30*da0073e9SAndroid Build Coastguard Workerfrom torch.onnx._internal import onnx_proto_utils 31*da0073e9SAndroid Build Coastguard Workerfrom torch.types import Number 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Worker_ORT_PROVIDERS = ("CPUExecutionProvider",) 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Worker_NumericType = Union[Number, torch.Tensor, np.ndarray] 37*da0073e9SAndroid Build Coastguard Worker_ModelType = Union[torch.nn.Module, torch.jit.ScriptModule] 38*da0073e9SAndroid Build Coastguard Worker_InputArgsType = Union[torch.Tensor, Tuple[Any, ...]] 39*da0073e9SAndroid Build Coastguard Worker_InputKwargsType = Mapping[str, Any] 40*da0073e9SAndroid Build Coastguard Worker_OutputsType = Union[Sequence[_NumericType], Sequence] 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Workerclass OnnxBackend(enum.Enum): 44*da0073e9SAndroid Build Coastguard Worker """Enum class for ONNX backend used for export verification.""" 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Worker REFERENCE = "ONNXReferenceEvaluator" 47*da0073e9SAndroid Build Coastguard Worker ONNX_RUNTIME_CPU = "CPUExecutionProvider" 48*da0073e9SAndroid Build Coastguard Worker ONNX_RUNTIME_CUDA = "CUDAExecutionProvider" 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Worker@dataclasses.dataclass 52*da0073e9SAndroid Build Coastguard Workerclass VerificationOptions: 53*da0073e9SAndroid Build Coastguard Worker """Options for ONNX export verification. 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker Attributes: 56*da0073e9SAndroid Build Coastguard Worker flatten: If True, unpack nested list/tuple/dict inputs into a flattened list of 57*da0073e9SAndroid Build Coastguard Worker Tensors for ONNX. Set this to False if nested structures are to be preserved 58*da0073e9SAndroid Build Coastguard Worker for ONNX, which is usually the case with exporting ScriptModules. Default True. 59*da0073e9SAndroid Build Coastguard Worker ignore_none: Whether to ignore None type in torch output, which is usually the 60*da0073e9SAndroid Build Coastguard Worker case with tracing. Set this to False, if torch output should keep None type, 61*da0073e9SAndroid Build Coastguard Worker which is usually the case with exporting ScriptModules. Default to True. 62*da0073e9SAndroid Build Coastguard Worker check_shape: Whether to check the shapes between PyTorch and ONNX Runtime outputs 63*da0073e9SAndroid Build Coastguard Worker are exactly the same. Set this to False to allow output shape broadcasting. 64*da0073e9SAndroid Build Coastguard Worker Default to True. 65*da0073e9SAndroid Build Coastguard Worker check_dtype: Whether to check the dtypes between PyTorch and ONNX Runtime outputs 66*da0073e9SAndroid Build Coastguard Worker are consistent. Default to True. 67*da0073e9SAndroid Build Coastguard Worker backend: ONNX backend for verification. Default to OnnxBackend.ONNX_RUNTIME_CPU. 68*da0073e9SAndroid Build Coastguard Worker rtol: relative tolerance in comparison between ONNX and PyTorch outputs. 69*da0073e9SAndroid Build Coastguard Worker atol: absolute tolerance in comparison between ONNX and PyTorch outputs. 70*da0073e9SAndroid Build Coastguard Worker remained_onnx_input_idx: If provided, only the specified inputs will be passed 71*da0073e9SAndroid Build Coastguard Worker to the ONNX model. Supply a list when there are unused inputs in the model. 72*da0073e9SAndroid Build Coastguard Worker Since unused inputs will be removed in the exported ONNX model, supplying 73*da0073e9SAndroid Build Coastguard Worker all inputs will cause an error on unexpected inputs. This parameter tells 74*da0073e9SAndroid Build Coastguard Worker the verifier which inputs to pass into the ONNX model. 75*da0073e9SAndroid Build Coastguard Worker acceptable_error_percentage: acceptable percentage of element mismatches in comparison. 76*da0073e9SAndroid Build Coastguard Worker It should be a float of value between 0.0 and 1.0. 77*da0073e9SAndroid Build Coastguard Worker """ 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker flatten: bool = True 80*da0073e9SAndroid Build Coastguard Worker ignore_none: bool = True 81*da0073e9SAndroid Build Coastguard Worker check_shape: bool = True 82*da0073e9SAndroid Build Coastguard Worker check_dtype: bool = True 83*da0073e9SAndroid Build Coastguard Worker backend: OnnxBackend = OnnxBackend.ONNX_RUNTIME_CPU 84*da0073e9SAndroid Build Coastguard Worker rtol: float = 1e-3 85*da0073e9SAndroid Build Coastguard Worker atol: float = 1e-7 86*da0073e9SAndroid Build Coastguard Worker remained_onnx_input_idx: Sequence[int] | None = None 87*da0073e9SAndroid Build Coastguard Worker acceptable_error_percentage: float | None = None 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Workerdef _flatten_tuples(elem): 91*da0073e9SAndroid Build Coastguard Worker flattened = [] 92*da0073e9SAndroid Build Coastguard Worker for t in elem: 93*da0073e9SAndroid Build Coastguard Worker if isinstance(t, tuple): 94*da0073e9SAndroid Build Coastguard Worker flattened.extend(_flatten_tuples(t)) 95*da0073e9SAndroid Build Coastguard Worker else: 96*da0073e9SAndroid Build Coastguard Worker flattened.append(t) 97*da0073e9SAndroid Build Coastguard Worker return flattened 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker# TODO(justinchuby): Add type checking by narrowing down the return type when input is None 101*da0073e9SAndroid Build Coastguard Workerdef _to_numpy(elem) -> list | np.ndarray: 102*da0073e9SAndroid Build Coastguard Worker if isinstance(elem, torch.Tensor): 103*da0073e9SAndroid Build Coastguard Worker if elem.requires_grad: 104*da0073e9SAndroid Build Coastguard Worker return elem.detach().cpu().numpy() 105*da0073e9SAndroid Build Coastguard Worker else: 106*da0073e9SAndroid Build Coastguard Worker return elem.cpu().numpy() 107*da0073e9SAndroid Build Coastguard Worker elif isinstance(elem, (list, tuple)): 108*da0073e9SAndroid Build Coastguard Worker return [_to_numpy(inp) for inp in elem] 109*da0073e9SAndroid Build Coastguard Worker elif isinstance(elem, (bool, int, float)): 110*da0073e9SAndroid Build Coastguard Worker return np.array(elem) 111*da0073e9SAndroid Build Coastguard Worker elif isinstance(elem, dict): 112*da0073e9SAndroid Build Coastguard Worker flattened = [] 113*da0073e9SAndroid Build Coastguard Worker for k in elem: 114*da0073e9SAndroid Build Coastguard Worker flattened.extend([_to_numpy(k), _to_numpy(elem[k])]) 115*da0073e9SAndroid Build Coastguard Worker return flattened 116*da0073e9SAndroid Build Coastguard Worker return elem 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Worker 119*da0073e9SAndroid Build Coastguard Workerdef _inline_flatten_list(inputs, res_list) -> list: 120*da0073e9SAndroid Build Coastguard Worker for i in inputs: 121*da0073e9SAndroid Build Coastguard Worker res_list.append(i) if not isinstance( 122*da0073e9SAndroid Build Coastguard Worker i, (list, tuple) 123*da0073e9SAndroid Build Coastguard Worker ) else _inline_flatten_list(i, res_list) 124*da0073e9SAndroid Build Coastguard Worker return res_list 125*da0073e9SAndroid Build Coastguard Worker 126*da0073e9SAndroid Build Coastguard Worker 127*da0073e9SAndroid Build Coastguard Workerdef _unpack_to_numpy(values, cast_onnx_accepted=True) -> list: 128*da0073e9SAndroid Build Coastguard Worker value_unpacked = [] 129*da0073e9SAndroid Build Coastguard Worker for value in values: 130*da0073e9SAndroid Build Coastguard Worker value_unpacked.extend( 131*da0073e9SAndroid Build Coastguard Worker utils.unpack_quantized_tensor(value, cast_onnx_accepted=cast_onnx_accepted) 132*da0073e9SAndroid Build Coastguard Worker ) 133*da0073e9SAndroid Build Coastguard Worker return [_to_numpy(v) for v in value_unpacked] 134*da0073e9SAndroid Build Coastguard Worker 135*da0073e9SAndroid Build Coastguard Worker 136*da0073e9SAndroid Build Coastguard Workerdef _run_onnx(onnx_session, inputs) -> _OutputsType: 137*da0073e9SAndroid Build Coastguard Worker kw_inputs = {} 138*da0073e9SAndroid Build Coastguard Worker if inputs and isinstance(inputs[-1], dict): 139*da0073e9SAndroid Build Coastguard Worker kw_inputs = inputs[-1] 140*da0073e9SAndroid Build Coastguard Worker inputs = inputs[:-1] 141*da0073e9SAndroid Build Coastguard Worker inputs = _unpack_to_numpy(_flatten_tuples(inputs)) 142*da0073e9SAndroid Build Coastguard Worker ort_inputs = {} 143*da0073e9SAndroid Build Coastguard Worker for input_name, input in kw_inputs.items(): 144*da0073e9SAndroid Build Coastguard Worker ort_inputs[input_name] = _to_numpy(input) 145*da0073e9SAndroid Build Coastguard Worker inputs = _to_numpy(inputs) 146*da0073e9SAndroid Build Coastguard Worker if hasattr(onnx_session, "get_inputs"): 147*da0073e9SAndroid Build Coastguard Worker # onnxruntime.InferenceSession 148*da0073e9SAndroid Build Coastguard Worker input_names = [i.name for i in onnx_session.get_inputs()] 149*da0073e9SAndroid Build Coastguard Worker elif hasattr(onnx_session, "input_names"): 150*da0073e9SAndroid Build Coastguard Worker # onnx.reference.ReferenceEvaluator 151*da0073e9SAndroid Build Coastguard Worker input_names = onnx_session.input_names 152*da0073e9SAndroid Build Coastguard Worker else: 153*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Unknown ONNX backend type: {type(onnx_session)}.") 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Worker for i, input in enumerate(inputs): 156*da0073e9SAndroid Build Coastguard Worker if i == len(input_names) or input_names[i] in ort_inputs: 157*da0073e9SAndroid Build Coastguard Worker raise ValueError( 158*da0073e9SAndroid Build Coastguard Worker f"got too many positional inputs. inputs: {inputs}. kw_inputs: {kw_inputs}. " 159*da0073e9SAndroid Build Coastguard Worker f"input names: {input_names}." 160*da0073e9SAndroid Build Coastguard Worker ) 161*da0073e9SAndroid Build Coastguard Worker ort_inputs[input_names[i]] = input 162*da0073e9SAndroid Build Coastguard Worker onnx_outs = onnx_session.run(None, ort_inputs) 163*da0073e9SAndroid Build Coastguard Worker return onnx_outs 164*da0073e9SAndroid Build Coastguard Worker 165*da0073e9SAndroid Build Coastguard Worker 166*da0073e9SAndroid Build Coastguard Workerdef _ort_session( 167*da0073e9SAndroid Build Coastguard Worker model: str | io.BytesIO, ort_providers: Sequence[str] = _ORT_PROVIDERS 168*da0073e9SAndroid Build Coastguard Worker): 169*da0073e9SAndroid Build Coastguard Worker try: 170*da0073e9SAndroid Build Coastguard Worker import onnxruntime # type: ignore[import] 171*da0073e9SAndroid Build Coastguard Worker except ImportError as e: 172*da0073e9SAndroid Build Coastguard Worker raise ImportError("onnxruntime is required for export verification.") from e 173*da0073e9SAndroid Build Coastguard Worker 174*da0073e9SAndroid Build Coastguard Worker if ort_providers is None: 175*da0073e9SAndroid Build Coastguard Worker ort_providers = _ORT_PROVIDERS 176*da0073e9SAndroid Build Coastguard Worker 177*da0073e9SAndroid Build Coastguard Worker session_options = onnxruntime.SessionOptions() 178*da0073e9SAndroid Build Coastguard Worker # suppress ort warnings. 179*da0073e9SAndroid Build Coastguard Worker # 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2. 180*da0073e9SAndroid Build Coastguard Worker session_options.log_severity_level = 3 181*da0073e9SAndroid Build Coastguard Worker ort_session = onnxruntime.InferenceSession( 182*da0073e9SAndroid Build Coastguard Worker model if isinstance(model, str) else model.getvalue(), 183*da0073e9SAndroid Build Coastguard Worker session_options, 184*da0073e9SAndroid Build Coastguard Worker providers=ort_providers, 185*da0073e9SAndroid Build Coastguard Worker ) 186*da0073e9SAndroid Build Coastguard Worker return ort_session 187*da0073e9SAndroid Build Coastguard Worker 188*da0073e9SAndroid Build Coastguard Worker 189*da0073e9SAndroid Build Coastguard Workerdef _onnx_reference_evaluator_session(model: str | io.BytesIO): 190*da0073e9SAndroid Build Coastguard Worker try: 191*da0073e9SAndroid Build Coastguard Worker import onnx 192*da0073e9SAndroid Build Coastguard Worker from onnx import reference as onnx_reference # type: ignore[attr-defined] 193*da0073e9SAndroid Build Coastguard Worker except ImportError as exc: 194*da0073e9SAndroid Build Coastguard Worker raise ImportError("onnx >= 1.13 is required for reference evaluator.") from exc 195*da0073e9SAndroid Build Coastguard Worker 196*da0073e9SAndroid Build Coastguard Worker proto = ( 197*da0073e9SAndroid Build Coastguard Worker onnx.load(model) # type: ignore[attr-defined] 198*da0073e9SAndroid Build Coastguard Worker if isinstance(model, str) 199*da0073e9SAndroid Build Coastguard Worker else onnx.load_model_from_string(model.getvalue()) # type: ignore[attr-defined] 200*da0073e9SAndroid Build Coastguard Worker ) 201*da0073e9SAndroid Build Coastguard Worker onnx_session = onnx_reference.ReferenceEvaluator(proto) 202*da0073e9SAndroid Build Coastguard Worker return onnx_session 203*da0073e9SAndroid Build Coastguard Worker 204*da0073e9SAndroid Build Coastguard Worker 205*da0073e9SAndroid Build Coastguard Workerdef _onnx_backend_session(model: str | io.BytesIO, backend: OnnxBackend): 206*da0073e9SAndroid Build Coastguard Worker if backend == OnnxBackend.REFERENCE: 207*da0073e9SAndroid Build Coastguard Worker onnx_session = _onnx_reference_evaluator_session(model) 208*da0073e9SAndroid Build Coastguard Worker elif backend in {OnnxBackend.ONNX_RUNTIME_CPU, OnnxBackend.ONNX_RUNTIME_CUDA}: 209*da0073e9SAndroid Build Coastguard Worker onnx_session = _ort_session(model, (backend.value,)) 210*da0073e9SAndroid Build Coastguard Worker else: 211*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Unsupported backend: {backend}") 212*da0073e9SAndroid Build Coastguard Worker return onnx_session 213*da0073e9SAndroid Build Coastguard Worker 214*da0073e9SAndroid Build Coastguard Worker 215*da0073e9SAndroid Build Coastguard Workerdef _compare_onnx_pytorch_outputs_in_np( 216*da0073e9SAndroid Build Coastguard Worker onnx_outs: _OutputsType, 217*da0073e9SAndroid Build Coastguard Worker pt_outs: _OutputsType, 218*da0073e9SAndroid Build Coastguard Worker options: VerificationOptions, 219*da0073e9SAndroid Build Coastguard Worker): 220*da0073e9SAndroid Build Coastguard Worker assert ( 221*da0073e9SAndroid Build Coastguard Worker len(onnx_outs) == len(pt_outs) 222*da0073e9SAndroid Build Coastguard Worker ), f"Number of outputs differ ONNX runtime: ({len(onnx_outs)}) PyTorch: ({len(pt_outs)})" 223*da0073e9SAndroid Build Coastguard Worker acceptable_error_percentage = options.acceptable_error_percentage 224*da0073e9SAndroid Build Coastguard Worker if acceptable_error_percentage and ( 225*da0073e9SAndroid Build Coastguard Worker acceptable_error_percentage > 1.0 or acceptable_error_percentage < 0.0 226*da0073e9SAndroid Build Coastguard Worker ): 227*da0073e9SAndroid Build Coastguard Worker raise ValueError( 228*da0073e9SAndroid Build Coastguard Worker "If set, acceptable_error_percentage should be between 0.0 and 1.0" 229*da0073e9SAndroid Build Coastguard Worker ) 230*da0073e9SAndroid Build Coastguard Worker 231*da0073e9SAndroid Build Coastguard Worker for ort_out, pt_out in zip(onnx_outs, pt_outs): 232*da0073e9SAndroid Build Coastguard Worker try: 233*da0073e9SAndroid Build Coastguard Worker # TODO: Remove `check_shape` option once every shape inconsistent issue is addressed. 234*da0073e9SAndroid Build Coastguard Worker if not options.check_shape: 235*da0073e9SAndroid Build Coastguard Worker # Allow different but broadcastable output shapes. 236*da0073e9SAndroid Build Coastguard Worker ort_out, pt_out = np.broadcast_arrays(ort_out, pt_out) 237*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close( 238*da0073e9SAndroid Build Coastguard Worker ort_out, 239*da0073e9SAndroid Build Coastguard Worker pt_out, 240*da0073e9SAndroid Build Coastguard Worker rtol=options.rtol, 241*da0073e9SAndroid Build Coastguard Worker atol=options.atol, 242*da0073e9SAndroid Build Coastguard Worker check_dtype=options.check_dtype, 243*da0073e9SAndroid Build Coastguard Worker equal_nan=True, 244*da0073e9SAndroid Build Coastguard Worker ) 245*da0073e9SAndroid Build Coastguard Worker except AssertionError as e: 246*da0073e9SAndroid Build Coastguard Worker if acceptable_error_percentage: 247*da0073e9SAndroid Build Coastguard Worker error_percentage = 1 - np.sum( 248*da0073e9SAndroid Build Coastguard Worker np.isclose(ort_out, pt_out, rtol=options.rtol, atol=options.atol) 249*da0073e9SAndroid Build Coastguard Worker ) / np.prod(ort_out.shape) 250*da0073e9SAndroid Build Coastguard Worker if error_percentage <= acceptable_error_percentage: 251*da0073e9SAndroid Build Coastguard Worker warnings.warn( 252*da0073e9SAndroid Build Coastguard Worker f"Suppressed AssertionError:\n{e}.\n" 253*da0073e9SAndroid Build Coastguard Worker f"Error percentage {error_percentage} " 254*da0073e9SAndroid Build Coastguard Worker f"within acceptable range {acceptable_error_percentage}." 255*da0073e9SAndroid Build Coastguard Worker ) 256*da0073e9SAndroid Build Coastguard Worker continue 257*da0073e9SAndroid Build Coastguard Worker if ort_out.dtype == np.uint8 or ort_out.dtype == np.int8: 258*da0073e9SAndroid Build Coastguard Worker warnings.warn("ONNX output is quantized") 259*da0073e9SAndroid Build Coastguard Worker if pt_out.dtype == np.uint8 or pt_out.dtype == np.int8: 260*da0073e9SAndroid Build Coastguard Worker warnings.warn("PyTorch output is quantized") 261*da0073e9SAndroid Build Coastguard Worker raise 262*da0073e9SAndroid Build Coastguard Worker 263*da0073e9SAndroid Build Coastguard Worker 264*da0073e9SAndroid Build Coastguard Workerdef _compare_onnx_pytorch_outputs( 265*da0073e9SAndroid Build Coastguard Worker onnx_outs: _OutputsType, 266*da0073e9SAndroid Build Coastguard Worker pt_outs: Any, 267*da0073e9SAndroid Build Coastguard Worker options: VerificationOptions, 268*da0073e9SAndroid Build Coastguard Worker): 269*da0073e9SAndroid Build Coastguard Worker """ 270*da0073e9SAndroid Build Coastguard Worker Compare ONNX and PyTorch outputs. 271*da0073e9SAndroid Build Coastguard Worker 272*da0073e9SAndroid Build Coastguard Worker Args: 273*da0073e9SAndroid Build Coastguard Worker onnx_outs: outputs from ONNX backend. 274*da0073e9SAndroid Build Coastguard Worker pt_outs: outputs from PyTorch. 275*da0073e9SAndroid Build Coastguard Worker options: options for verification. 276*da0073e9SAndroid Build Coastguard Worker 277*da0073e9SAndroid Build Coastguard Worker Raises: 278*da0073e9SAndroid Build Coastguard Worker AssertionError: if outputs from ONNX model and PyTorch model are not 279*da0073e9SAndroid Build Coastguard Worker equal up to specified precision. 280*da0073e9SAndroid Build Coastguard Worker ValueError: if arguments provided are invalid. 281*da0073e9SAndroid Build Coastguard Worker """ 282*da0073e9SAndroid Build Coastguard Worker if options.ignore_none: 283*da0073e9SAndroid Build Coastguard Worker # torch.jit._flatten filters None type 284*da0073e9SAndroid Build Coastguard Worker pt_outs, _ = torch.jit._flatten(pt_outs) 285*da0073e9SAndroid Build Coastguard Worker else: 286*da0073e9SAndroid Build Coastguard Worker pt_outs = _inline_flatten_list([pt_outs], []) 287*da0073e9SAndroid Build Coastguard Worker pt_outs_np = _unpack_to_numpy(pt_outs, cast_onnx_accepted=False) 288*da0073e9SAndroid Build Coastguard Worker onnx_outs = _inline_flatten_list(onnx_outs, []) 289*da0073e9SAndroid Build Coastguard Worker _compare_onnx_pytorch_outputs_in_np(onnx_outs, pt_outs_np, options) 290*da0073e9SAndroid Build Coastguard Worker 291*da0073e9SAndroid Build Coastguard Worker 292*da0073e9SAndroid Build Coastguard Workerdef _prepare_input_for_pytorch(args, kwargs): 293*da0073e9SAndroid Build Coastguard Worker """Prepare input for PyTorch model execution. 294*da0073e9SAndroid Build Coastguard Worker 295*da0073e9SAndroid Build Coastguard Worker Any future changes/formatting to the input before dispatching to the PyTorch 296*da0073e9SAndroid Build Coastguard Worker model should be made in this function. 297*da0073e9SAndroid Build Coastguard Worker 298*da0073e9SAndroid Build Coastguard Worker Args: 299*da0073e9SAndroid Build Coastguard Worker args: positional arguments for PyTorch model forward method. 300*da0073e9SAndroid Build Coastguard Worker kwargs: keyword arguments for PyTorch model forward method. 301*da0073e9SAndroid Build Coastguard Worker 302*da0073e9SAndroid Build Coastguard Worker Returns: 303*da0073e9SAndroid Build Coastguard Worker args: positional arguments for PyTorch model forward method. 304*da0073e9SAndroid Build Coastguard Worker kwargs: keyword arguments for PyTorch model forward method. 305*da0073e9SAndroid Build Coastguard Worker """ 306*da0073e9SAndroid Build Coastguard Worker if isinstance(args, (torch.Tensor, dict)): 307*da0073e9SAndroid Build Coastguard Worker args = (args,) 308*da0073e9SAndroid Build Coastguard Worker # In-place operators will update input tensor data as well. 309*da0073e9SAndroid Build Coastguard Worker # Thus inputs are replicated before every forward call. 310*da0073e9SAndroid Build Coastguard Worker args = copy.deepcopy(args) 311*da0073e9SAndroid Build Coastguard Worker if kwargs: 312*da0073e9SAndroid Build Coastguard Worker kwargs = copy.deepcopy(kwargs) 313*da0073e9SAndroid Build Coastguard Worker else: 314*da0073e9SAndroid Build Coastguard Worker kwargs = {} 315*da0073e9SAndroid Build Coastguard Worker return args, kwargs 316*da0073e9SAndroid Build Coastguard Worker 317*da0073e9SAndroid Build Coastguard Worker 318*da0073e9SAndroid Build Coastguard Workerdef _prepare_input_for_export(args, kwargs): 319*da0073e9SAndroid Build Coastguard Worker """Prepare input for ONNX model export. 320*da0073e9SAndroid Build Coastguard Worker 321*da0073e9SAndroid Build Coastguard Worker Any future changes/formatting to the input before dispatching to the 322*da0073e9SAndroid Build Coastguard Worker :func:`torch.onnx.export` api should be made in this function. 323*da0073e9SAndroid Build Coastguard Worker 324*da0073e9SAndroid Build Coastguard Worker Args: 325*da0073e9SAndroid Build Coastguard Worker args: positional arguments for PyTorch model forward method. 326*da0073e9SAndroid Build Coastguard Worker kwargs: keyword arguments for PyTorch model forward method. 327*da0073e9SAndroid Build Coastguard Worker 328*da0073e9SAndroid Build Coastguard Worker Returns: 329*da0073e9SAndroid Build Coastguard Worker onnx_inputs: positional arguments for ONNX model export, as `args` in 330*da0073e9SAndroid Build Coastguard Worker :func:`torch.onnx.export`. 331*da0073e9SAndroid Build Coastguard Worker """ 332*da0073e9SAndroid Build Coastguard Worker args, kwargs = _prepare_input_for_pytorch(args, kwargs) 333*da0073e9SAndroid Build Coastguard Worker if not kwargs and len(args) > 0 and isinstance(args[-1], dict): 334*da0073e9SAndroid Build Coastguard Worker onnx_inputs = args + ({},) 335*da0073e9SAndroid Build Coastguard Worker elif kwargs: 336*da0073e9SAndroid Build Coastguard Worker onnx_inputs = args + (kwargs,) 337*da0073e9SAndroid Build Coastguard Worker else: 338*da0073e9SAndroid Build Coastguard Worker onnx_inputs = args 339*da0073e9SAndroid Build Coastguard Worker return onnx_inputs 340*da0073e9SAndroid Build Coastguard Worker 341*da0073e9SAndroid Build Coastguard Worker 342*da0073e9SAndroid Build Coastguard Workerdef _prepare_input_for_onnx( 343*da0073e9SAndroid Build Coastguard Worker args, kwargs, remained_onnx_input_idx: Sequence[int] | None, flatten: bool 344*da0073e9SAndroid Build Coastguard Worker): 345*da0073e9SAndroid Build Coastguard Worker """Prepare input for ONNX model execution in ONNX backend. 346*da0073e9SAndroid Build Coastguard Worker 347*da0073e9SAndroid Build Coastguard Worker Any future changes/formatting to the input before dispatching to the ONNX backend 348*da0073e9SAndroid Build Coastguard Worker run should be made in this function. 349*da0073e9SAndroid Build Coastguard Worker 350*da0073e9SAndroid Build Coastguard Worker Args: 351*da0073e9SAndroid Build Coastguard Worker args: positional arguments for PyTorch model forward method. 352*da0073e9SAndroid Build Coastguard Worker kwargs: keyword arguments for PyTorch model forward method. 353*da0073e9SAndroid Build Coastguard Worker remained_onnx_input_idx: indices of inputs to be used for ONNX model execution. 354*da0073e9SAndroid Build Coastguard Worker flatten: whether to flatten the input before dispatching to the ONNX model execution. 355*da0073e9SAndroid Build Coastguard Worker 356*da0073e9SAndroid Build Coastguard Worker Returns: 357*da0073e9SAndroid Build Coastguard Worker onnx_inputs: positional arguments for ONNX model execution in ONNX backend. 358*da0073e9SAndroid Build Coastguard Worker """ 359*da0073e9SAndroid Build Coastguard Worker onnx_inputs = _prepare_input_for_export(args, kwargs) 360*da0073e9SAndroid Build Coastguard Worker if flatten: 361*da0073e9SAndroid Build Coastguard Worker onnx_inputs, _ = torch.jit._flatten(onnx_inputs) 362*da0073e9SAndroid Build Coastguard Worker elif onnx_inputs and onnx_inputs[-1] == {}: 363*da0073e9SAndroid Build Coastguard Worker # Handle empty kwargs (normally removed by flatten). 364*da0073e9SAndroid Build Coastguard Worker onnx_inputs = onnx_inputs[:-1] 365*da0073e9SAndroid Build Coastguard Worker if remained_onnx_input_idx is not None: 366*da0073e9SAndroid Build Coastguard Worker return [onnx_inputs[i] for i in remained_onnx_input_idx] 367*da0073e9SAndroid Build Coastguard Worker else: 368*da0073e9SAndroid Build Coastguard Worker return onnx_inputs 369*da0073e9SAndroid Build Coastguard Worker 370*da0073e9SAndroid Build Coastguard Worker 371*da0073e9SAndroid Build Coastguard Workerdef _try_clone_model(model): 372*da0073e9SAndroid Build Coastguard Worker """Used for preserving original model in case forward mutates model states.""" 373*da0073e9SAndroid Build Coastguard Worker try: 374*da0073e9SAndroid Build Coastguard Worker return copy.deepcopy(model) 375*da0073e9SAndroid Build Coastguard Worker except Exception: 376*da0073e9SAndroid Build Coastguard Worker warnings.warn( 377*da0073e9SAndroid Build Coastguard Worker "Failed to clone model. Model state might be mutated during verification." 378*da0073e9SAndroid Build Coastguard Worker ) 379*da0073e9SAndroid Build Coastguard Worker return model 380*da0073e9SAndroid Build Coastguard Worker 381*da0073e9SAndroid Build Coastguard Worker 382*da0073e9SAndroid Build Coastguard Workerdef _compare_onnx_pytorch_model( 383*da0073e9SAndroid Build Coastguard Worker pt_model: _ModelType, 384*da0073e9SAndroid Build Coastguard Worker onnx_model_f: str | io.BytesIO, 385*da0073e9SAndroid Build Coastguard Worker input_args: _InputArgsType, 386*da0073e9SAndroid Build Coastguard Worker input_kwargs: _InputKwargsType | None, 387*da0073e9SAndroid Build Coastguard Worker additional_test_inputs: Sequence[_InputArgsType] | None, 388*da0073e9SAndroid Build Coastguard Worker options: VerificationOptions, 389*da0073e9SAndroid Build Coastguard Worker): 390*da0073e9SAndroid Build Coastguard Worker """Compare outputs from ONNX model runs with outputs from PyTorch model runs. 391*da0073e9SAndroid Build Coastguard Worker 392*da0073e9SAndroid Build Coastguard Worker Args: 393*da0073e9SAndroid Build Coastguard Worker pt_model: PyTorch model. 394*da0073e9SAndroid Build Coastguard Worker onnx_model_f: ONNX model file path or file-like object. 395*da0073e9SAndroid Build Coastguard Worker input_args: positional arguments for PyTorch model forward method. 396*da0073e9SAndroid Build Coastguard Worker input_kwargs: keyword arguments for PyTorch model forward method. 397*da0073e9SAndroid Build Coastguard Worker additional_test_inputs: additional positional arguments for PyTorch model 398*da0073e9SAndroid Build Coastguard Worker forward method. 399*da0073e9SAndroid Build Coastguard Worker options: options for verification. 400*da0073e9SAndroid Build Coastguard Worker 401*da0073e9SAndroid Build Coastguard Worker Raises: 402*da0073e9SAndroid Build Coastguard Worker AssertionError: if outputs from ONNX model and PyTorch model are not 403*da0073e9SAndroid Build Coastguard Worker equal up to specified precision. 404*da0073e9SAndroid Build Coastguard Worker """ 405*da0073e9SAndroid Build Coastguard Worker onnx_session = _onnx_backend_session(onnx_model_f, options.backend) 406*da0073e9SAndroid Build Coastguard Worker 407*da0073e9SAndroid Build Coastguard Worker def compare_onnx_pytorch_model_with_input(input_args, input_kwargs): 408*da0073e9SAndroid Build Coastguard Worker pt_args, pt_kwargs = _prepare_input_for_pytorch(input_args, input_kwargs) 409*da0073e9SAndroid Build Coastguard Worker # TODO: remove this and treat mutating model separately. See #77679 410*da0073e9SAndroid Build Coastguard Worker pt_model_copy = _try_clone_model(pt_model) 411*da0073e9SAndroid Build Coastguard Worker pt_outs = pt_model_copy(*pt_args, **pt_kwargs) 412*da0073e9SAndroid Build Coastguard Worker 413*da0073e9SAndroid Build Coastguard Worker onnx_inputs = _prepare_input_for_onnx( 414*da0073e9SAndroid Build Coastguard Worker input_args, input_kwargs, options.remained_onnx_input_idx, options.flatten 415*da0073e9SAndroid Build Coastguard Worker ) 416*da0073e9SAndroid Build Coastguard Worker 417*da0073e9SAndroid Build Coastguard Worker onnx_outs = _run_onnx(onnx_session, onnx_inputs) 418*da0073e9SAndroid Build Coastguard Worker 419*da0073e9SAndroid Build Coastguard Worker _compare_onnx_pytorch_outputs( 420*da0073e9SAndroid Build Coastguard Worker onnx_outs=onnx_outs, 421*da0073e9SAndroid Build Coastguard Worker pt_outs=pt_outs, 422*da0073e9SAndroid Build Coastguard Worker options=options, 423*da0073e9SAndroid Build Coastguard Worker ) 424*da0073e9SAndroid Build Coastguard Worker 425*da0073e9SAndroid Build Coastguard Worker compare_onnx_pytorch_model_with_input(input_args, input_kwargs) 426*da0073e9SAndroid Build Coastguard Worker 427*da0073e9SAndroid Build Coastguard Worker if additional_test_inputs: 428*da0073e9SAndroid Build Coastguard Worker for test_input_args in additional_test_inputs: 429*da0073e9SAndroid Build Coastguard Worker compare_onnx_pytorch_model_with_input(test_input_args, {}) 430*da0073e9SAndroid Build Coastguard Worker 431*da0073e9SAndroid Build Coastguard Worker 432*da0073e9SAndroid Build Coastguard Workerclass _GraphDiff: 433*da0073e9SAndroid Build Coastguard Worker """A class to represent the difference between two graphs.""" 434*da0073e9SAndroid Build Coastguard Worker 435*da0073e9SAndroid Build Coastguard Worker def __init__(self, graph_a: _C.Graph, graph_b: _C.Graph): 436*da0073e9SAndroid Build Coastguard Worker """Construct a _GraphDiff object. 437*da0073e9SAndroid Build Coastguard Worker 438*da0073e9SAndroid Build Coastguard Worker Args: 439*da0073e9SAndroid Build Coastguard Worker graph_a (_C.Graph): First graph to compare. 440*da0073e9SAndroid Build Coastguard Worker graph_b (_C.Graph): Second graph to compare. 441*da0073e9SAndroid Build Coastguard Worker """ 442*da0073e9SAndroid Build Coastguard Worker self.graph_a = graph_a 443*da0073e9SAndroid Build Coastguard Worker self.graph_b = graph_b 444*da0073e9SAndroid Build Coastguard Worker 445*da0073e9SAndroid Build Coastguard Worker def __str__(self): 446*da0073e9SAndroid Build Coastguard Worker """See function :func:`diff_report`.""" 447*da0073e9SAndroid Build Coastguard Worker return self.diff_report() 448*da0073e9SAndroid Build Coastguard Worker 449*da0073e9SAndroid Build Coastguard Worker def _indent(self, lines: str) -> str: 450*da0073e9SAndroid Build Coastguard Worker return "\n".join(["\t" + line for line in lines.splitlines()]) 451*da0073e9SAndroid Build Coastguard Worker 452*da0073e9SAndroid Build Coastguard Worker def diff_report(self) -> str: 453*da0073e9SAndroid Build Coastguard Worker """Return a string representation of the graph difference. 454*da0073e9SAndroid Build Coastguard Worker 455*da0073e9SAndroid Build Coastguard Worker The report shows the first pair of nodes that diverges. It also shows the source 456*da0073e9SAndroid Build Coastguard Worker location of the pair of nodes. 457*da0073e9SAndroid Build Coastguard Worker 458*da0073e9SAndroid Build Coastguard Worker Returns: 459*da0073e9SAndroid Build Coastguard Worker graph_diff_report (str): A string representation of the graph difference. 460*da0073e9SAndroid Build Coastguard Worker """ 461*da0073e9SAndroid Build Coastguard Worker graph_a = self.graph_a 462*da0073e9SAndroid Build Coastguard Worker graph_b = self.graph_b 463*da0073e9SAndroid Build Coastguard Worker 464*da0073e9SAndroid Build Coastguard Worker graph_a_str = str(graph_a) 465*da0073e9SAndroid Build Coastguard Worker graph_b_str = str(graph_b) 466*da0073e9SAndroid Build Coastguard Worker 467*da0073e9SAndroid Build Coastguard Worker if graph_a_str == graph_b_str: 468*da0073e9SAndroid Build Coastguard Worker return "" 469*da0073e9SAndroid Build Coastguard Worker 470*da0073e9SAndroid Build Coastguard Worker graph_diff = difflib.ndiff( 471*da0073e9SAndroid Build Coastguard Worker graph_a_str.splitlines(True), graph_b_str.splitlines(True) 472*da0073e9SAndroid Build Coastguard Worker ) 473*da0073e9SAndroid Build Coastguard Worker graph_diff_report = ["Graph diff:", self._indent("".join(graph_diff))] 474*da0073e9SAndroid Build Coastguard Worker 475*da0073e9SAndroid Build Coastguard Worker for node_a, node_b in itertools.zip_longest(graph_a.nodes(), graph_b.nodes()): 476*da0073e9SAndroid Build Coastguard Worker if str(node_a) != str(node_b): 477*da0073e9SAndroid Build Coastguard Worker graph_diff_report.append("First diverging operator:") 478*da0073e9SAndroid Build Coastguard Worker node_diff = difflib.ndiff( 479*da0073e9SAndroid Build Coastguard Worker str(node_a).splitlines(True), str(node_b).splitlines(True) 480*da0073e9SAndroid Build Coastguard Worker ) 481*da0073e9SAndroid Build Coastguard Worker source_printout = ["node diff:", self._indent("".join(node_diff))] 482*da0073e9SAndroid Build Coastguard Worker 483*da0073e9SAndroid Build Coastguard Worker stack_a = node_a.sourceRange() if node_a else None 484*da0073e9SAndroid Build Coastguard Worker if stack_a: 485*da0073e9SAndroid Build Coastguard Worker source_printout.extend( 486*da0073e9SAndroid Build Coastguard Worker ["Former source location:", self._indent(str(stack_a))] 487*da0073e9SAndroid Build Coastguard Worker ) 488*da0073e9SAndroid Build Coastguard Worker stack_b = node_b.sourceRange() if node_b else None 489*da0073e9SAndroid Build Coastguard Worker if stack_b: 490*da0073e9SAndroid Build Coastguard Worker source_printout.extend( 491*da0073e9SAndroid Build Coastguard Worker ["Latter source location:", self._indent(str(stack_b))] 492*da0073e9SAndroid Build Coastguard Worker ) 493*da0073e9SAndroid Build Coastguard Worker 494*da0073e9SAndroid Build Coastguard Worker graph_diff_report.extend(source_printout) 495*da0073e9SAndroid Build Coastguard Worker 496*da0073e9SAndroid Build Coastguard Worker break 497*da0073e9SAndroid Build Coastguard Worker 498*da0073e9SAndroid Build Coastguard Worker return "\n".join(graph_diff_report) 499*da0073e9SAndroid Build Coastguard Worker 500*da0073e9SAndroid Build Coastguard Worker 501*da0073e9SAndroid Build Coastguard Workerdef _check_graph_diff( 502*da0073e9SAndroid Build Coastguard Worker model: torch.nn.Module | torch.jit.ScriptModule, 503*da0073e9SAndroid Build Coastguard Worker test_input_groups: Sequence[tuple[tuple[Any, ...], Mapping[str, Any]]], 504*da0073e9SAndroid Build Coastguard Worker export_options: _experimental.ExportOptions, 505*da0073e9SAndroid Build Coastguard Worker model_to_graph_func: Callable[ 506*da0073e9SAndroid Build Coastguard Worker [ 507*da0073e9SAndroid Build Coastguard Worker torch.nn.Module, 508*da0073e9SAndroid Build Coastguard Worker tuple[Any, ...], 509*da0073e9SAndroid Build Coastguard Worker Mapping[str, Any], 510*da0073e9SAndroid Build Coastguard Worker _experimental.ExportOptions, 511*da0073e9SAndroid Build Coastguard Worker ], 512*da0073e9SAndroid Build Coastguard Worker _C.Graph, 513*da0073e9SAndroid Build Coastguard Worker ], 514*da0073e9SAndroid Build Coastguard Worker) -> str: 515*da0073e9SAndroid Build Coastguard Worker """Check if graph produced by `model_to_graph_func` is the same across `test_input_groups`. 516*da0073e9SAndroid Build Coastguard Worker 517*da0073e9SAndroid Build Coastguard Worker Args: 518*da0073e9SAndroid Build Coastguard Worker model: See :func:`check_export_model_diff`. 519*da0073e9SAndroid Build Coastguard Worker test_input_groups: See :func:`check_export_model_diff`. 520*da0073e9SAndroid Build Coastguard Worker export_options: See :func:`check_export_model_diff`. 521*da0073e9SAndroid Build Coastguard Worker model_to_graph_func: A function to convert a PyTorch model to a JIT IR graph. 522*da0073e9SAndroid Build Coastguard Worker 523*da0073e9SAndroid Build Coastguard Worker Returns: 524*da0073e9SAndroid Build Coastguard Worker graph_diff_report (str): A string representation of the graph difference. 525*da0073e9SAndroid Build Coastguard Worker """ 526*da0073e9SAndroid Build Coastguard Worker if len(test_input_groups) < 2: 527*da0073e9SAndroid Build Coastguard Worker raise ValueError("Need at least two groups of test inputs to compare.") 528*da0073e9SAndroid Build Coastguard Worker 529*da0073e9SAndroid Build Coastguard Worker ref_jit_graph = None 530*da0073e9SAndroid Build Coastguard Worker for args, kwargs in test_input_groups: 531*da0073e9SAndroid Build Coastguard Worker jit_graph = model_to_graph_func(model, args, kwargs, export_options) 532*da0073e9SAndroid Build Coastguard Worker if ref_jit_graph is None: 533*da0073e9SAndroid Build Coastguard Worker ref_jit_graph = jit_graph 534*da0073e9SAndroid Build Coastguard Worker continue 535*da0073e9SAndroid Build Coastguard Worker 536*da0073e9SAndroid Build Coastguard Worker graph_diff_report = _GraphDiff(ref_jit_graph, jit_graph).diff_report() 537*da0073e9SAndroid Build Coastguard Worker if graph_diff_report: 538*da0073e9SAndroid Build Coastguard Worker return graph_diff_report 539*da0073e9SAndroid Build Coastguard Worker return "" 540*da0073e9SAndroid Build Coastguard Worker 541*da0073e9SAndroid Build Coastguard Worker 542*da0073e9SAndroid Build Coastguard Workerdef _traced_graph_from_model( 543*da0073e9SAndroid Build Coastguard Worker model: torch.nn.Module | torch.jit.ScriptModule, 544*da0073e9SAndroid Build Coastguard Worker args: tuple[Any, ...], 545*da0073e9SAndroid Build Coastguard Worker kwargs: Mapping[str, Any], 546*da0073e9SAndroid Build Coastguard Worker export_options: _experimental.ExportOptions, 547*da0073e9SAndroid Build Coastguard Worker) -> _C.Graph: 548*da0073e9SAndroid Build Coastguard Worker """As part of the ONNX export steps, create a traced JIT graph from a PyTorch model. 549*da0073e9SAndroid Build Coastguard Worker 550*da0073e9SAndroid Build Coastguard Worker Args: 551*da0073e9SAndroid Build Coastguard Worker model: See :func:`check_export_model_diff`. 552*da0073e9SAndroid Build Coastguard Worker args: See :func:`check_export_model_diff`. 553*da0073e9SAndroid Build Coastguard Worker kwargs: See :func:`check_export_model_diff`. 554*da0073e9SAndroid Build Coastguard Worker export_options: See :func:`check_export_model_diff`. 555*da0073e9SAndroid Build Coastguard Worker 556*da0073e9SAndroid Build Coastguard Worker Returns: 557*da0073e9SAndroid Build Coastguard Worker jit_graph (_C.Graph): A traced JIT graph. 558*da0073e9SAndroid Build Coastguard Worker """ 559*da0073e9SAndroid Build Coastguard Worker training = export_options.training 560*da0073e9SAndroid Build Coastguard Worker verbose = export_options.verbose 561*da0073e9SAndroid Build Coastguard Worker 562*da0073e9SAndroid Build Coastguard Worker with utils.exporter_context(model, training, verbose): 563*da0073e9SAndroid Build Coastguard Worker export_inputs = _prepare_input_for_export(args, kwargs) 564*da0073e9SAndroid Build Coastguard Worker model = utils._pre_trace_quant_model(model, export_inputs) 565*da0073e9SAndroid Build Coastguard Worker jit_graph, _, _, _ = utils._create_jit_graph(model, export_inputs) 566*da0073e9SAndroid Build Coastguard Worker return jit_graph 567*da0073e9SAndroid Build Coastguard Worker 568*da0073e9SAndroid Build Coastguard Worker 569*da0073e9SAndroid Build Coastguard Workerdef _onnx_graph_from_model( 570*da0073e9SAndroid Build Coastguard Worker model: torch.nn.Module | torch.jit.ScriptModule, 571*da0073e9SAndroid Build Coastguard Worker args: tuple[Any, ...], 572*da0073e9SAndroid Build Coastguard Worker kwargs: Mapping[str, Any], 573*da0073e9SAndroid Build Coastguard Worker export_options: _experimental.ExportOptions, 574*da0073e9SAndroid Build Coastguard Worker) -> _C.Graph: 575*da0073e9SAndroid Build Coastguard Worker """As part of the ONNX export steps, export an ONNX JIT graph from a PyTorch model. 576*da0073e9SAndroid Build Coastguard Worker 577*da0073e9SAndroid Build Coastguard Worker Args: 578*da0073e9SAndroid Build Coastguard Worker model: See :func:`check_export_model_diff`. 579*da0073e9SAndroid Build Coastguard Worker args: See :func:`check_export_model_diff`. 580*da0073e9SAndroid Build Coastguard Worker kwargs: See :func:`check_export_model_diff`. 581*da0073e9SAndroid Build Coastguard Worker export_options: See :func:`check_export_model_diff`. 582*da0073e9SAndroid Build Coastguard Worker 583*da0073e9SAndroid Build Coastguard Worker Returns: 584*da0073e9SAndroid Build Coastguard Worker onnx_graph (_C.Graph): An ONNX JIT graph. 585*da0073e9SAndroid Build Coastguard Worker """ 586*da0073e9SAndroid Build Coastguard Worker # TODO: refactor utils.py to remove duplicated code of context setup. See #78834 587*da0073e9SAndroid Build Coastguard Worker opset_version = export_options.opset_version 588*da0073e9SAndroid Build Coastguard Worker operator_export_type = export_options.operator_export_type 589*da0073e9SAndroid Build Coastguard Worker export_modules_as_functions = export_options.export_modules_as_functions 590*da0073e9SAndroid Build Coastguard Worker training = export_options.training 591*da0073e9SAndroid Build Coastguard Worker verbose = export_options.verbose 592*da0073e9SAndroid Build Coastguard Worker dynamic_axes = export_options.dynamic_axes 593*da0073e9SAndroid Build Coastguard Worker input_names = export_options.input_names 594*da0073e9SAndroid Build Coastguard Worker output_names = export_options.output_names 595*da0073e9SAndroid Build Coastguard Worker 596*da0073e9SAndroid Build Coastguard Worker if opset_version is None: 597*da0073e9SAndroid Build Coastguard Worker opset_version = _constants.ONNX_DEFAULT_OPSET 598*da0073e9SAndroid Build Coastguard Worker 599*da0073e9SAndroid Build Coastguard Worker utils._setup_trace_module_map(model, export_modules_as_functions) 600*da0073e9SAndroid Build Coastguard Worker 601*da0073e9SAndroid Build Coastguard Worker if not operator_export_type: 602*da0073e9SAndroid Build Coastguard Worker operator_export_type = _C_onnx.OperatorExportTypes.ONNX 603*da0073e9SAndroid Build Coastguard Worker 604*da0073e9SAndroid Build Coastguard Worker GLOBALS.export_onnx_opset_version = opset_version 605*da0073e9SAndroid Build Coastguard Worker GLOBALS.operator_export_type = operator_export_type 606*da0073e9SAndroid Build Coastguard Worker 607*da0073e9SAndroid Build Coastguard Worker with utils.exporter_context(model, training, verbose): 608*da0073e9SAndroid Build Coastguard Worker do_constant_folding = utils._decide_constant_folding( 609*da0073e9SAndroid Build Coastguard Worker export_options.do_constant_folding, operator_export_type, training 610*da0073e9SAndroid Build Coastguard Worker ) 611*da0073e9SAndroid Build Coastguard Worker 612*da0073e9SAndroid Build Coastguard Worker if dynamic_axes is None: 613*da0073e9SAndroid Build Coastguard Worker dynamic_axes = {} 614*da0073e9SAndroid Build Coastguard Worker utils._validate_dynamic_axes(dynamic_axes, model, input_names, output_names) 615*da0073e9SAndroid Build Coastguard Worker 616*da0073e9SAndroid Build Coastguard Worker export_inputs = _prepare_input_for_export(args, kwargs) 617*da0073e9SAndroid Build Coastguard Worker export_inputs = utils._decide_input_format(model, export_inputs) 618*da0073e9SAndroid Build Coastguard Worker onnx_graph, _, _ = utils._model_to_graph( 619*da0073e9SAndroid Build Coastguard Worker model, 620*da0073e9SAndroid Build Coastguard Worker export_inputs, 621*da0073e9SAndroid Build Coastguard Worker verbose, 622*da0073e9SAndroid Build Coastguard Worker input_names, 623*da0073e9SAndroid Build Coastguard Worker output_names, 624*da0073e9SAndroid Build Coastguard Worker operator_export_type, 625*da0073e9SAndroid Build Coastguard Worker do_constant_folding, 626*da0073e9SAndroid Build Coastguard Worker training=training, 627*da0073e9SAndroid Build Coastguard Worker dynamic_axes=dynamic_axes, 628*da0073e9SAndroid Build Coastguard Worker ) 629*da0073e9SAndroid Build Coastguard Worker 630*da0073e9SAndroid Build Coastguard Worker return onnx_graph 631*da0073e9SAndroid Build Coastguard Worker 632*da0073e9SAndroid Build Coastguard Worker 633*da0073e9SAndroid Build Coastguard Workerdef _onnx_graph_from_aten_graph( 634*da0073e9SAndroid Build Coastguard Worker graph: torch.Graph, 635*da0073e9SAndroid Build Coastguard Worker export_options: _experimental.ExportOptions, 636*da0073e9SAndroid Build Coastguard Worker params_dict: dict[str, Any] | None = None, 637*da0073e9SAndroid Build Coastguard Worker) -> tuple[torch.Graph, dict[str, Any]]: 638*da0073e9SAndroid Build Coastguard Worker if params_dict is None: 639*da0073e9SAndroid Build Coastguard Worker params_dict = {} 640*da0073e9SAndroid Build Coastguard Worker operator_export_type = export_options.operator_export_type 641*da0073e9SAndroid Build Coastguard Worker dynamic_axes = export_options.dynamic_axes or {} 642*da0073e9SAndroid Build Coastguard Worker input_names = export_options.input_names 643*da0073e9SAndroid Build Coastguard Worker training = export_options.training 644*da0073e9SAndroid Build Coastguard Worker do_constant_folding = export_options.do_constant_folding 645*da0073e9SAndroid Build Coastguard Worker opset_version = export_options.opset_version or _constants.ONNX_DEFAULT_OPSET 646*da0073e9SAndroid Build Coastguard Worker 647*da0073e9SAndroid Build Coastguard Worker GLOBALS.export_onnx_opset_version = opset_version 648*da0073e9SAndroid Build Coastguard Worker GLOBALS.operator_export_type = operator_export_type 649*da0073e9SAndroid Build Coastguard Worker 650*da0073e9SAndroid Build Coastguard Worker do_constant_folding = utils._decide_constant_folding( 651*da0073e9SAndroid Build Coastguard Worker do_constant_folding, operator_export_type, training 652*da0073e9SAndroid Build Coastguard Worker ) 653*da0073e9SAndroid Build Coastguard Worker 654*da0073e9SAndroid Build Coastguard Worker # TODO: Below is doing aten graph to onnx. It should be abstracted as a 655*da0073e9SAndroid Build Coastguard Worker # function in torch/onnx/utils.py. 656*da0073e9SAndroid Build Coastguard Worker graph = graph.copy() 657*da0073e9SAndroid Build Coastguard Worker graph = utils._optimize_graph( 658*da0073e9SAndroid Build Coastguard Worker graph, 659*da0073e9SAndroid Build Coastguard Worker operator_export_type, 660*da0073e9SAndroid Build Coastguard Worker params_dict=params_dict, 661*da0073e9SAndroid Build Coastguard Worker dynamic_axes=dynamic_axes, 662*da0073e9SAndroid Build Coastguard Worker input_names=input_names, 663*da0073e9SAndroid Build Coastguard Worker ) 664*da0073e9SAndroid Build Coastguard Worker 665*da0073e9SAndroid Build Coastguard Worker if training is None or training == _C_onnx.TrainingMode.EVAL: 666*da0073e9SAndroid Build Coastguard Worker params_dict = torch._C._jit_pass_onnx_eval_peephole(graph, params_dict) 667*da0073e9SAndroid Build Coastguard Worker 668*da0073e9SAndroid Build Coastguard Worker if ( 669*da0073e9SAndroid Build Coastguard Worker do_constant_folding 670*da0073e9SAndroid Build Coastguard Worker and opset_version >= _constants.ONNX_CONSTANT_FOLDING_MIN_OPSET 671*da0073e9SAndroid Build Coastguard Worker ): 672*da0073e9SAndroid Build Coastguard Worker params_dict = _C._jit_pass_onnx_constant_fold(graph, params_dict, opset_version) 673*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) 674*da0073e9SAndroid Build Coastguard Worker 675*da0073e9SAndroid Build Coastguard Worker if GLOBALS.onnx_shape_inference: 676*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_onnx_graph_shape_type_inference(graph, params_dict, opset_version) 677*da0073e9SAndroid Build Coastguard Worker 678*da0073e9SAndroid Build Coastguard Worker params_dict = _C._jit_pass_onnx_eliminate_unused_items(graph, params_dict) 679*da0073e9SAndroid Build Coastguard Worker 680*da0073e9SAndroid Build Coastguard Worker # For ONNX opset < 9, constants only have three data types: float16, float, double. 681*da0073e9SAndroid Build Coastguard Worker # In this pass transform constants of other data types to float/double + cast operator. 682*da0073e9SAndroid Build Coastguard Worker if opset_version < 9: 683*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_onnx_cast_all_constant_to_floating(graph) 684*da0073e9SAndroid Build Coastguard Worker 685*da0073e9SAndroid Build Coastguard Worker params_dict = _C._jit_pass_filter_non_tensor_arguments(params_dict) 686*da0073e9SAndroid Build Coastguard Worker _C._jit_decay_packed_param_input_types(graph) 687*da0073e9SAndroid Build Coastguard Worker 688*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) 689*da0073e9SAndroid Build Coastguard Worker 690*da0073e9SAndroid Build Coastguard Worker if export_options.verbose: 691*da0073e9SAndroid Build Coastguard Worker print("ONNX graph: ", graph) 692*da0073e9SAndroid Build Coastguard Worker 693*da0073e9SAndroid Build Coastguard Worker return graph, params_dict 694*da0073e9SAndroid Build Coastguard Worker 695*da0073e9SAndroid Build Coastguard Worker 696*da0073e9SAndroid Build Coastguard Workerdef _onnx_proto_from_onnx_graph( 697*da0073e9SAndroid Build Coastguard Worker onnx_graph: torch.Graph, 698*da0073e9SAndroid Build Coastguard Worker export_options: _experimental.ExportOptions, 699*da0073e9SAndroid Build Coastguard Worker params_dict: dict[str, Any], 700*da0073e9SAndroid Build Coastguard Worker) -> tuple[bytes, Mapping[str, bytes]]: 701*da0073e9SAndroid Build Coastguard Worker opset_version = export_options.opset_version or _constants.ONNX_DEFAULT_OPSET 702*da0073e9SAndroid Build Coastguard Worker dynamic_axes = export_options.dynamic_axes or {} 703*da0073e9SAndroid Build Coastguard Worker operator_export_type = export_options.operator_export_type 704*da0073e9SAndroid Build Coastguard Worker val_keep_init_as_ip = utils._decide_keep_init_as_input( 705*da0073e9SAndroid Build Coastguard Worker export_options.keep_initializers_as_inputs, 706*da0073e9SAndroid Build Coastguard Worker operator_export_type, 707*da0073e9SAndroid Build Coastguard Worker opset_version, 708*da0073e9SAndroid Build Coastguard Worker ) 709*da0073e9SAndroid Build Coastguard Worker val_add_node_names = utils._decide_add_node_names(True, operator_export_type) 710*da0073e9SAndroid Build Coastguard Worker custom_opsets = export_options.custom_opsets or {} 711*da0073e9SAndroid Build Coastguard Worker 712*da0073e9SAndroid Build Coastguard Worker proto, export_map, _, _ = onnx_graph._export_onnx( # type: ignore[attr-defined] 713*da0073e9SAndroid Build Coastguard Worker params_dict, 714*da0073e9SAndroid Build Coastguard Worker opset_version, 715*da0073e9SAndroid Build Coastguard Worker dynamic_axes, 716*da0073e9SAndroid Build Coastguard Worker False, 717*da0073e9SAndroid Build Coastguard Worker operator_export_type, 718*da0073e9SAndroid Build Coastguard Worker not export_options.verbose, 719*da0073e9SAndroid Build Coastguard Worker val_keep_init_as_ip, 720*da0073e9SAndroid Build Coastguard Worker custom_opsets, 721*da0073e9SAndroid Build Coastguard Worker val_add_node_names, 722*da0073e9SAndroid Build Coastguard Worker "", 723*da0073e9SAndroid Build Coastguard Worker {}, 724*da0073e9SAndroid Build Coastguard Worker ) 725*da0073e9SAndroid Build Coastguard Worker 726*da0073e9SAndroid Build Coastguard Worker return proto, export_map 727*da0073e9SAndroid Build Coastguard Worker 728*da0073e9SAndroid Build Coastguard Worker 729*da0073e9SAndroid Build Coastguard Workerdef check_export_model_diff( 730*da0073e9SAndroid Build Coastguard Worker model: torch.nn.Module | torch.jit.ScriptModule, 731*da0073e9SAndroid Build Coastguard Worker test_input_groups: Sequence[tuple[tuple[Any, ...], Mapping[str, Any]]], 732*da0073e9SAndroid Build Coastguard Worker export_options: _experimental.ExportOptions | None = None, 733*da0073e9SAndroid Build Coastguard Worker) -> str: 734*da0073e9SAndroid Build Coastguard Worker """Verify exported model discrepancy between different groups of inputs. 735*da0073e9SAndroid Build Coastguard Worker 736*da0073e9SAndroid Build Coastguard Worker A graph is exported for each group of inputs. The exported graphs are then compared 737*da0073e9SAndroid Build Coastguard Worker to each other, and discrepancies of first pair of nodes are reported. This function 738*da0073e9SAndroid Build Coastguard Worker first checks the jit graph. If no discrepancies were found, it then checks the onnx 739*da0073e9SAndroid Build Coastguard Worker graph. 740*da0073e9SAndroid Build Coastguard Worker 741*da0073e9SAndroid Build Coastguard Worker Unless otherwise specified, the jit/ONNX graph is expected to be the same, regardless 742*da0073e9SAndroid Build Coastguard Worker of the inputs used for exporting. A discrepancy implies the graph exported is 743*da0073e9SAndroid Build Coastguard Worker not accurate when run on other groups of inputs, which will typically results in 744*da0073e9SAndroid Build Coastguard Worker runtime errors or mismatching output. 745*da0073e9SAndroid Build Coastguard Worker 746*da0073e9SAndroid Build Coastguard Worker Args: 747*da0073e9SAndroid Build Coastguard Worker model (torch.nn.Module or torch.jit.ScriptModule): The model to be exported. 748*da0073e9SAndroid Build Coastguard Worker test_input_groups (Sequence[Tuple[Tuple[Any, ...], Mapping[str, Any]]]): A sequence 749*da0073e9SAndroid Build Coastguard Worker of input groups to be used to export the model. Each input group is a pair of 750*da0073e9SAndroid Build Coastguard Worker (args, kwargs). 751*da0073e9SAndroid Build Coastguard Worker export_options (_experimental.ExportOptions, optional): An _experimental.ExportOptions 752*da0073e9SAndroid Build Coastguard Worker object that controls the export behavior. 753*da0073e9SAndroid Build Coastguard Worker 754*da0073e9SAndroid Build Coastguard Worker Returns: 755*da0073e9SAndroid Build Coastguard Worker str: A string containing the diff of the exported models. 756*da0073e9SAndroid Build Coastguard Worker """ 757*da0073e9SAndroid Build Coastguard Worker export_options = ( 758*da0073e9SAndroid Build Coastguard Worker _experimental.ExportOptions() if export_options is None else export_options 759*da0073e9SAndroid Build Coastguard Worker ) 760*da0073e9SAndroid Build Coastguard Worker 761*da0073e9SAndroid Build Coastguard Worker jit_diff_report = _check_graph_diff( 762*da0073e9SAndroid Build Coastguard Worker model, test_input_groups, export_options, _traced_graph_from_model 763*da0073e9SAndroid Build Coastguard Worker ) 764*da0073e9SAndroid Build Coastguard Worker if jit_diff_report: 765*da0073e9SAndroid Build Coastguard Worker return jit_diff_report 766*da0073e9SAndroid Build Coastguard Worker 767*da0073e9SAndroid Build Coastguard Worker return _check_graph_diff( 768*da0073e9SAndroid Build Coastguard Worker model, test_input_groups, export_options, _onnx_graph_from_model 769*da0073e9SAndroid Build Coastguard Worker ) 770*da0073e9SAndroid Build Coastguard Worker 771*da0073e9SAndroid Build Coastguard Worker 772*da0073e9SAndroid Build Coastguard Workerdef verify( 773*da0073e9SAndroid Build Coastguard Worker model: _ModelType, 774*da0073e9SAndroid Build Coastguard Worker input_args: _InputArgsType, 775*da0073e9SAndroid Build Coastguard Worker input_kwargs: _InputKwargsType | None = None, 776*da0073e9SAndroid Build Coastguard Worker do_constant_folding: bool = True, 777*da0073e9SAndroid Build Coastguard Worker dynamic_axes: Mapping[str, Mapping[int, str] | Mapping[str, Sequence[int]]] 778*da0073e9SAndroid Build Coastguard Worker | None = None, 779*da0073e9SAndroid Build Coastguard Worker input_names: Sequence[str] | None = None, 780*da0073e9SAndroid Build Coastguard Worker output_names: Sequence[str] | None = None, 781*da0073e9SAndroid Build Coastguard Worker training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL, 782*da0073e9SAndroid Build Coastguard Worker opset_version: int | None = None, 783*da0073e9SAndroid Build Coastguard Worker keep_initializers_as_inputs: bool = True, 784*da0073e9SAndroid Build Coastguard Worker verbose: bool = False, 785*da0073e9SAndroid Build Coastguard Worker fixed_batch_size: bool = False, 786*da0073e9SAndroid Build Coastguard Worker use_external_data: bool = False, 787*da0073e9SAndroid Build Coastguard Worker additional_test_inputs: Sequence[_InputArgsType] | None = None, 788*da0073e9SAndroid Build Coastguard Worker options: VerificationOptions | None = None, 789*da0073e9SAndroid Build Coastguard Worker): 790*da0073e9SAndroid Build Coastguard Worker """Verify model export to ONNX against original PyTorch model. 791*da0073e9SAndroid Build Coastguard Worker 792*da0073e9SAndroid Build Coastguard Worker Args: 793*da0073e9SAndroid Build Coastguard Worker model (torch.nn.Module or torch.jit.ScriptModule): See :func:`torch.onnx.export`. 794*da0073e9SAndroid Build Coastguard Worker input_args (tuple): See :func:`torch.onnx.export`. 795*da0073e9SAndroid Build Coastguard Worker input_kwargs (dict): See :func:`torch.onnx.export`. 796*da0073e9SAndroid Build Coastguard Worker do_constant_folding (bool, optional): See :func:`torch.onnx.export`. 797*da0073e9SAndroid Build Coastguard Worker dynamic_axes (dict, optional): See :func:`torch.onnx.export`. 798*da0073e9SAndroid Build Coastguard Worker input_names (list, optional): See :func:`torch.onnx.export`. 799*da0073e9SAndroid Build Coastguard Worker output_names (list, optional): See :func:`torch.onnx.export`. 800*da0073e9SAndroid Build Coastguard Worker training (torch.onnx.TrainingMode): See :func:`torch.onnx.export`. 801*da0073e9SAndroid Build Coastguard Worker opset_version (int, optional): See :func:`torch.onnx.export`. 802*da0073e9SAndroid Build Coastguard Worker keep_initializers_as_inputs (bool, optional): See :func:`torch.onnx.export`. 803*da0073e9SAndroid Build Coastguard Worker verbose (bool, optional): See :func:`torch.onnx.export`. 804*da0073e9SAndroid Build Coastguard Worker fixed_batch_size (bool, optional): Legacy argument, used only by rnn test cases. 805*da0073e9SAndroid Build Coastguard Worker use_external_data (bool, optional): Explicitly specify whether to export the 806*da0073e9SAndroid Build Coastguard Worker model with external data. 807*da0073e9SAndroid Build Coastguard Worker additional_test_inputs (list, optional): List of tuples. Each tuple is a group of 808*da0073e9SAndroid Build Coastguard Worker input arguments to test. Currently only *args are supported. 809*da0073e9SAndroid Build Coastguard Worker options (_VerificationOptions, optional): A _VerificationOptions object that 810*da0073e9SAndroid Build Coastguard Worker controls the verification behavior. 811*da0073e9SAndroid Build Coastguard Worker 812*da0073e9SAndroid Build Coastguard Worker Raises: 813*da0073e9SAndroid Build Coastguard Worker AssertionError: if outputs from ONNX model and PyTorch model are not 814*da0073e9SAndroid Build Coastguard Worker equal up to specified precision. 815*da0073e9SAndroid Build Coastguard Worker ValueError: if arguments provided are invalid. 816*da0073e9SAndroid Build Coastguard Worker """ 817*da0073e9SAndroid Build Coastguard Worker if options is None: 818*da0073e9SAndroid Build Coastguard Worker options = VerificationOptions() 819*da0073e9SAndroid Build Coastguard Worker 820*da0073e9SAndroid Build Coastguard Worker if training == torch.onnx.TrainingMode.TRAINING: 821*da0073e9SAndroid Build Coastguard Worker model.train() 822*da0073e9SAndroid Build Coastguard Worker elif training == torch.onnx.TrainingMode.EVAL: 823*da0073e9SAndroid Build Coastguard Worker model.eval() 824*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(), contextlib.ExitStack() as stack: 825*da0073e9SAndroid Build Coastguard Worker model_f: str | io.BytesIO = io.BytesIO() 826*da0073e9SAndroid Build Coastguard Worker if use_external_data: 827*da0073e9SAndroid Build Coastguard Worker tmpdir_path = stack.enter_context(tempfile.TemporaryDirectory()) 828*da0073e9SAndroid Build Coastguard Worker model_f = os.path.join(tmpdir_path, "model.onnx") 829*da0073e9SAndroid Build Coastguard Worker 830*da0073e9SAndroid Build Coastguard Worker inputs_for_export = _prepare_input_for_export(input_args, input_kwargs) 831*da0073e9SAndroid Build Coastguard Worker 832*da0073e9SAndroid Build Coastguard Worker # TODO(#77679): remove this and treat mutating model separately. 833*da0073e9SAndroid Build Coastguard Worker model_copy = _try_clone_model(model) 834*da0073e9SAndroid Build Coastguard Worker utils._export( 835*da0073e9SAndroid Build Coastguard Worker model, 836*da0073e9SAndroid Build Coastguard Worker inputs_for_export, 837*da0073e9SAndroid Build Coastguard Worker model_f, 838*da0073e9SAndroid Build Coastguard Worker opset_version=opset_version, 839*da0073e9SAndroid Build Coastguard Worker do_constant_folding=do_constant_folding, 840*da0073e9SAndroid Build Coastguard Worker keep_initializers_as_inputs=keep_initializers_as_inputs, 841*da0073e9SAndroid Build Coastguard Worker dynamic_axes=dynamic_axes, 842*da0073e9SAndroid Build Coastguard Worker input_names=input_names, 843*da0073e9SAndroid Build Coastguard Worker output_names=output_names, 844*da0073e9SAndroid Build Coastguard Worker fixed_batch_size=fixed_batch_size, 845*da0073e9SAndroid Build Coastguard Worker training=training, 846*da0073e9SAndroid Build Coastguard Worker verbose=verbose, 847*da0073e9SAndroid Build Coastguard Worker ) 848*da0073e9SAndroid Build Coastguard Worker 849*da0073e9SAndroid Build Coastguard Worker _compare_onnx_pytorch_model( 850*da0073e9SAndroid Build Coastguard Worker pt_model=model_copy, 851*da0073e9SAndroid Build Coastguard Worker onnx_model_f=model_f, 852*da0073e9SAndroid Build Coastguard Worker input_args=input_args, 853*da0073e9SAndroid Build Coastguard Worker input_kwargs=input_kwargs, 854*da0073e9SAndroid Build Coastguard Worker additional_test_inputs=additional_test_inputs, 855*da0073e9SAndroid Build Coastguard Worker options=options, 856*da0073e9SAndroid Build Coastguard Worker ) 857*da0073e9SAndroid Build Coastguard Worker 858*da0073e9SAndroid Build Coastguard Worker 859*da0073e9SAndroid Build Coastguard Workerdef verify_aten_graph( 860*da0073e9SAndroid Build Coastguard Worker graph: torch.Graph, 861*da0073e9SAndroid Build Coastguard Worker input_args: tuple[Any, ...], 862*da0073e9SAndroid Build Coastguard Worker export_options: _experimental.ExportOptions, 863*da0073e9SAndroid Build Coastguard Worker params_dict: dict[str, Any] | None = None, 864*da0073e9SAndroid Build Coastguard Worker verification_options: VerificationOptions | None = None, 865*da0073e9SAndroid Build Coastguard Worker) -> tuple[AssertionError | None, torch.Graph, _OutputsType, _OutputsType]: 866*da0073e9SAndroid Build Coastguard Worker if verification_options is None: 867*da0073e9SAndroid Build Coastguard Worker verification_options = VerificationOptions() 868*da0073e9SAndroid Build Coastguard Worker if params_dict is None: 869*da0073e9SAndroid Build Coastguard Worker params_dict = {} 870*da0073e9SAndroid Build Coastguard Worker 871*da0073e9SAndroid Build Coastguard Worker original_jit_graph = graph 872*da0073e9SAndroid Build Coastguard Worker graph = graph.copy() 873*da0073e9SAndroid Build Coastguard Worker 874*da0073e9SAndroid Build Coastguard Worker # Execute aten graph and get reference torch jit outputs. 875*da0073e9SAndroid Build Coastguard Worker graph_inputs = list(graph.inputs()) 876*da0073e9SAndroid Build Coastguard Worker jit_inputs = tuple([arg for arg in input_args if arg is not None]) 877*da0073e9SAndroid Build Coastguard Worker weights = [params_dict[v.debugName()] for v in graph_inputs[len(jit_inputs) :]] 878*da0073e9SAndroid Build Coastguard Worker assert all(w is not None for w in weights) 879*da0073e9SAndroid Build Coastguard Worker # TODO: Only copy the argument if mutation is detected in Graph. 880*da0073e9SAndroid Build Coastguard Worker jit_inputs = copy.deepcopy(jit_inputs) 881*da0073e9SAndroid Build Coastguard Worker jit_input_and_parameters = jit_inputs + tuple(weights) 882*da0073e9SAndroid Build Coastguard Worker jit_outs = torch._C._jit_interpret_graph(graph, jit_input_and_parameters) # type: ignore[attr-defined] 883*da0073e9SAndroid Build Coastguard Worker if not isinstance(jit_outs, (list, tuple)): 884*da0073e9SAndroid Build Coastguard Worker jit_outs = [jit_outs] 885*da0073e9SAndroid Build Coastguard Worker 886*da0073e9SAndroid Build Coastguard Worker # Convert aten graph to onnx graph. 887*da0073e9SAndroid Build Coastguard Worker graph, onnx_params_dict = _onnx_graph_from_aten_graph( 888*da0073e9SAndroid Build Coastguard Worker graph, export_options, params_dict 889*da0073e9SAndroid Build Coastguard Worker ) 890*da0073e9SAndroid Build Coastguard Worker 891*da0073e9SAndroid Build Coastguard Worker proto, export_map = _onnx_proto_from_onnx_graph( 892*da0073e9SAndroid Build Coastguard Worker graph, export_options, onnx_params_dict 893*da0073e9SAndroid Build Coastguard Worker ) 894*da0073e9SAndroid Build Coastguard Worker model_f: str | io.BytesIO = io.BytesIO() 895*da0073e9SAndroid Build Coastguard Worker export_type = _exporter_states.ExportTypes.PROTOBUF_FILE 896*da0073e9SAndroid Build Coastguard Worker onnx_proto_utils._export_file(proto, model_f, export_type, export_map) 897*da0073e9SAndroid Build Coastguard Worker 898*da0073e9SAndroid Build Coastguard Worker # NOTE: Verification is unstable. Try catch to emit information for debugging. 899*da0073e9SAndroid Build Coastguard Worker try: 900*da0073e9SAndroid Build Coastguard Worker # NOTE: Input might be dce'ed, so we need to remove those from the input args. 901*da0073e9SAndroid Build Coastguard Worker new_input_names = {v.debugName() for v in graph.inputs()} 902*da0073e9SAndroid Build Coastguard Worker new_input_args = [] 903*da0073e9SAndroid Build Coastguard Worker for v, arg in zip(original_jit_graph.inputs(), input_args): 904*da0073e9SAndroid Build Coastguard Worker if v.debugName() in new_input_names: 905*da0073e9SAndroid Build Coastguard Worker new_input_args.append(arg) 906*da0073e9SAndroid Build Coastguard Worker input_args = tuple(new_input_args) 907*da0073e9SAndroid Build Coastguard Worker 908*da0073e9SAndroid Build Coastguard Worker onnx_inputs = _prepare_input_for_onnx( 909*da0073e9SAndroid Build Coastguard Worker input_args, 910*da0073e9SAndroid Build Coastguard Worker {}, 911*da0073e9SAndroid Build Coastguard Worker verification_options.remained_onnx_input_idx, 912*da0073e9SAndroid Build Coastguard Worker verification_options.flatten, 913*da0073e9SAndroid Build Coastguard Worker ) 914*da0073e9SAndroid Build Coastguard Worker 915*da0073e9SAndroid Build Coastguard Worker onnx_session = _onnx_backend_session(model_f, verification_options.backend) 916*da0073e9SAndroid Build Coastguard Worker onnx_outs = _run_onnx(onnx_session, onnx_inputs) 917*da0073e9SAndroid Build Coastguard Worker del onnx_session # To free device memory 918*da0073e9SAndroid Build Coastguard Worker 919*da0073e9SAndroid Build Coastguard Worker try: 920*da0073e9SAndroid Build Coastguard Worker _compare_onnx_pytorch_outputs( 921*da0073e9SAndroid Build Coastguard Worker onnx_outs=onnx_outs, 922*da0073e9SAndroid Build Coastguard Worker pt_outs=jit_outs, 923*da0073e9SAndroid Build Coastguard Worker options=verification_options, 924*da0073e9SAndroid Build Coastguard Worker ) 925*da0073e9SAndroid Build Coastguard Worker except AssertionError as e: 926*da0073e9SAndroid Build Coastguard Worker return e, graph, jit_outs, onnx_outs 927*da0073e9SAndroid Build Coastguard Worker 928*da0073e9SAndroid Build Coastguard Worker return None, graph, jit_outs, onnx_outs 929*da0073e9SAndroid Build Coastguard Worker 930*da0073e9SAndroid Build Coastguard Worker except Exception as e: 931*da0073e9SAndroid Build Coastguard Worker print("Unexpected error during verification.") 932*da0073e9SAndroid Build Coastguard Worker print("jit graph: ", original_jit_graph) 933*da0073e9SAndroid Build Coastguard Worker print("onnx graph: ", graph) 934*da0073e9SAndroid Build Coastguard Worker raise e 935*da0073e9SAndroid Build Coastguard Worker 936*da0073e9SAndroid Build Coastguard Worker 937*da0073e9SAndroid Build Coastguard Workerclass GraphInfoPrettyPrinter: 938*da0073e9SAndroid Build Coastguard Worker graph_info: GraphInfo | None 939*da0073e9SAndroid Build Coastguard Worker upper_printer: GraphInfoPrettyPrinter | None 940*da0073e9SAndroid Build Coastguard Worker lower_printer: GraphInfoPrettyPrinter | None 941*da0073e9SAndroid Build Coastguard Worker 942*da0073e9SAndroid Build Coastguard Worker graph_str_lambdas: Mapping[int, str] 943*da0073e9SAndroid Build Coastguard Worker connector_str_lambdas: Mapping[int, str] 944*da0073e9SAndroid Build Coastguard Worker children_str_lambdas: Mapping[int, str] 945*da0073e9SAndroid Build Coastguard Worker 946*da0073e9SAndroid Build Coastguard Worker def __init__(self, graph_info: GraphInfo | None): 947*da0073e9SAndroid Build Coastguard Worker self.graph_info = graph_info 948*da0073e9SAndroid Build Coastguard Worker if ( 949*da0073e9SAndroid Build Coastguard Worker graph_info is not None 950*da0073e9SAndroid Build Coastguard Worker and graph_info.upper_graph_info is not None 951*da0073e9SAndroid Build Coastguard Worker and graph_info.lower_graph_info is not None 952*da0073e9SAndroid Build Coastguard Worker ): 953*da0073e9SAndroid Build Coastguard Worker self.upper_printer = GraphInfoPrettyPrinter(graph_info.upper_graph_info) 954*da0073e9SAndroid Build Coastguard Worker self.lower_printer = GraphInfoPrettyPrinter(graph_info.lower_graph_info) 955*da0073e9SAndroid Build Coastguard Worker else: 956*da0073e9SAndroid Build Coastguard Worker self.upper_printer = None 957*da0073e9SAndroid Build Coastguard Worker self.lower_printer = None 958*da0073e9SAndroid Build Coastguard Worker 959*da0073e9SAndroid Build Coastguard Worker def _total_rows(self) -> int: 960*da0073e9SAndroid Build Coastguard Worker if self.graph_info is None: 961*da0073e9SAndroid Build Coastguard Worker return 1 962*da0073e9SAndroid Build Coastguard Worker if self.upper_printer and self.lower_printer: 963*da0073e9SAndroid Build Coastguard Worker return ( 964*da0073e9SAndroid Build Coastguard Worker self.upper_printer._total_rows() + self.lower_printer._total_rows() + 1 965*da0073e9SAndroid Build Coastguard Worker ) 966*da0073e9SAndroid Build Coastguard Worker return 2 # Two lines: node count + id. 967*da0073e9SAndroid Build Coastguard Worker 968*da0073e9SAndroid Build Coastguard Worker def _node_count_segment_str(self) -> str: 969*da0073e9SAndroid Build Coastguard Worker if self.graph_info is None: 970*da0073e9SAndroid Build Coastguard Worker return "..." 971*da0073e9SAndroid Build Coastguard Worker node_count = self.graph_info.essential_node_count() 972*da0073e9SAndroid Build Coastguard Worker has_mismatch = self.graph_info.has_mismatch() 973*da0073e9SAndroid Build Coastguard Worker error_node_kind = ( 974*da0073e9SAndroid Build Coastguard Worker f"({self.graph_info.essential_node_kinds().pop()})" 975*da0073e9SAndroid Build Coastguard Worker if node_count == 1 and has_mismatch 976*da0073e9SAndroid Build Coastguard Worker else "" 977*da0073e9SAndroid Build Coastguard Worker ) 978*da0073e9SAndroid Build Coastguard Worker 979*da0073e9SAndroid Build Coastguard Worker return f"{node_count} {'X' if has_mismatch else chr(0x2713)} {error_node_kind}" 980*da0073e9SAndroid Build Coastguard Worker 981*da0073e9SAndroid Build Coastguard Worker def _graph_id_segment_str(self) -> str: 982*da0073e9SAndroid Build Coastguard Worker if self.graph_info is None: 983*da0073e9SAndroid Build Coastguard Worker return "" 984*da0073e9SAndroid Build Coastguard Worker return f"id: {self.graph_info.id}" 985*da0073e9SAndroid Build Coastguard Worker 986*da0073e9SAndroid Build Coastguard Worker def _max_segment_columns(self) -> int: 987*da0073e9SAndroid Build Coastguard Worker return max( 988*da0073e9SAndroid Build Coastguard Worker map(len, (self._node_count_segment_str(), self._graph_id_segment_str())) 989*da0073e9SAndroid Build Coastguard Worker ) 990*da0073e9SAndroid Build Coastguard Worker 991*da0073e9SAndroid Build Coastguard Worker def _graph_segment_str_at_line(self, line: int) -> str: 992*da0073e9SAndroid Build Coastguard Worker """Get the string representation of the graph segment at the given line.""" 993*da0073e9SAndroid Build Coastguard Worker if line == 0: 994*da0073e9SAndroid Build Coastguard Worker result_str = self._node_count_segment_str() 995*da0073e9SAndroid Build Coastguard Worker result_str += " " * (self._max_segment_columns() - len(result_str)) 996*da0073e9SAndroid Build Coastguard Worker return result_str 997*da0073e9SAndroid Build Coastguard Worker if line == 1: 998*da0073e9SAndroid Build Coastguard Worker result_str = self._graph_id_segment_str() 999*da0073e9SAndroid Build Coastguard Worker result_str += " " * (self._max_segment_columns() - len(result_str)) 1000*da0073e9SAndroid Build Coastguard Worker return result_str 1001*da0073e9SAndroid Build Coastguard Worker if 0 <= line < self._total_rows(): 1002*da0073e9SAndroid Build Coastguard Worker return " " * self._max_segment_columns() 1003*da0073e9SAndroid Build Coastguard Worker return "" 1004*da0073e9SAndroid Build Coastguard Worker 1005*da0073e9SAndroid Build Coastguard Worker def _connector_segment_str_at_line(self, line: int) -> str: 1006*da0073e9SAndroid Build Coastguard Worker """Get the connector segment string at the given line.""" 1007*da0073e9SAndroid Build Coastguard Worker if self.upper_printer is None and self.lower_printer is None: 1008*da0073e9SAndroid Build Coastguard Worker return "" 1009*da0073e9SAndroid Build Coastguard Worker upper_total_rows = self.upper_printer._total_rows() if self.upper_printer else 1 1010*da0073e9SAndroid Build Coastguard Worker lower_total_rows = self.lower_printer._total_rows() if self.lower_printer else 1 1011*da0073e9SAndroid Build Coastguard Worker if line == 0: 1012*da0073e9SAndroid Build Coastguard Worker return " __" 1013*da0073e9SAndroid Build Coastguard Worker elif line < upper_total_rows + 1: 1014*da0073e9SAndroid Build Coastguard Worker return " | " 1015*da0073e9SAndroid Build Coastguard Worker elif line == upper_total_rows + 1: 1016*da0073e9SAndroid Build Coastguard Worker return " |__" 1017*da0073e9SAndroid Build Coastguard Worker elif line < upper_total_rows + lower_total_rows + 1: 1018*da0073e9SAndroid Build Coastguard Worker return " " 1019*da0073e9SAndroid Build Coastguard Worker return "" 1020*da0073e9SAndroid Build Coastguard Worker 1021*da0073e9SAndroid Build Coastguard Worker def _children_str_at_line(self, line: int) -> str: 1022*da0073e9SAndroid Build Coastguard Worker """Get the string representation of the children at the given line. 1023*da0073e9SAndroid Build Coastguard Worker 1024*da0073e9SAndroid Build Coastguard Worker Recursively calls `_str_at_line` on children nodes. 1025*da0073e9SAndroid Build Coastguard Worker """ 1026*da0073e9SAndroid Build Coastguard Worker if self.upper_printer is None and self.lower_printer is None: 1027*da0073e9SAndroid Build Coastguard Worker return "" 1028*da0073e9SAndroid Build Coastguard Worker upper_total_rows = self.upper_printer._total_rows() if self.upper_printer else 1 1029*da0073e9SAndroid Build Coastguard Worker lower_total_rows = self.lower_printer._total_rows() if self.lower_printer else 1 1030*da0073e9SAndroid Build Coastguard Worker if 0 <= line < upper_total_rows: 1031*da0073e9SAndroid Build Coastguard Worker return ( 1032*da0073e9SAndroid Build Coastguard Worker self.upper_printer._str_at_line(line) if self.upper_printer else "..." 1033*da0073e9SAndroid Build Coastguard Worker ) 1034*da0073e9SAndroid Build Coastguard Worker elif upper_total_rows < line < upper_total_rows + lower_total_rows + 1: 1035*da0073e9SAndroid Build Coastguard Worker return ( 1036*da0073e9SAndroid Build Coastguard Worker self.lower_printer._str_at_line(line - upper_total_rows - 1) 1037*da0073e9SAndroid Build Coastguard Worker if self.lower_printer 1038*da0073e9SAndroid Build Coastguard Worker else "..." 1039*da0073e9SAndroid Build Coastguard Worker ) 1040*da0073e9SAndroid Build Coastguard Worker return "" 1041*da0073e9SAndroid Build Coastguard Worker 1042*da0073e9SAndroid Build Coastguard Worker def _str_at_line(self, line: int) -> str: 1043*da0073e9SAndroid Build Coastguard Worker """Get the string representation of the graph at the given line.""" 1044*da0073e9SAndroid Build Coastguard Worker return ( 1045*da0073e9SAndroid Build Coastguard Worker self._graph_segment_str_at_line(line) 1046*da0073e9SAndroid Build Coastguard Worker + self._connector_segment_str_at_line(line) 1047*da0073e9SAndroid Build Coastguard Worker + self._children_str_at_line(line) 1048*da0073e9SAndroid Build Coastguard Worker ) 1049*da0073e9SAndroid Build Coastguard Worker 1050*da0073e9SAndroid Build Coastguard Worker def pretty_print(self): 1051*da0073e9SAndroid Build Coastguard Worker if self.graph_info is None: 1052*da0073e9SAndroid Build Coastguard Worker print(None) 1053*da0073e9SAndroid Build Coastguard Worker return 1054*da0073e9SAndroid Build Coastguard Worker # Print tree. 1055*da0073e9SAndroid Build Coastguard Worker print(" Tree: ".center(80, "=")) 1056*da0073e9SAndroid Build Coastguard Worker total_rows = self._total_rows() 1057*da0073e9SAndroid Build Coastguard Worker for line in range(total_rows): 1058*da0073e9SAndroid Build Coastguard Worker print(self._str_at_line(line).rstrip()) 1059*da0073e9SAndroid Build Coastguard Worker if self.graph_info.has_mismatch(): 1060*da0073e9SAndroid Build Coastguard Worker # Summarize leaf subgraphs with mismatch. 1061*da0073e9SAndroid Build Coastguard Worker print(" Mismatch leaf subgraphs: ".center(80, "=")) 1062*da0073e9SAndroid Build Coastguard Worker print( 1063*da0073e9SAndroid Build Coastguard Worker [ 1064*da0073e9SAndroid Build Coastguard Worker graph_info.id 1065*da0073e9SAndroid Build Coastguard Worker for graph_info in self.graph_info.all_mismatch_leaf_graph_info() 1066*da0073e9SAndroid Build Coastguard Worker ] 1067*da0073e9SAndroid Build Coastguard Worker ) 1068*da0073e9SAndroid Build Coastguard Worker # Summarize node kinds with mismatch. 1069*da0073e9SAndroid Build Coastguard Worker mismatch_node_kinds: dict[str, int] = {} 1070*da0073e9SAndroid Build Coastguard Worker for graph_info in self.graph_info.all_mismatch_leaf_graph_info(): 1071*da0073e9SAndroid Build Coastguard Worker node_kinds = graph_info.essential_node_kinds() 1072*da0073e9SAndroid Build Coastguard Worker if len(node_kinds) == 1: 1073*da0073e9SAndroid Build Coastguard Worker node_kind = node_kinds.pop() 1074*da0073e9SAndroid Build Coastguard Worker mismatch_node_kinds[node_kind] = ( 1075*da0073e9SAndroid Build Coastguard Worker mismatch_node_kinds.get(node_kind, 0) + 1 1076*da0073e9SAndroid Build Coastguard Worker ) 1077*da0073e9SAndroid Build Coastguard Worker print(" Mismatch node kinds: ".center(80, "=")) 1078*da0073e9SAndroid Build Coastguard Worker print(mismatch_node_kinds) 1079*da0073e9SAndroid Build Coastguard Worker else: 1080*da0073e9SAndroid Build Coastguard Worker print(" No mismatch found. ".center(80, "=")) 1081*da0073e9SAndroid Build Coastguard Worker 1082*da0073e9SAndroid Build Coastguard Worker 1083*da0073e9SAndroid Build Coastguard Workerclass OnnxTestCaseRepro: 1084*da0073e9SAndroid Build Coastguard Worker def __init__(self, repro_dir): 1085*da0073e9SAndroid Build Coastguard Worker self.repro_dir = repro_dir 1086*da0073e9SAndroid Build Coastguard Worker self.proto, self.inputs, self.outputs = onnx_proto_utils.load_test_case( 1087*da0073e9SAndroid Build Coastguard Worker repro_dir 1088*da0073e9SAndroid Build Coastguard Worker ) 1089*da0073e9SAndroid Build Coastguard Worker 1090*da0073e9SAndroid Build Coastguard Worker @classmethod 1091*da0073e9SAndroid Build Coastguard Worker def create_test_case_repro( 1092*da0073e9SAndroid Build Coastguard Worker cls, proto: bytes, inputs, outputs, dir: str, name: str | None = None 1093*da0073e9SAndroid Build Coastguard Worker ): 1094*da0073e9SAndroid Build Coastguard Worker """Create a repro under "{dir}/test_{name}" for an ONNX test case. 1095*da0073e9SAndroid Build Coastguard Worker 1096*da0073e9SAndroid Build Coastguard Worker The test case contains the model and the inputs/outputs data. The directory 1097*da0073e9SAndroid Build Coastguard Worker structure is as follows: 1098*da0073e9SAndroid Build Coastguard Worker 1099*da0073e9SAndroid Build Coastguard Worker dir 1100*da0073e9SAndroid Build Coastguard Worker \u251c\u2500\u2500 test_<name> 1101*da0073e9SAndroid Build Coastguard Worker \u2502 \u251c\u2500\u2500 model.onnx 1102*da0073e9SAndroid Build Coastguard Worker \u2502 \u2514\u2500\u2500 test_data_set_0 1103*da0073e9SAndroid Build Coastguard Worker \u2502 \u251c\u2500\u2500 input_0.pb 1104*da0073e9SAndroid Build Coastguard Worker \u2502 \u251c\u2500\u2500 input_1.pb 1105*da0073e9SAndroid Build Coastguard Worker \u2502 \u251c\u2500\u2500 output_0.pb 1106*da0073e9SAndroid Build Coastguard Worker \u2502 \u2514\u2500\u2500 output_1.pb 1107*da0073e9SAndroid Build Coastguard Worker 1108*da0073e9SAndroid Build Coastguard Worker Args: 1109*da0073e9SAndroid Build Coastguard Worker proto: ONNX model proto. 1110*da0073e9SAndroid Build Coastguard Worker inputs: Inputs to the model. 1111*da0073e9SAndroid Build Coastguard Worker outputs: Outputs of the model. 1112*da0073e9SAndroid Build Coastguard Worker dir: Directory to save the repro. 1113*da0073e9SAndroid Build Coastguard Worker name: Name of the test case. If not specified, a name based on current time 1114*da0073e9SAndroid Build Coastguard Worker will be generated. 1115*da0073e9SAndroid Build Coastguard Worker Returns: 1116*da0073e9SAndroid Build Coastguard Worker Path to the repro. 1117*da0073e9SAndroid Build Coastguard Worker """ 1118*da0073e9SAndroid Build Coastguard Worker if name is None: 1119*da0073e9SAndroid Build Coastguard Worker name = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f") 1120*da0073e9SAndroid Build Coastguard Worker return onnx_proto_utils.export_as_test_case( 1121*da0073e9SAndroid Build Coastguard Worker proto, 1122*da0073e9SAndroid Build Coastguard Worker _to_numpy(inputs), 1123*da0073e9SAndroid Build Coastguard Worker _to_numpy(outputs), 1124*da0073e9SAndroid Build Coastguard Worker name, 1125*da0073e9SAndroid Build Coastguard Worker dir, 1126*da0073e9SAndroid Build Coastguard Worker ) 1127*da0073e9SAndroid Build Coastguard Worker 1128*da0073e9SAndroid Build Coastguard Worker def validate(self, options: VerificationOptions): 1129*da0073e9SAndroid Build Coastguard Worker """Run the ONNX test case with options.backend, and compare with the expected outputs. 1130*da0073e9SAndroid Build Coastguard Worker 1131*da0073e9SAndroid Build Coastguard Worker Args: 1132*da0073e9SAndroid Build Coastguard Worker options: Options for validation. 1133*da0073e9SAndroid Build Coastguard Worker 1134*da0073e9SAndroid Build Coastguard Worker Raise: 1135*da0073e9SAndroid Build Coastguard Worker AssertionError: if outputs from options.backend and expected outputs are not 1136*da0073e9SAndroid Build Coastguard Worker equal up to specified precision. 1137*da0073e9SAndroid Build Coastguard Worker """ 1138*da0073e9SAndroid Build Coastguard Worker onnx_session = _onnx_backend_session(io.BytesIO(self.proto), options.backend) 1139*da0073e9SAndroid Build Coastguard Worker run_outputs = onnx_session.run(None, self.inputs) 1140*da0073e9SAndroid Build Coastguard Worker if hasattr(onnx_session, "get_outputs"): 1141*da0073e9SAndroid Build Coastguard Worker output_names = [o.name for o in onnx_session.get_outputs()] 1142*da0073e9SAndroid Build Coastguard Worker elif hasattr(onnx_session, "output_names"): 1143*da0073e9SAndroid Build Coastguard Worker output_names = onnx_session.output_names 1144*da0073e9SAndroid Build Coastguard Worker else: 1145*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Unknown onnx session type: {type(onnx_session)}") 1146*da0073e9SAndroid Build Coastguard Worker expected_outs = [self.outputs[name] for name in output_names] 1147*da0073e9SAndroid Build Coastguard Worker _compare_onnx_pytorch_outputs_in_np(run_outputs, expected_outs, options) 1148*da0073e9SAndroid Build Coastguard Worker 1149*da0073e9SAndroid Build Coastguard Worker 1150*da0073e9SAndroid Build Coastguard Worker@dataclasses.dataclass 1151*da0073e9SAndroid Build Coastguard Workerclass GraphInfo: 1152*da0073e9SAndroid Build Coastguard Worker """GraphInfo contains validation information of a TorchScript graph and its converted ONNX graph.""" 1153*da0073e9SAndroid Build Coastguard Worker 1154*da0073e9SAndroid Build Coastguard Worker graph: torch.Graph 1155*da0073e9SAndroid Build Coastguard Worker input_args: tuple[Any, ...] 1156*da0073e9SAndroid Build Coastguard Worker params_dict: dict[str, Any] 1157*da0073e9SAndroid Build Coastguard Worker export_options: _experimental.ExportOptions = dataclasses.field( 1158*da0073e9SAndroid Build Coastguard Worker default_factory=_experimental.ExportOptions 1159*da0073e9SAndroid Build Coastguard Worker ) 1160*da0073e9SAndroid Build Coastguard Worker mismatch_error: AssertionError | None = dataclasses.field(default=None, init=False) 1161*da0073e9SAndroid Build Coastguard Worker pt_outs: Sequence[_NumericType] | None = dataclasses.field(default=None, init=False) 1162*da0073e9SAndroid Build Coastguard Worker upper_graph_info: GraphInfo | None = dataclasses.field(default=None, init=False) 1163*da0073e9SAndroid Build Coastguard Worker lower_graph_info: GraphInfo | None = dataclasses.field(default=None, init=False) 1164*da0073e9SAndroid Build Coastguard Worker id: str = dataclasses.field(default="") 1165*da0073e9SAndroid Build Coastguard Worker _onnx_graph: torch.Graph | None = dataclasses.field(init=False, default=None) 1166*da0073e9SAndroid Build Coastguard Worker 1167*da0073e9SAndroid Build Coastguard Worker _EXCLUDED_NODE_KINDS: frozenset[str] = frozenset( 1168*da0073e9SAndroid Build Coastguard Worker {"prim::Constant", "prim::ListConstruct", "aten::ScalarImplicit"} 1169*da0073e9SAndroid Build Coastguard Worker ) 1170*da0073e9SAndroid Build Coastguard Worker 1171*da0073e9SAndroid Build Coastguard Worker def clear(self): 1172*da0073e9SAndroid Build Coastguard Worker """Clear states and results of previous verification.""" 1173*da0073e9SAndroid Build Coastguard Worker self.mismatch_error = None 1174*da0073e9SAndroid Build Coastguard Worker self.pt_outs = None 1175*da0073e9SAndroid Build Coastguard Worker self._onnx_graph = None 1176*da0073e9SAndroid Build Coastguard Worker self.upper_graph_info = None 1177*da0073e9SAndroid Build Coastguard Worker self.lower_graph_info = None 1178*da0073e9SAndroid Build Coastguard Worker 1179*da0073e9SAndroid Build Coastguard Worker def pretty_print_tree(self): 1180*da0073e9SAndroid Build Coastguard Worker """Pretty print `GraphInfo` tree. 1181*da0073e9SAndroid Build Coastguard Worker 1182*da0073e9SAndroid Build Coastguard Worker Each node represents a subgraph, showing the number of nodes in the subgraph and 1183*da0073e9SAndroid Build Coastguard Worker a check mark if the subgraph has output mismatch between torch and ONNX. 1184*da0073e9SAndroid Build Coastguard Worker 1185*da0073e9SAndroid Build Coastguard Worker The id of the subgraph is shown under the node. The `GraphInfo` object for any 1186*da0073e9SAndroid Build Coastguard Worker subgraph can be retrieved by calling `graph_info.find_partition(id)`. 1187*da0073e9SAndroid Build Coastguard Worker 1188*da0073e9SAndroid Build Coastguard Worker Example:: 1189*da0073e9SAndroid Build Coastguard Worker 1190*da0073e9SAndroid Build Coastguard Worker ==================================== Tree: ===================================== 1191*da0073e9SAndroid Build Coastguard Worker 5 X __2 X __1 \u2713 1192*da0073e9SAndroid Build Coastguard Worker id: | id: 0 | id: 00 1193*da0073e9SAndroid Build Coastguard Worker | | 1194*da0073e9SAndroid Build Coastguard Worker | |__1 X (aten::relu) 1195*da0073e9SAndroid Build Coastguard Worker | id: 01 1196*da0073e9SAndroid Build Coastguard Worker | 1197*da0073e9SAndroid Build Coastguard Worker |__3 X __1 \u2713 1198*da0073e9SAndroid Build Coastguard Worker id: 1 | id: 10 1199*da0073e9SAndroid Build Coastguard Worker | 1200*da0073e9SAndroid Build Coastguard Worker |__2 X __1 X (aten::relu) 1201*da0073e9SAndroid Build Coastguard Worker id: 11 | id: 110 1202*da0073e9SAndroid Build Coastguard Worker | 1203*da0073e9SAndroid Build Coastguard Worker |__1 \u2713 1204*da0073e9SAndroid Build Coastguard Worker id: 111 1205*da0073e9SAndroid Build Coastguard Worker =========================== Mismatch leaf subgraphs: =========================== 1206*da0073e9SAndroid Build Coastguard Worker ['01', '110'] 1207*da0073e9SAndroid Build Coastguard Worker ============================= Mismatch node kinds: ============================= 1208*da0073e9SAndroid Build Coastguard Worker {'aten::relu': 2} 1209*da0073e9SAndroid Build Coastguard Worker 1210*da0073e9SAndroid Build Coastguard Worker """ 1211*da0073e9SAndroid Build Coastguard Worker GraphInfoPrettyPrinter(self).pretty_print() 1212*da0073e9SAndroid Build Coastguard Worker 1213*da0073e9SAndroid Build Coastguard Worker def pretty_print_mismatch(self, graph: bool = False): 1214*da0073e9SAndroid Build Coastguard Worker """Pretty print details of the mismatch between torch and ONNX. 1215*da0073e9SAndroid Build Coastguard Worker 1216*da0073e9SAndroid Build Coastguard Worker Args: 1217*da0073e9SAndroid Build Coastguard Worker graph: If True, print the ATen JIT graph and ONNX graph. 1218*da0073e9SAndroid Build Coastguard Worker """ 1219*da0073e9SAndroid Build Coastguard Worker print(f" Mismatch info for graph partition {self.id}: ".center(80, "=")) 1220*da0073e9SAndroid Build Coastguard Worker if graph: 1221*da0073e9SAndroid Build Coastguard Worker print(" ATen JIT graph ".center(80, "=")) 1222*da0073e9SAndroid Build Coastguard Worker # TODO: A more compact graph printer. 1223*da0073e9SAndroid Build Coastguard Worker # * Drop stride, grad, device information. 1224*da0073e9SAndroid Build Coastguard Worker # * Show source location on a separate line. 1225*da0073e9SAndroid Build Coastguard Worker print(self.graph) 1226*da0073e9SAndroid Build Coastguard Worker if self._onnx_graph is not None: 1227*da0073e9SAndroid Build Coastguard Worker print(" ONNX graph ".center(80, "=")) 1228*da0073e9SAndroid Build Coastguard Worker print(self._onnx_graph) 1229*da0073e9SAndroid Build Coastguard Worker if self.has_mismatch(): 1230*da0073e9SAndroid Build Coastguard Worker print(" Mismatch error ".center(80, "=")) 1231*da0073e9SAndroid Build Coastguard Worker print(self.mismatch_error) 1232*da0073e9SAndroid Build Coastguard Worker else: 1233*da0073e9SAndroid Build Coastguard Worker print(" No mismatch ".center(80, "=")) 1234*da0073e9SAndroid Build Coastguard Worker 1235*da0073e9SAndroid Build Coastguard Worker def has_mismatch(self) -> bool: 1236*da0073e9SAndroid Build Coastguard Worker """Return True if the subgraph has output mismatch between torch and ONNX.""" 1237*da0073e9SAndroid Build Coastguard Worker return self.mismatch_error is not None 1238*da0073e9SAndroid Build Coastguard Worker 1239*da0073e9SAndroid Build Coastguard Worker def essential_node_count(self) -> int: 1240*da0073e9SAndroid Build Coastguard Worker """Return the number of nodes in the subgraph excluding those in `_EXCLUDED_NODE_KINDS`.""" 1241*da0073e9SAndroid Build Coastguard Worker return sum( 1242*da0073e9SAndroid Build Coastguard Worker 1 for n in self.graph.nodes() if n.kind() not in self._EXCLUDED_NODE_KINDS 1243*da0073e9SAndroid Build Coastguard Worker ) 1244*da0073e9SAndroid Build Coastguard Worker 1245*da0073e9SAndroid Build Coastguard Worker def essential_node_kinds(self) -> set[str]: 1246*da0073e9SAndroid Build Coastguard Worker """Return the set of node kinds in the subgraph excluding those in `_EXCLUDED_NODE_KINDS`.""" 1247*da0073e9SAndroid Build Coastguard Worker return { 1248*da0073e9SAndroid Build Coastguard Worker n.kind() 1249*da0073e9SAndroid Build Coastguard Worker for n in self.graph.nodes() 1250*da0073e9SAndroid Build Coastguard Worker if n.kind() not in self._EXCLUDED_NODE_KINDS 1251*da0073e9SAndroid Build Coastguard Worker } 1252*da0073e9SAndroid Build Coastguard Worker 1253*da0073e9SAndroid Build Coastguard Worker def all_mismatch_leaf_graph_info(self) -> list[GraphInfo]: 1254*da0073e9SAndroid Build Coastguard Worker """Return a list of all leaf `GraphInfo` objects that have mismatch.""" 1255*da0073e9SAndroid Build Coastguard Worker if not self.has_mismatch(): 1256*da0073e9SAndroid Build Coastguard Worker return [] 1257*da0073e9SAndroid Build Coastguard Worker 1258*da0073e9SAndroid Build Coastguard Worker no_mismatch_children = ( 1259*da0073e9SAndroid Build Coastguard Worker self.upper_graph_info is None or not self.upper_graph_info.has_mismatch() 1260*da0073e9SAndroid Build Coastguard Worker ) and ( 1261*da0073e9SAndroid Build Coastguard Worker self.lower_graph_info is None or not self.lower_graph_info.has_mismatch() 1262*da0073e9SAndroid Build Coastguard Worker ) 1263*da0073e9SAndroid Build Coastguard Worker 1264*da0073e9SAndroid Build Coastguard Worker if no_mismatch_children: 1265*da0073e9SAndroid Build Coastguard Worker return [self] 1266*da0073e9SAndroid Build Coastguard Worker 1267*da0073e9SAndroid Build Coastguard Worker results = [] 1268*da0073e9SAndroid Build Coastguard Worker if self.upper_graph_info is not None: 1269*da0073e9SAndroid Build Coastguard Worker results += self.upper_graph_info.all_mismatch_leaf_graph_info() 1270*da0073e9SAndroid Build Coastguard Worker if self.lower_graph_info is not None: 1271*da0073e9SAndroid Build Coastguard Worker results += self.lower_graph_info.all_mismatch_leaf_graph_info() 1272*da0073e9SAndroid Build Coastguard Worker 1273*da0073e9SAndroid Build Coastguard Worker return results 1274*da0073e9SAndroid Build Coastguard Worker 1275*da0073e9SAndroid Build Coastguard Worker def find_partition(self, id: str) -> GraphInfo | None: 1276*da0073e9SAndroid Build Coastguard Worker """Find the `GraphInfo` object with the given id.""" 1277*da0073e9SAndroid Build Coastguard Worker if id == self.id: 1278*da0073e9SAndroid Build Coastguard Worker return self 1279*da0073e9SAndroid Build Coastguard Worker current_length = len(self.id) 1280*da0073e9SAndroid Build Coastguard Worker if len(id) > current_length: 1281*da0073e9SAndroid Build Coastguard Worker if id[current_length] == "0" and self.upper_graph_info is not None: 1282*da0073e9SAndroid Build Coastguard Worker return self.upper_graph_info.find_partition(id) 1283*da0073e9SAndroid Build Coastguard Worker elif id[current_length] == "1" and self.lower_graph_info is not None: 1284*da0073e9SAndroid Build Coastguard Worker return self.lower_graph_info.find_partition(id) 1285*da0073e9SAndroid Build Coastguard Worker return None 1286*da0073e9SAndroid Build Coastguard Worker 1287*da0073e9SAndroid Build Coastguard Worker def export_repro( 1288*da0073e9SAndroid Build Coastguard Worker self, repro_dir: str | None = None, name: str | None = None 1289*da0073e9SAndroid Build Coastguard Worker ) -> str: 1290*da0073e9SAndroid Build Coastguard Worker """Export the subgraph to ONNX along with the input/output data for repro. 1291*da0073e9SAndroid Build Coastguard Worker 1292*da0073e9SAndroid Build Coastguard Worker The repro directory will contain the following files:: 1293*da0073e9SAndroid Build Coastguard Worker 1294*da0073e9SAndroid Build Coastguard Worker dir 1295*da0073e9SAndroid Build Coastguard Worker \u251c\u2500\u2500 test_<name> 1296*da0073e9SAndroid Build Coastguard Worker \u2502 \u251c\u2500\u2500 model.onnx 1297*da0073e9SAndroid Build Coastguard Worker \u2502 \u2514\u2500\u2500 test_data_set_0 1298*da0073e9SAndroid Build Coastguard Worker \u2502 \u251c\u2500\u2500 input_0.pb 1299*da0073e9SAndroid Build Coastguard Worker \u2502 \u251c\u2500\u2500 input_1.pb 1300*da0073e9SAndroid Build Coastguard Worker \u2502 \u251c\u2500\u2500 output_0.pb 1301*da0073e9SAndroid Build Coastguard Worker \u2502 \u2514\u2500\u2500 output_1.pb 1302*da0073e9SAndroid Build Coastguard Worker 1303*da0073e9SAndroid Build Coastguard Worker Args: 1304*da0073e9SAndroid Build Coastguard Worker repro_dir: The directory to export the repro files to. Defaults to current 1305*da0073e9SAndroid Build Coastguard Worker working directory if None. 1306*da0073e9SAndroid Build Coastguard Worker name: An optional name for the test case folder: "test_{name}". 1307*da0073e9SAndroid Build Coastguard Worker 1308*da0073e9SAndroid Build Coastguard Worker Returns: 1309*da0073e9SAndroid Build Coastguard Worker The path to the exported repro directory. 1310*da0073e9SAndroid Build Coastguard Worker """ 1311*da0073e9SAndroid Build Coastguard Worker 1312*da0073e9SAndroid Build Coastguard Worker if repro_dir is None: 1313*da0073e9SAndroid Build Coastguard Worker repro_dir = os.getcwd() 1314*da0073e9SAndroid Build Coastguard Worker repro_dir = os.path.join(repro_dir, "onnx_debug") 1315*da0073e9SAndroid Build Coastguard Worker 1316*da0073e9SAndroid Build Coastguard Worker onnx_graph, onnx_params_dict = _onnx_graph_from_aten_graph( 1317*da0073e9SAndroid Build Coastguard Worker self.graph, self.export_options, self.params_dict 1318*da0073e9SAndroid Build Coastguard Worker ) 1319*da0073e9SAndroid Build Coastguard Worker 1320*da0073e9SAndroid Build Coastguard Worker proto, _ = _onnx_proto_from_onnx_graph( 1321*da0073e9SAndroid Build Coastguard Worker onnx_graph, self.export_options, onnx_params_dict 1322*da0073e9SAndroid Build Coastguard Worker ) 1323*da0073e9SAndroid Build Coastguard Worker return OnnxTestCaseRepro.create_test_case_repro( 1324*da0073e9SAndroid Build Coastguard Worker proto, self.input_args, self.pt_outs, repro_dir, name 1325*da0073e9SAndroid Build Coastguard Worker ) 1326*da0073e9SAndroid Build Coastguard Worker 1327*da0073e9SAndroid Build Coastguard Worker def _graph_partition_pivot(self) -> int: 1328*da0073e9SAndroid Build Coastguard Worker """Find the pivot index to partition the graph. 1329*da0073e9SAndroid Build Coastguard Worker 1330*da0073e9SAndroid Build Coastguard Worker The pivot is the node that splits the graph into two parts. Each part should 1331*da0073e9SAndroid Build Coastguard Worker have the similar amount of nodes, excluding non essential ops, defined in 1332*da0073e9SAndroid Build Coastguard Worker `_EXCLUDED_NODE_KINDS`, such as `prim::Constant`. 1333*da0073e9SAndroid Build Coastguard Worker If the graph has an odd number of nodes, the upper part will have one more node. 1334*da0073e9SAndroid Build Coastguard Worker If the graph does not have any node that can be partitioned, return -1. 1335*da0073e9SAndroid Build Coastguard Worker 1336*da0073e9SAndroid Build Coastguard Worker Returns: 1337*da0073e9SAndroid Build Coastguard Worker The index of the pivot node. 1338*da0073e9SAndroid Build Coastguard Worker """ 1339*da0073e9SAndroid Build Coastguard Worker included_node_indices = [ 1340*da0073e9SAndroid Build Coastguard Worker i 1341*da0073e9SAndroid Build Coastguard Worker for i, n in enumerate(self.graph.nodes()) 1342*da0073e9SAndroid Build Coastguard Worker if n.kind() not in self._EXCLUDED_NODE_KINDS 1343*da0073e9SAndroid Build Coastguard Worker ] 1344*da0073e9SAndroid Build Coastguard Worker half_idx = len(included_node_indices) // 2 - 1 1345*da0073e9SAndroid Build Coastguard Worker if half_idx >= 0 and len(included_node_indices) > half_idx: 1346*da0073e9SAndroid Build Coastguard Worker return included_node_indices[half_idx] + 1 1347*da0073e9SAndroid Build Coastguard Worker return -1 1348*da0073e9SAndroid Build Coastguard Worker 1349*da0073e9SAndroid Build Coastguard Worker def _partition_upper_graph(self) -> torch.Graph: 1350*da0073e9SAndroid Build Coastguard Worker pivot = self._graph_partition_pivot() 1351*da0073e9SAndroid Build Coastguard Worker if pivot == -1: 1352*da0073e9SAndroid Build Coastguard Worker return torch.Graph() 1353*da0073e9SAndroid Build Coastguard Worker graph = self.graph.copy() # Copy to not mutate parent graph. 1354*da0073e9SAndroid Build Coastguard Worker original_outputs = list(graph.outputs()) 1355*da0073e9SAndroid Build Coastguard Worker 1356*da0073e9SAndroid Build Coastguard Worker def _process_bridge_value_for_upper( 1357*da0073e9SAndroid Build Coastguard Worker new_outputs: list[torch.Value], bridge_value: torch.Value 1358*da0073e9SAndroid Build Coastguard Worker ) -> torch.Value: 1359*da0073e9SAndroid Build Coastguard Worker # Add bridge values as upper graph outputs. 1360*da0073e9SAndroid Build Coastguard Worker new_outputs.append(bridge_value) 1361*da0073e9SAndroid Build Coastguard Worker return bridge_value 1362*da0073e9SAndroid Build Coastguard Worker 1363*da0073e9SAndroid Build Coastguard Worker new_outputs: list[torch.Value] = [] 1364*da0073e9SAndroid Build Coastguard Worker process_bridge_value_for_upper = functools.partial( 1365*da0073e9SAndroid Build Coastguard Worker _process_bridge_value_for_upper, new_outputs 1366*da0073e9SAndroid Build Coastguard Worker ) 1367*da0073e9SAndroid Build Coastguard Worker _, dropped_nodes, complete_upper_nodes_set, _ = self._partition_nodes( 1368*da0073e9SAndroid Build Coastguard Worker graph, pivot, process_bridge_value_for_upper 1369*da0073e9SAndroid Build Coastguard Worker ) 1370*da0073e9SAndroid Build Coastguard Worker 1371*da0073e9SAndroid Build Coastguard Worker for _ in enumerate(original_outputs): 1372*da0073e9SAndroid Build Coastguard Worker graph.eraseOutput(0) 1373*da0073e9SAndroid Build Coastguard Worker for output in new_outputs: 1374*da0073e9SAndroid Build Coastguard Worker graph.registerOutput(output) 1375*da0073e9SAndroid Build Coastguard Worker 1376*da0073e9SAndroid Build Coastguard Worker for node in reversed(dropped_nodes): 1377*da0073e9SAndroid Build Coastguard Worker node.destroy() 1378*da0073e9SAndroid Build Coastguard Worker 1379*da0073e9SAndroid Build Coastguard Worker for i, input in reversed(list(enumerate(list(graph.inputs())))): 1380*da0073e9SAndroid Build Coastguard Worker if ( 1381*da0073e9SAndroid Build Coastguard Worker not _has_uses_by_nodes(input, complete_upper_nodes_set) 1382*da0073e9SAndroid Build Coastguard Worker and input not in new_outputs 1383*da0073e9SAndroid Build Coastguard Worker ): 1384*da0073e9SAndroid Build Coastguard Worker try: 1385*da0073e9SAndroid Build Coastguard Worker graph.eraseInput(i) 1386*da0073e9SAndroid Build Coastguard Worker except RuntimeError as e: 1387*da0073e9SAndroid Build Coastguard Worker print(input, graph) 1388*da0073e9SAndroid Build Coastguard Worker raise e 1389*da0073e9SAndroid Build Coastguard Worker 1390*da0073e9SAndroid Build Coastguard Worker return graph 1391*da0073e9SAndroid Build Coastguard Worker 1392*da0073e9SAndroid Build Coastguard Worker def _partition_lower_graph(self) -> torch.Graph: 1393*da0073e9SAndroid Build Coastguard Worker pivot = self._graph_partition_pivot() 1394*da0073e9SAndroid Build Coastguard Worker if pivot == -1: 1395*da0073e9SAndroid Build Coastguard Worker return torch.Graph() 1396*da0073e9SAndroid Build Coastguard Worker graph = self.graph.copy() # Copy to not mutate parent graph. 1397*da0073e9SAndroid Build Coastguard Worker original_outputs = list(graph.outputs()) 1398*da0073e9SAndroid Build Coastguard Worker original_inputs = list(graph.inputs()) 1399*da0073e9SAndroid Build Coastguard Worker 1400*da0073e9SAndroid Build Coastguard Worker new_outputs = [] 1401*da0073e9SAndroid Build Coastguard Worker 1402*da0073e9SAndroid Build Coastguard Worker def _process_bridge_value_for_lower( 1403*da0073e9SAndroid Build Coastguard Worker graph: torch.Graph, bridge_value: torch.Value 1404*da0073e9SAndroid Build Coastguard Worker ) -> torch.Value: 1405*da0073e9SAndroid Build Coastguard Worker # Add bridge values as lower graph inputs. 1406*da0073e9SAndroid Build Coastguard Worker new_input = graph.addInput() 1407*da0073e9SAndroid Build Coastguard Worker bridge_value.replaceAllUsesWith(new_input) 1408*da0073e9SAndroid Build Coastguard Worker new_input.copyMetadata(bridge_value) 1409*da0073e9SAndroid Build Coastguard Worker return new_input 1410*da0073e9SAndroid Build Coastguard Worker 1411*da0073e9SAndroid Build Coastguard Worker process_bridge_value_for_lower = functools.partial( 1412*da0073e9SAndroid Build Coastguard Worker _process_bridge_value_for_lower, graph 1413*da0073e9SAndroid Build Coastguard Worker ) 1414*da0073e9SAndroid Build Coastguard Worker 1415*da0073e9SAndroid Build Coastguard Worker upper_nodes, lower_nodes, _, complete_lower_nodes_set = self._partition_nodes( 1416*da0073e9SAndroid Build Coastguard Worker graph, pivot, process_bridge_value_for_lower 1417*da0073e9SAndroid Build Coastguard Worker ) 1418*da0073e9SAndroid Build Coastguard Worker 1419*da0073e9SAndroid Build Coastguard Worker for output in original_outputs: 1420*da0073e9SAndroid Build Coastguard Worker if _produced_by(output, lower_nodes): 1421*da0073e9SAndroid Build Coastguard Worker new_outputs.append(output) 1422*da0073e9SAndroid Build Coastguard Worker for _ in enumerate(original_outputs): 1423*da0073e9SAndroid Build Coastguard Worker graph.eraseOutput(0) 1424*da0073e9SAndroid Build Coastguard Worker for output in new_outputs: 1425*da0073e9SAndroid Build Coastguard Worker graph.registerOutput(output) 1426*da0073e9SAndroid Build Coastguard Worker 1427*da0073e9SAndroid Build Coastguard Worker for input in original_inputs: 1428*da0073e9SAndroid Build Coastguard Worker if _has_uses_by_nodes(input, complete_lower_nodes_set): 1429*da0073e9SAndroid Build Coastguard Worker new_input = graph.addInput() 1430*da0073e9SAndroid Build Coastguard Worker input.replaceAllUsesWith(new_input) 1431*da0073e9SAndroid Build Coastguard Worker new_input.copyMetadata(input) 1432*da0073e9SAndroid Build Coastguard Worker 1433*da0073e9SAndroid Build Coastguard Worker for node in reversed(upper_nodes): 1434*da0073e9SAndroid Build Coastguard Worker if node not in complete_lower_nodes_set: 1435*da0073e9SAndroid Build Coastguard Worker try: 1436*da0073e9SAndroid Build Coastguard Worker node.destroy() 1437*da0073e9SAndroid Build Coastguard Worker except RuntimeError as e: 1438*da0073e9SAndroid Build Coastguard Worker print(node, graph) 1439*da0073e9SAndroid Build Coastguard Worker raise e 1440*da0073e9SAndroid Build Coastguard Worker 1441*da0073e9SAndroid Build Coastguard Worker for _ in original_inputs: 1442*da0073e9SAndroid Build Coastguard Worker graph.eraseInput(0) 1443*da0073e9SAndroid Build Coastguard Worker 1444*da0073e9SAndroid Build Coastguard Worker return graph 1445*da0073e9SAndroid Build Coastguard Worker 1446*da0073e9SAndroid Build Coastguard Worker def _partition_node( 1447*da0073e9SAndroid Build Coastguard Worker self, 1448*da0073e9SAndroid Build Coastguard Worker node: torch.Node, 1449*da0073e9SAndroid Build Coastguard Worker complete_upper_nodes_set: set[torch.Node], 1450*da0073e9SAndroid Build Coastguard Worker complete_lower_nodes_set: set[torch.Node], 1451*da0073e9SAndroid Build Coastguard Worker original_graph_outputs: set[torch.Value], 1452*da0073e9SAndroid Build Coastguard Worker covered_bridge_values: set[torch.Value], 1453*da0073e9SAndroid Build Coastguard Worker process_bridge_value: Callable[[torch.Value], torch.Value], 1454*da0073e9SAndroid Build Coastguard Worker ): 1455*da0073e9SAndroid Build Coastguard Worker if node in complete_lower_nodes_set: 1456*da0073e9SAndroid Build Coastguard Worker return 1457*da0073e9SAndroid Build Coastguard Worker 1458*da0073e9SAndroid Build Coastguard Worker if ( 1459*da0073e9SAndroid Build Coastguard Worker _node_has_uses_by(node, complete_lower_nodes_set) 1460*da0073e9SAndroid Build Coastguard Worker and node.kind() in self._EXCLUDED_NODE_KINDS 1461*da0073e9SAndroid Build Coastguard Worker ): 1462*da0073e9SAndroid Build Coastguard Worker complete_lower_nodes_set.update(_all_nodes([node])) 1463*da0073e9SAndroid Build Coastguard Worker for input in node.inputs(): 1464*da0073e9SAndroid Build Coastguard Worker if input in covered_bridge_values: 1465*da0073e9SAndroid Build Coastguard Worker continue 1466*da0073e9SAndroid Build Coastguard Worker self._partition_node( 1467*da0073e9SAndroid Build Coastguard Worker input.node(), 1468*da0073e9SAndroid Build Coastguard Worker complete_upper_nodes_set, 1469*da0073e9SAndroid Build Coastguard Worker complete_lower_nodes_set, 1470*da0073e9SAndroid Build Coastguard Worker original_graph_outputs, 1471*da0073e9SAndroid Build Coastguard Worker covered_bridge_values, 1472*da0073e9SAndroid Build Coastguard Worker process_bridge_value, 1473*da0073e9SAndroid Build Coastguard Worker ) 1474*da0073e9SAndroid Build Coastguard Worker else: 1475*da0073e9SAndroid Build Coastguard Worker for output in node.outputs(): 1476*da0073e9SAndroid Build Coastguard Worker if output in covered_bridge_values: 1477*da0073e9SAndroid Build Coastguard Worker continue 1478*da0073e9SAndroid Build Coastguard Worker if ( 1479*da0073e9SAndroid Build Coastguard Worker _has_uses_by_nodes(output, complete_lower_nodes_set) 1480*da0073e9SAndroid Build Coastguard Worker or output in original_graph_outputs 1481*da0073e9SAndroid Build Coastguard Worker ): 1482*da0073e9SAndroid Build Coastguard Worker covered_bridge_values.add(process_bridge_value(output)) 1483*da0073e9SAndroid Build Coastguard Worker 1484*da0073e9SAndroid Build Coastguard Worker def _partition_nodes( 1485*da0073e9SAndroid Build Coastguard Worker self, 1486*da0073e9SAndroid Build Coastguard Worker graph: torch.Graph, 1487*da0073e9SAndroid Build Coastguard Worker pivot: int, 1488*da0073e9SAndroid Build Coastguard Worker process_bridge_value: Callable[[torch.Value], torch.Value], 1489*da0073e9SAndroid Build Coastguard Worker ) -> tuple[list[torch.Node], list[torch.Node], set[torch.Node], set[torch.Node]]: 1490*da0073e9SAndroid Build Coastguard Worker nodes = list(graph.nodes()) 1491*da0073e9SAndroid Build Coastguard Worker upper_nodes = nodes[:pivot] 1492*da0073e9SAndroid Build Coastguard Worker lower_nodes = nodes[pivot:] 1493*da0073e9SAndroid Build Coastguard Worker # `upper_nodes` and `complete_upper_nodes_set` differs in that the latter 1494*da0073e9SAndroid Build Coastguard Worker # recursively contains nodes in subblock of `upper_nodes`. 1495*da0073e9SAndroid Build Coastguard Worker # The same applies for `lower_nodes` and `complete_lower_nodes_set`. 1496*da0073e9SAndroid Build Coastguard Worker # With addition that `complete_lower_nodes_set` will include nodes that 1497*da0073e9SAndroid Build Coastguard Worker # are determined to be copied from `upper_nodes` to `lower_nodes`. 1498*da0073e9SAndroid Build Coastguard Worker complete_upper_nodes_set = _all_nodes(upper_nodes) 1499*da0073e9SAndroid Build Coastguard Worker complete_lower_nodes_set = _all_nodes(lower_nodes) 1500*da0073e9SAndroid Build Coastguard Worker original_graph_outputs = set(graph.outputs()) 1501*da0073e9SAndroid Build Coastguard Worker # Bridge values are values produced from upper graph, and consumed 1502*da0073e9SAndroid Build Coastguard Worker # by lower graph. These values need to be become upper graph outputs 1503*da0073e9SAndroid Build Coastguard Worker # and lower graph inputs, to bridge the interaction. 1504*da0073e9SAndroid Build Coastguard Worker # Start with all graph inputs marked as covered. If any graph input is 1505*da0073e9SAndroid Build Coastguard Worker # needed by lower graph, just keep it in lower graph inputs later. 1506*da0073e9SAndroid Build Coastguard Worker covered_bridge_values = set(graph.inputs()) 1507*da0073e9SAndroid Build Coastguard Worker for node in upper_nodes: 1508*da0073e9SAndroid Build Coastguard Worker self._partition_node( 1509*da0073e9SAndroid Build Coastguard Worker node, 1510*da0073e9SAndroid Build Coastguard Worker complete_upper_nodes_set, 1511*da0073e9SAndroid Build Coastguard Worker complete_lower_nodes_set, 1512*da0073e9SAndroid Build Coastguard Worker original_graph_outputs, 1513*da0073e9SAndroid Build Coastguard Worker covered_bridge_values, 1514*da0073e9SAndroid Build Coastguard Worker process_bridge_value, 1515*da0073e9SAndroid Build Coastguard Worker ) 1516*da0073e9SAndroid Build Coastguard Worker return ( 1517*da0073e9SAndroid Build Coastguard Worker upper_nodes, 1518*da0073e9SAndroid Build Coastguard Worker lower_nodes, 1519*da0073e9SAndroid Build Coastguard Worker complete_upper_nodes_set, 1520*da0073e9SAndroid Build Coastguard Worker complete_lower_nodes_set, 1521*da0073e9SAndroid Build Coastguard Worker ) 1522*da0073e9SAndroid Build Coastguard Worker 1523*da0073e9SAndroid Build Coastguard Worker def _bridge_kwargs(self): 1524*da0073e9SAndroid Build Coastguard Worker pt_outs = self.pt_outs 1525*da0073e9SAndroid Build Coastguard Worker graph_outputs = list(self.graph.outputs()) 1526*da0073e9SAndroid Build Coastguard Worker assert pt_outs is not None 1527*da0073e9SAndroid Build Coastguard Worker assert len(graph_outputs) == len( 1528*da0073e9SAndroid Build Coastguard Worker pt_outs 1529*da0073e9SAndroid Build Coastguard Worker ), f"{len(graph_outputs)} vs {len(pt_outs)}\nGraph: {self.graph}" 1530*da0073e9SAndroid Build Coastguard Worker return {v.debugName(): o for v, o in zip(graph_outputs, pt_outs)} 1531*da0073e9SAndroid Build Coastguard Worker 1532*da0073e9SAndroid Build Coastguard Worker def _args_and_params_for_partition_graph( 1533*da0073e9SAndroid Build Coastguard Worker self, 1534*da0073e9SAndroid Build Coastguard Worker graph: torch.Graph, 1535*da0073e9SAndroid Build Coastguard Worker bridge_kwargs: Mapping[str, _NumericType | Sequence[_NumericType]], 1536*da0073e9SAndroid Build Coastguard Worker full_kwargs: Mapping[str, torch.Tensor], 1537*da0073e9SAndroid Build Coastguard Worker full_params: Mapping[str, torch.Tensor], 1538*da0073e9SAndroid Build Coastguard Worker ): 1539*da0073e9SAndroid Build Coastguard Worker input_names = [input.debugName() for input in graph.inputs()] 1540*da0073e9SAndroid Build Coastguard Worker args = tuple(bridge_kwargs[k] for k in input_names if k in bridge_kwargs) 1541*da0073e9SAndroid Build Coastguard Worker args += tuple(full_kwargs[k] for k in input_names if k in full_kwargs) 1542*da0073e9SAndroid Build Coastguard Worker params = {k: full_params[k] for k in input_names if k in full_params} 1543*da0073e9SAndroid Build Coastguard Worker assert len(args) + len(params) == len( 1544*da0073e9SAndroid Build Coastguard Worker input_names 1545*da0073e9SAndroid Build Coastguard Worker ), f"{len(args)} + {len(params)} vs {len(input_names)}: {input_names}" 1546*da0073e9SAndroid Build Coastguard Worker return args, params 1547*da0073e9SAndroid Build Coastguard Worker 1548*da0073e9SAndroid Build Coastguard Worker def verify_export( 1549*da0073e9SAndroid Build Coastguard Worker self, options: VerificationOptions 1550*da0073e9SAndroid Build Coastguard Worker ) -> tuple[AssertionError | None, torch.Graph, _OutputsType, _OutputsType]: 1551*da0073e9SAndroid Build Coastguard Worker """ 1552*da0073e9SAndroid Build Coastguard Worker Verify the export from TorchScript IR graph to ONNX. 1553*da0073e9SAndroid Build Coastguard Worker 1554*da0073e9SAndroid Build Coastguard Worker Export the TorchScript IR graph to ONNX, with the inputs, parameters and export 1555*da0073e9SAndroid Build Coastguard Worker options recorded in this object. Then verify the exported ONNX graph against 1556*da0073e9SAndroid Build Coastguard Worker the original TorchScript IR graph under the provided verification options. 1557*da0073e9SAndroid Build Coastguard Worker 1558*da0073e9SAndroid Build Coastguard Worker Args: 1559*da0073e9SAndroid Build Coastguard Worker options: The verification options. 1560*da0073e9SAndroid Build Coastguard Worker 1561*da0073e9SAndroid Build Coastguard Worker Returns: 1562*da0073e9SAndroid Build Coastguard Worker error: The AssertionError raised during the verification. Returns None if no 1563*da0073e9SAndroid Build Coastguard Worker error is raised. 1564*da0073e9SAndroid Build Coastguard Worker onnx_graph: The exported ONNX graph in TorchScript IR format. 1565*da0073e9SAndroid Build Coastguard Worker onnx_outs: The outputs from running exported ONNX model under the onnx 1566*da0073e9SAndroid Build Coastguard Worker backend in `options`. 1567*da0073e9SAndroid Build Coastguard Worker pt_outs: The outputs from running the TorchScript IR graph. 1568*da0073e9SAndroid Build Coastguard Worker """ 1569*da0073e9SAndroid Build Coastguard Worker return verify_aten_graph( 1570*da0073e9SAndroid Build Coastguard Worker self.graph, 1571*da0073e9SAndroid Build Coastguard Worker input_args=self.input_args, 1572*da0073e9SAndroid Build Coastguard Worker params_dict=self.params_dict, 1573*da0073e9SAndroid Build Coastguard Worker export_options=self.export_options, 1574*da0073e9SAndroid Build Coastguard Worker verification_options=options, 1575*da0073e9SAndroid Build Coastguard Worker ) 1576*da0073e9SAndroid Build Coastguard Worker 1577*da0073e9SAndroid Build Coastguard Worker def find_mismatch( 1578*da0073e9SAndroid Build Coastguard Worker self, 1579*da0073e9SAndroid Build Coastguard Worker options: VerificationOptions | None = None, 1580*da0073e9SAndroid Build Coastguard Worker ): 1581*da0073e9SAndroid Build Coastguard Worker """ 1582*da0073e9SAndroid Build Coastguard Worker Find all mismatches between the TorchScript IR graph and the exported onnx model. 1583*da0073e9SAndroid Build Coastguard Worker 1584*da0073e9SAndroid Build Coastguard Worker Binary searches the model graph to find the minimal subgraph that exhibits the 1585*da0073e9SAndroid Build Coastguard Worker mismatch. A `GraphInfo` object is created for each subgraph, recording the test 1586*da0073e9SAndroid Build Coastguard Worker inputs and export options, as well as the validation results. 1587*da0073e9SAndroid Build Coastguard Worker 1588*da0073e9SAndroid Build Coastguard Worker Args: 1589*da0073e9SAndroid Build Coastguard Worker options: The verification options. 1590*da0073e9SAndroid Build Coastguard Worker """ 1591*da0073e9SAndroid Build Coastguard Worker self.clear() 1592*da0073e9SAndroid Build Coastguard Worker 1593*da0073e9SAndroid Build Coastguard Worker if options is None: 1594*da0073e9SAndroid Build Coastguard Worker options = VerificationOptions() 1595*da0073e9SAndroid Build Coastguard Worker 1596*da0073e9SAndroid Build Coastguard Worker if self.export_options.verbose: 1597*da0073e9SAndroid Build Coastguard Worker print(self.graph) 1598*da0073e9SAndroid Build Coastguard Worker 1599*da0073e9SAndroid Build Coastguard Worker if len(list(self.graph.outputs())) == 0: 1600*da0073e9SAndroid Build Coastguard Worker return 1601*da0073e9SAndroid Build Coastguard Worker 1602*da0073e9SAndroid Build Coastguard Worker assert len(self.input_args) + len(self.params_dict) == len( 1603*da0073e9SAndroid Build Coastguard Worker list(self.graph.inputs()) 1604*da0073e9SAndroid Build Coastguard Worker ), ( 1605*da0073e9SAndroid Build Coastguard Worker f"Number of graph inputs({len(list(self.graph.inputs()))}) does not match " 1606*da0073e9SAndroid Build Coastguard Worker f"the provided tensor arguments({len(self.input_args)} + {len(self.params_dict)})." 1607*da0073e9SAndroid Build Coastguard Worker ) 1608*da0073e9SAndroid Build Coastguard Worker 1609*da0073e9SAndroid Build Coastguard Worker self.mismatch_error, self._onnx_graph, self.pt_outs, _ = self.verify_export( 1610*da0073e9SAndroid Build Coastguard Worker options 1611*da0073e9SAndroid Build Coastguard Worker ) 1612*da0073e9SAndroid Build Coastguard Worker 1613*da0073e9SAndroid Build Coastguard Worker if self.mismatch_error is None: 1614*da0073e9SAndroid Build Coastguard Worker # No mismatch found in graph. 1615*da0073e9SAndroid Build Coastguard Worker return 1616*da0073e9SAndroid Build Coastguard Worker 1617*da0073e9SAndroid Build Coastguard Worker if self.essential_node_count() <= 1: 1618*da0073e9SAndroid Build Coastguard Worker # Reached leaf node, no more partitioning. 1619*da0073e9SAndroid Build Coastguard Worker return 1620*da0073e9SAndroid Build Coastguard Worker 1621*da0073e9SAndroid Build Coastguard Worker full_kwargs = { 1622*da0073e9SAndroid Build Coastguard Worker k.debugName(): v for k, v in zip(self.graph.inputs(), self.input_args) 1623*da0073e9SAndroid Build Coastguard Worker } 1624*da0073e9SAndroid Build Coastguard Worker full_params = self.params_dict 1625*da0073e9SAndroid Build Coastguard Worker 1626*da0073e9SAndroid Build Coastguard Worker upper_graph = self._partition_upper_graph() 1627*da0073e9SAndroid Build Coastguard Worker upper_args, upper_params = self._args_and_params_for_partition_graph( 1628*da0073e9SAndroid Build Coastguard Worker upper_graph, {}, full_kwargs, full_params 1629*da0073e9SAndroid Build Coastguard Worker ) 1630*da0073e9SAndroid Build Coastguard Worker self.upper_graph_info = GraphInfo( 1631*da0073e9SAndroid Build Coastguard Worker upper_graph, 1632*da0073e9SAndroid Build Coastguard Worker upper_args, 1633*da0073e9SAndroid Build Coastguard Worker upper_params, 1634*da0073e9SAndroid Build Coastguard Worker self.export_options, 1635*da0073e9SAndroid Build Coastguard Worker id=self.id + "0", 1636*da0073e9SAndroid Build Coastguard Worker ) 1637*da0073e9SAndroid Build Coastguard Worker 1638*da0073e9SAndroid Build Coastguard Worker self.upper_graph_info.find_mismatch(options) 1639*da0073e9SAndroid Build Coastguard Worker 1640*da0073e9SAndroid Build Coastguard Worker bridge_kwargs = self.upper_graph_info._bridge_kwargs() 1641*da0073e9SAndroid Build Coastguard Worker lower_graph = self._partition_lower_graph() 1642*da0073e9SAndroid Build Coastguard Worker lower_args, lower_params = self._args_and_params_for_partition_graph( 1643*da0073e9SAndroid Build Coastguard Worker lower_graph, bridge_kwargs, full_kwargs, full_params 1644*da0073e9SAndroid Build Coastguard Worker ) 1645*da0073e9SAndroid Build Coastguard Worker self.lower_graph_info = GraphInfo( 1646*da0073e9SAndroid Build Coastguard Worker lower_graph, 1647*da0073e9SAndroid Build Coastguard Worker lower_args, 1648*da0073e9SAndroid Build Coastguard Worker lower_params, 1649*da0073e9SAndroid Build Coastguard Worker self.export_options, 1650*da0073e9SAndroid Build Coastguard Worker id=self.id + "1", 1651*da0073e9SAndroid Build Coastguard Worker ) 1652*da0073e9SAndroid Build Coastguard Worker 1653*da0073e9SAndroid Build Coastguard Worker self.lower_graph_info.find_mismatch(options) 1654*da0073e9SAndroid Build Coastguard Worker 1655*da0073e9SAndroid Build Coastguard Worker 1656*da0073e9SAndroid Build Coastguard Workerdef _all_nodes(nodes: Collection[torch.Node]) -> set[torch.Node]: 1657*da0073e9SAndroid Build Coastguard Worker all_nodes = set(nodes) 1658*da0073e9SAndroid Build Coastguard Worker for n in nodes: 1659*da0073e9SAndroid Build Coastguard Worker for b in n.blocks(): 1660*da0073e9SAndroid Build Coastguard Worker all_nodes.update(_all_nodes(list(b.nodes()))) 1661*da0073e9SAndroid Build Coastguard Worker return all_nodes 1662*da0073e9SAndroid Build Coastguard Worker 1663*da0073e9SAndroid Build Coastguard Worker 1664*da0073e9SAndroid Build Coastguard Workerdef _has_uses_by_nodes(value: torch.Value, nodes: Collection[torch.Node]) -> bool: 1665*da0073e9SAndroid Build Coastguard Worker return any(use.user in nodes for use in value.uses()) 1666*da0073e9SAndroid Build Coastguard Worker 1667*da0073e9SAndroid Build Coastguard Worker 1668*da0073e9SAndroid Build Coastguard Workerdef _node_has_uses_by(node: torch.Node, nodes: Collection[torch.Node]) -> bool: 1669*da0073e9SAndroid Build Coastguard Worker for output in node.outputs(): 1670*da0073e9SAndroid Build Coastguard Worker if _has_uses_by_nodes(output, nodes): 1671*da0073e9SAndroid Build Coastguard Worker return True 1672*da0073e9SAndroid Build Coastguard Worker return False 1673*da0073e9SAndroid Build Coastguard Worker 1674*da0073e9SAndroid Build Coastguard Worker 1675*da0073e9SAndroid Build Coastguard Workerdef _produced_by(value: torch.Value, nodes: Collection[torch.Node]) -> bool: 1676*da0073e9SAndroid Build Coastguard Worker return value.node() in nodes 1677*da0073e9SAndroid Build Coastguard Worker 1678*da0073e9SAndroid Build Coastguard Worker 1679*da0073e9SAndroid Build Coastguard Workerdef find_mismatch( 1680*da0073e9SAndroid Build Coastguard Worker model: torch.nn.Module | torch.jit.ScriptModule, 1681*da0073e9SAndroid Build Coastguard Worker input_args: tuple[Any, ...], 1682*da0073e9SAndroid Build Coastguard Worker do_constant_folding: bool = True, 1683*da0073e9SAndroid Build Coastguard Worker training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL, 1684*da0073e9SAndroid Build Coastguard Worker opset_version: int | None = None, 1685*da0073e9SAndroid Build Coastguard Worker keep_initializers_as_inputs: bool = True, 1686*da0073e9SAndroid Build Coastguard Worker verbose: bool = False, 1687*da0073e9SAndroid Build Coastguard Worker options: VerificationOptions | None = None, 1688*da0073e9SAndroid Build Coastguard Worker) -> GraphInfo: 1689*da0073e9SAndroid Build Coastguard Worker r"""Find all mismatches between the original model and the exported model. 1690*da0073e9SAndroid Build Coastguard Worker 1691*da0073e9SAndroid Build Coastguard Worker Experimental. The API is subject to change. 1692*da0073e9SAndroid Build Coastguard Worker 1693*da0073e9SAndroid Build Coastguard Worker This tool helps debug the mismatch between the original PyTorch model and exported 1694*da0073e9SAndroid Build Coastguard Worker ONNX model. It binary searches the model graph to find the minimal subgraph that 1695*da0073e9SAndroid Build Coastguard Worker exhibits the mismatch. 1696*da0073e9SAndroid Build Coastguard Worker 1697*da0073e9SAndroid Build Coastguard Worker Args: 1698*da0073e9SAndroid Build Coastguard Worker model: The model to be exported. 1699*da0073e9SAndroid Build Coastguard Worker input_args: The input arguments to the model. 1700*da0073e9SAndroid Build Coastguard Worker do_constant_folding: Same as `do_constant_folding` in :func:`torch.onnx.export`. 1701*da0073e9SAndroid Build Coastguard Worker training: Same as `training` in :func:`torch.onnx.export`. 1702*da0073e9SAndroid Build Coastguard Worker opset_version: Same as `opset_version` in :func:`torch.onnx.export`. 1703*da0073e9SAndroid Build Coastguard Worker keep_initializers_as_inputs: Same as `keep_initializers_as_inputs` in :func:`torch.onnx.export`. 1704*da0073e9SAndroid Build Coastguard Worker verbose: Same as `verbose` in :func:`torch.onnx.export`. 1705*da0073e9SAndroid Build Coastguard Worker options: The options for the mismatch verification. 1706*da0073e9SAndroid Build Coastguard Worker 1707*da0073e9SAndroid Build Coastguard Worker Returns: 1708*da0073e9SAndroid Build Coastguard Worker A GraphInfo object that contains the mismatch information. 1709*da0073e9SAndroid Build Coastguard Worker 1710*da0073e9SAndroid Build Coastguard Worker Example:: 1711*da0073e9SAndroid Build Coastguard Worker 1712*da0073e9SAndroid Build Coastguard Worker >>> import torch 1713*da0073e9SAndroid Build Coastguard Worker >>> import torch.onnx.verification 1714*da0073e9SAndroid Build Coastguard Worker >>> torch.manual_seed(0) 1715*da0073e9SAndroid Build Coastguard Worker >>> opset_version = 15 1716*da0073e9SAndroid Build Coastguard Worker >>> # Define a custom symbolic function for aten::relu. 1717*da0073e9SAndroid Build Coastguard Worker >>> # The custom symbolic function is incorrect, which will result in mismatches. 1718*da0073e9SAndroid Build Coastguard Worker >>> def incorrect_relu_symbolic_function(g, self): 1719*da0073e9SAndroid Build Coastguard Worker ... return self 1720*da0073e9SAndroid Build Coastguard Worker >>> torch.onnx.register_custom_op_symbolic( 1721*da0073e9SAndroid Build Coastguard Worker ... "aten::relu", 1722*da0073e9SAndroid Build Coastguard Worker ... incorrect_relu_symbolic_function, 1723*da0073e9SAndroid Build Coastguard Worker ... opset_version=opset_version, 1724*da0073e9SAndroid Build Coastguard Worker ... ) 1725*da0073e9SAndroid Build Coastguard Worker >>> class Model(torch.nn.Module): 1726*da0073e9SAndroid Build Coastguard Worker ... def __init__(self) -> None: 1727*da0073e9SAndroid Build Coastguard Worker ... super().__init__() 1728*da0073e9SAndroid Build Coastguard Worker ... self.layers = torch.nn.Sequential( 1729*da0073e9SAndroid Build Coastguard Worker ... torch.nn.Linear(3, 4), 1730*da0073e9SAndroid Build Coastguard Worker ... torch.nn.ReLU(), 1731*da0073e9SAndroid Build Coastguard Worker ... torch.nn.Linear(4, 5), 1732*da0073e9SAndroid Build Coastguard Worker ... torch.nn.ReLU(), 1733*da0073e9SAndroid Build Coastguard Worker ... torch.nn.Linear(5, 6), 1734*da0073e9SAndroid Build Coastguard Worker ... ) 1735*da0073e9SAndroid Build Coastguard Worker ... def forward(self, x): 1736*da0073e9SAndroid Build Coastguard Worker ... return self.layers(x) 1737*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX) 1738*da0073e9SAndroid Build Coastguard Worker >>> graph_info = torch.onnx.verification.find_mismatch( 1739*da0073e9SAndroid Build Coastguard Worker ... Model(), 1740*da0073e9SAndroid Build Coastguard Worker ... (torch.randn(2, 3),), 1741*da0073e9SAndroid Build Coastguard Worker ... opset_version=opset_version, 1742*da0073e9SAndroid Build Coastguard Worker ... ) 1743*da0073e9SAndroid Build Coastguard Worker ===================== Mismatch info for graph partition : ====================== 1744*da0073e9SAndroid Build Coastguard Worker ================================ Mismatch error ================================ 1745*da0073e9SAndroid Build Coastguard Worker Tensor-likes are not close! 1746*da0073e9SAndroid Build Coastguard Worker Mismatched elements: 12 / 12 (100.0%) 1747*da0073e9SAndroid Build Coastguard Worker Greatest absolute difference: 0.2328854203224182 at index (1, 2) (up to 1e-07 allowed) 1748*da0073e9SAndroid Build Coastguard Worker Greatest relative difference: 0.699536174352349 at index (1, 3) (up to 0.001 allowed) 1749*da0073e9SAndroid Build Coastguard Worker ==================================== Tree: ===================================== 1750*da0073e9SAndroid Build Coastguard Worker 5 X __2 X __1 \u2713 1751*da0073e9SAndroid Build Coastguard Worker id: | id: 0 | id: 00 1752*da0073e9SAndroid Build Coastguard Worker | | 1753*da0073e9SAndroid Build Coastguard Worker | |__1 X (aten::relu) 1754*da0073e9SAndroid Build Coastguard Worker | id: 01 1755*da0073e9SAndroid Build Coastguard Worker | 1756*da0073e9SAndroid Build Coastguard Worker |__3 X __1 \u2713 1757*da0073e9SAndroid Build Coastguard Worker id: 1 | id: 10 1758*da0073e9SAndroid Build Coastguard Worker | 1759*da0073e9SAndroid Build Coastguard Worker |__2 X __1 X (aten::relu) 1760*da0073e9SAndroid Build Coastguard Worker id: 11 | id: 110 1761*da0073e9SAndroid Build Coastguard Worker | 1762*da0073e9SAndroid Build Coastguard Worker |__1 \u2713 1763*da0073e9SAndroid Build Coastguard Worker id: 111 1764*da0073e9SAndroid Build Coastguard Worker =========================== Mismatch leaf subgraphs: =========================== 1765*da0073e9SAndroid Build Coastguard Worker ['01', '110'] 1766*da0073e9SAndroid Build Coastguard Worker ============================= Mismatch node kinds: ============================= 1767*da0073e9SAndroid Build Coastguard Worker {'aten::relu': 2} 1768*da0073e9SAndroid Build Coastguard Worker 1769*da0073e9SAndroid Build Coastguard Worker """ 1770*da0073e9SAndroid Build Coastguard Worker if options is None: 1771*da0073e9SAndroid Build Coastguard Worker options = VerificationOptions() 1772*da0073e9SAndroid Build Coastguard Worker if opset_version is None: 1773*da0073e9SAndroid Build Coastguard Worker opset_version = _constants.ONNX_DEFAULT_OPSET 1774*da0073e9SAndroid Build Coastguard Worker """From aten graph, do binary search on graph partition to find operator export discrepancy.""" 1775*da0073e9SAndroid Build Coastguard Worker # TODO: Copied from utils.py `export` until `_optimize_graph`. 1776*da0073e9SAndroid Build Coastguard Worker if training == torch.onnx.TrainingMode.TRAINING: 1777*da0073e9SAndroid Build Coastguard Worker model.train() 1778*da0073e9SAndroid Build Coastguard Worker elif training == torch.onnx.TrainingMode.EVAL: 1779*da0073e9SAndroid Build Coastguard Worker model.eval() 1780*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 1781*da0073e9SAndroid Build Coastguard Worker inputs_for_export = _prepare_input_for_export(input_args, {}) 1782*da0073e9SAndroid Build Coastguard Worker args = utils._decide_input_format(model, inputs_for_export) 1783*da0073e9SAndroid Build Coastguard Worker 1784*da0073e9SAndroid Build Coastguard Worker model = utils._pre_trace_quant_model(model, args) 1785*da0073e9SAndroid Build Coastguard Worker graph, params, torch_out, module = utils._create_jit_graph(model, args) 1786*da0073e9SAndroid Build Coastguard Worker params_dict = utils._get_named_param_dict(graph, params) 1787*da0073e9SAndroid Build Coastguard Worker 1788*da0073e9SAndroid Build Coastguard Worker utils._apply_friendly_debug_names(graph, params_dict) 1789*da0073e9SAndroid Build Coastguard Worker 1790*da0073e9SAndroid Build Coastguard Worker graph_info = GraphInfo( 1791*da0073e9SAndroid Build Coastguard Worker graph, 1792*da0073e9SAndroid Build Coastguard Worker input_args, 1793*da0073e9SAndroid Build Coastguard Worker params_dict, 1794*da0073e9SAndroid Build Coastguard Worker _experimental.ExportOptions( 1795*da0073e9SAndroid Build Coastguard Worker do_constant_folding=do_constant_folding, 1796*da0073e9SAndroid Build Coastguard Worker training=training, 1797*da0073e9SAndroid Build Coastguard Worker opset_version=opset_version, 1798*da0073e9SAndroid Build Coastguard Worker keep_initializers_as_inputs=keep_initializers_as_inputs, 1799*da0073e9SAndroid Build Coastguard Worker verbose=verbose, 1800*da0073e9SAndroid Build Coastguard Worker ), 1801*da0073e9SAndroid Build Coastguard Worker ) 1802*da0073e9SAndroid Build Coastguard Worker graph_info.find_mismatch(options) 1803*da0073e9SAndroid Build Coastguard Worker graph_info.pretty_print_mismatch() 1804*da0073e9SAndroid Build Coastguard Worker graph_info.pretty_print_tree() 1805*da0073e9SAndroid Build Coastguard Worker 1806*da0073e9SAndroid Build Coastguard Worker return graph_info 1807