1"""ONNX exporter exceptions.""" 2 3from __future__ import annotations 4 5 6__all__ = [ 7 "OnnxExporterWarning", 8 "SymbolicValueError", 9 "UnsupportedOperatorError", 10] 11 12import textwrap 13from typing import TYPE_CHECKING 14 15 16if TYPE_CHECKING: 17 from torch import _C 18 19 20class OnnxExporterWarning(UserWarning): 21 """Warnings in the ONNX exporter.""" 22 23 24class OnnxExporterError(RuntimeError): 25 """Errors raised by the ONNX exporter. This is the base class for all exporter errors.""" 26 27 28class UnsupportedOperatorError(OnnxExporterError): 29 """Raised when an operator is unsupported by the exporter.""" 30 31 # NOTE: This is legacy and is only used by the torchscript exporter 32 # Clean up when the torchscript exporter is removed 33 def __init__(self, name: str, version: int, supported_version: int | None): 34 from torch.onnx import _constants 35 from torch.onnx._internal import diagnostics 36 37 if supported_version is not None: 38 diagnostic_rule: diagnostics.infra.Rule = ( 39 diagnostics.rules.operator_supported_in_newer_opset_version 40 ) 41 msg = diagnostic_rule.format_message(name, version, supported_version) 42 diagnostics.diagnose(diagnostic_rule, diagnostics.levels.ERROR, msg) 43 else: 44 if name.startswith(("aten::", "prim::", "quantized::")): 45 diagnostic_rule = diagnostics.rules.missing_standard_symbolic_function 46 msg = diagnostic_rule.format_message( 47 name, version, _constants.PYTORCH_GITHUB_ISSUES_URL 48 ) 49 diagnostics.diagnose(diagnostic_rule, diagnostics.levels.ERROR, msg) 50 else: 51 diagnostic_rule = diagnostics.rules.missing_custom_symbolic_function 52 msg = diagnostic_rule.format_message(name) 53 diagnostics.diagnose(diagnostic_rule, diagnostics.levels.ERROR, msg) 54 super().__init__(msg) 55 56 57class SymbolicValueError(OnnxExporterError): 58 """Errors around TorchScript values and nodes.""" 59 60 # NOTE: This is legacy and is only used by the torchscript exporter 61 # Clean up when the torchscript exporter is removed 62 def __init__(self, msg: str, value: _C.Value): 63 message = ( 64 f"{msg} [Caused by the value '{value}' (type '{value.type()}') in the " 65 f"TorchScript graph. The containing node has kind '{value.node().kind()}'.] " 66 ) 67 68 code_location = value.node().sourceRange() 69 if code_location: 70 message += f"\n (node defined in {code_location})" 71 72 try: 73 # Add its input and output to the message. 74 message += "\n\n" 75 message += textwrap.indent( 76 ( 77 "Inputs:\n" 78 + ( 79 "\n".join( 80 f" #{i}: {input_} (type '{input_.type()}')" 81 for i, input_ in enumerate(value.node().inputs()) 82 ) 83 or " Empty" 84 ) 85 + "\n" 86 + "Outputs:\n" 87 + ( 88 "\n".join( 89 f" #{i}: {output} (type '{output.type()}')" 90 for i, output in enumerate(value.node().outputs()) 91 ) 92 or " Empty" 93 ) 94 ), 95 " ", 96 ) 97 except AttributeError: 98 message += ( 99 " Failed to obtain its input and output for debugging. " 100 "Please refer to the TorchScript graph for debugging information." 101 ) 102 103 super().__init__(message) 104