xref: /aosp_15_r20/external/pytorch/torch/onnx/errors.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""ONNX exporter exceptions."""
2
3from __future__ import annotations
4
5
6__all__ = [
7    "OnnxExporterWarning",
8    "SymbolicValueError",
9    "UnsupportedOperatorError",
10]
11
12import textwrap
13from typing import TYPE_CHECKING
14
15
16if TYPE_CHECKING:
17    from torch import _C
18
19
20class OnnxExporterWarning(UserWarning):
21    """Warnings in the ONNX exporter."""
22
23
24class OnnxExporterError(RuntimeError):
25    """Errors raised by the ONNX exporter. This is the base class for all exporter errors."""
26
27
28class UnsupportedOperatorError(OnnxExporterError):
29    """Raised when an operator is unsupported by the exporter."""
30
31    # NOTE: This is legacy and is only used by the torchscript exporter
32    # Clean up when the torchscript exporter is removed
33    def __init__(self, name: str, version: int, supported_version: int | None):
34        from torch.onnx import _constants
35        from torch.onnx._internal import diagnostics
36
37        if supported_version is not None:
38            diagnostic_rule: diagnostics.infra.Rule = (
39                diagnostics.rules.operator_supported_in_newer_opset_version
40            )
41            msg = diagnostic_rule.format_message(name, version, supported_version)
42            diagnostics.diagnose(diagnostic_rule, diagnostics.levels.ERROR, msg)
43        else:
44            if name.startswith(("aten::", "prim::", "quantized::")):
45                diagnostic_rule = diagnostics.rules.missing_standard_symbolic_function
46                msg = diagnostic_rule.format_message(
47                    name, version, _constants.PYTORCH_GITHUB_ISSUES_URL
48                )
49                diagnostics.diagnose(diagnostic_rule, diagnostics.levels.ERROR, msg)
50            else:
51                diagnostic_rule = diagnostics.rules.missing_custom_symbolic_function
52                msg = diagnostic_rule.format_message(name)
53                diagnostics.diagnose(diagnostic_rule, diagnostics.levels.ERROR, msg)
54        super().__init__(msg)
55
56
57class SymbolicValueError(OnnxExporterError):
58    """Errors around TorchScript values and nodes."""
59
60    # NOTE: This is legacy and is only used by the torchscript exporter
61    # Clean up when the torchscript exporter is removed
62    def __init__(self, msg: str, value: _C.Value):
63        message = (
64            f"{msg}  [Caused by the value '{value}' (type '{value.type()}') in the "
65            f"TorchScript graph. The containing node has kind '{value.node().kind()}'.] "
66        )
67
68        code_location = value.node().sourceRange()
69        if code_location:
70            message += f"\n    (node defined in {code_location})"
71
72        try:
73            # Add its input and output to the message.
74            message += "\n\n"
75            message += textwrap.indent(
76                (
77                    "Inputs:\n"
78                    + (
79                        "\n".join(
80                            f"    #{i}: {input_}  (type '{input_.type()}')"
81                            for i, input_ in enumerate(value.node().inputs())
82                        )
83                        or "    Empty"
84                    )
85                    + "\n"
86                    + "Outputs:\n"
87                    + (
88                        "\n".join(
89                            f"    #{i}: {output}  (type '{output.type()}')"
90                            for i, output in enumerate(value.node().outputs())
91                        )
92                        or "    Empty"
93                    )
94                ),
95                "    ",
96            )
97        except AttributeError:
98            message += (
99                " Failed to obtain its input and output for debugging. "
100                "Please refer to the TorchScript graph for debugging information."
101            )
102
103        super().__init__(message)
104