xref: /aosp_15_r20/external/pytorch/torch/onnx/_internal/fx/diagnostics.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4import dataclasses
5import functools
6from typing import Any, TYPE_CHECKING
7
8import onnxscript  # type: ignore[import]
9from onnxscript.function_libs.torch_lib import graph_building  # type: ignore[import]
10
11import torch
12import torch.fx
13from torch.onnx._internal import diagnostics
14from torch.onnx._internal.diagnostics import infra
15from torch.onnx._internal.diagnostics.infra import decorator, formatter
16from torch.onnx._internal.fx import registration, type_utils as fx_type_utils
17
18
19if TYPE_CHECKING:
20    import logging
21
22# NOTE: The following limits are for the number of items to display in diagnostics for
23# a list, tuple or dict. The limit is picked such that common useful scenarios such as
24# operator arguments are covered, while preventing excessive processing loads on considerably
25# large containers such as the dictionary mapping from fx to onnx nodes.
26_CONTAINER_ITEM_LIMIT: int = 10
27
28# NOTE(bowbao): This is a shim over `torch.onnx._internal.diagnostics`, which is
29# used in `torch.onnx`, and loaded with `torch`. Hence anything related to `onnxscript`
30# cannot be put there.
31
32# [NOTE: `dynamo_export` diagnostics logging]
33# The 'dynamo_export' diagnostics leverages the PT2 artifact logger to handle the verbosity
34# level of logs that are recorded in each SARIF log diagnostic. In addition to SARIF log,
35# terminal logging is by default disabled. Terminal logging can be activated by setting
36# the environment variable `TORCH_LOGS="onnx_diagnostics"`. When the environment variable
37# is set, it also fixes logging level to `logging.DEBUG`, overriding the verbosity level
38# specified in the diagnostic options.
39# See `torch/_logging/__init__.py` for more on PT2 logging.
40_ONNX_DIAGNOSTICS_ARTIFACT_LOGGER_NAME = "onnx_diagnostics"
41diagnostic_logger = torch._logging.getArtifactLogger(
42    "torch.onnx", _ONNX_DIAGNOSTICS_ARTIFACT_LOGGER_NAME
43)
44
45
46def is_onnx_diagnostics_log_artifact_enabled() -> bool:
47    return torch._logging._internal.log_state.is_artifact_enabled(
48        _ONNX_DIAGNOSTICS_ARTIFACT_LOGGER_NAME
49    )
50
51
52@functools.singledispatch
53def _format_argument(obj: Any) -> str:
54    return formatter.format_argument(obj)
55
56
57def format_argument(obj: Any) -> str:
58    formatter = _format_argument.dispatch(type(obj))
59    return formatter(obj)
60
61
62# NOTE: EDITING BELOW? READ THIS FIRST!
63#
64# The below functions register the `format_argument` function for different types via
65# `functools.singledispatch` registry. These are invoked by the diagnostics system
66# when recording function arguments and return values as part of a diagnostic.
67# Hence, code with heavy workload should be avoided. Things to avoid for example:
68# `torch.fx.GraphModule.print_readable()`.
69
70
71@_format_argument.register
72def _torch_nn_module(obj: torch.nn.Module) -> str:
73    return f"torch.nn.Module({obj.__class__.__name__})"
74
75
76@_format_argument.register
77def _torch_fx_graph_module(obj: torch.fx.GraphModule) -> str:
78    return f"torch.fx.GraphModule({obj.__class__.__name__})"
79
80
81@_format_argument.register
82def _torch_fx_node(obj: torch.fx.Node) -> str:
83    node_string = f"fx.Node({obj.target})[{obj.op}]:"
84    if "val" not in obj.meta:
85        return node_string + "None"
86    return node_string + format_argument(obj.meta["val"])
87
88
89@_format_argument.register
90def _torch_fx_symbolic_bool(obj: torch.SymBool) -> str:
91    return f"SymBool({obj})"
92
93
94@_format_argument.register
95def _torch_fx_symbolic_int(obj: torch.SymInt) -> str:
96    return f"SymInt({obj})"
97
98
99@_format_argument.register
100def _torch_fx_symbolic_float(obj: torch.SymFloat) -> str:
101    return f"SymFloat({obj})"
102
103
104@_format_argument.register
105def _torch_tensor(obj: torch.Tensor) -> str:
106    return f"Tensor({fx_type_utils.from_torch_dtype_to_abbr(obj.dtype)}{_stringify_shape(obj.shape)})"
107
108
109@_format_argument.register
110def _int(obj: int) -> str:
111    return str(obj)
112
113
114@_format_argument.register
115def _float(obj: float) -> str:
116    return str(obj)
117
118
119@_format_argument.register
120def _bool(obj: bool) -> str:
121    return str(obj)
122
123
124@_format_argument.register
125def _str(obj: str) -> str:
126    return obj
127
128
129@_format_argument.register
130def _registration_onnx_function(obj: registration.ONNXFunction) -> str:
131    # TODO: Compact display of `param_schema`.
132    return f"registration.ONNXFunction({obj.op_full_name}, is_custom={obj.is_custom}, is_complex={obj.is_complex})"
133
134
135@_format_argument.register
136def _list(obj: list) -> str:
137    list_string = f"List[length={len(obj)}](\n"
138    if not obj:
139        return list_string + "None)"
140    for i, item in enumerate(obj):
141        if i >= _CONTAINER_ITEM_LIMIT:
142            # NOTE: Print only first _CONTAINER_ITEM_LIMIT items.
143            list_string += "...,\n"
144            break
145        list_string += f"{format_argument(item)},\n"
146    return list_string + ")"
147
148
149@_format_argument.register
150def _tuple(obj: tuple) -> str:
151    tuple_string = f"Tuple[length={len(obj)}](\n"
152    if not obj:
153        return tuple_string + "None)"
154    for i, item in enumerate(obj):
155        if i >= _CONTAINER_ITEM_LIMIT:
156            # NOTE: Print only first _CONTAINER_ITEM_LIMIT items.
157            tuple_string += "...,\n"
158            break
159        tuple_string += f"{format_argument(item)},\n"
160    return tuple_string + ")"
161
162
163@_format_argument.register
164def _dict(obj: dict) -> str:
165    dict_string = f"Dict[length={len(obj)}](\n"
166    if not obj:
167        return dict_string + "None)"
168    for i, (key, value) in enumerate(obj.items()):
169        if i >= _CONTAINER_ITEM_LIMIT:
170            # NOTE: Print only first _CONTAINER_ITEM_LIMIT items.
171            dict_string += "...\n"
172            break
173        dict_string += f"{key}: {format_argument(value)},\n"
174    return dict_string + ")"
175
176
177@_format_argument.register
178def _torch_nn_parameter(obj: torch.nn.Parameter) -> str:
179    return f"Parameter({format_argument(obj.data)})"
180
181
182@_format_argument.register
183def _onnxscript_torch_script_tensor(obj: graph_building.TorchScriptTensor) -> str:
184    return f"`TorchScriptTensor({fx_type_utils.from_torch_dtype_to_abbr(obj.dtype)}{_stringify_shape(obj.shape)})`"  # type: ignore[arg-type]  # noqa: B950
185
186
187@_format_argument.register
188def _onnxscript_onnx_function(obj: onnxscript.OnnxFunction) -> str:
189    return f"`OnnxFunction({obj.name})`"
190
191
192@_format_argument.register
193def _onnxscript_traced_onnx_function(obj: onnxscript.TracedOnnxFunction) -> str:
194    return f"`TracedOnnxFunction({obj.name})`"
195
196
197# from torch/fx/graph.py to follow torch format
198def _stringify_shape(shape: torch.Size | None) -> str:
199    if shape is None:
200        return ""
201    return f"[{', '.join(str(x) for x in shape)}]"
202
203
204rules = diagnostics.rules
205levels = diagnostics.levels
206RuntimeErrorWithDiagnostic = infra.RuntimeErrorWithDiagnostic
207LazyString = formatter.LazyString
208DiagnosticOptions = infra.DiagnosticOptions
209
210
211@dataclasses.dataclass
212class Diagnostic(infra.Diagnostic):
213    logger: logging.Logger = dataclasses.field(init=False, default=diagnostic_logger)
214
215    def log(self, level: int, message: str, *args, **kwargs) -> None:
216        if self.logger.isEnabledFor(level):
217            formatted_message = message % args
218            if is_onnx_diagnostics_log_artifact_enabled():
219                # Only log to terminal if artifact is enabled.
220                # See [NOTE: `dynamo_export` diagnostics logging] for details.
221                self.logger.log(level, formatted_message, **kwargs)
222
223            self.additional_messages.append(formatted_message)
224
225
226@dataclasses.dataclass
227class DiagnosticContext(infra.DiagnosticContext[Diagnostic]):
228    logger: logging.Logger = dataclasses.field(init=False, default=diagnostic_logger)
229    _bound_diagnostic_type: type[Diagnostic] = dataclasses.field(
230        init=False, default=Diagnostic
231    )
232
233    def __enter__(self):
234        self._previous_log_level = self.logger.level
235        # Adjust the logger level based on `options.verbosity_level` and the environment
236        # variable `TORCH_LOGS`. See [NOTE: `dynamo_export` diagnostics logging] for details.
237        if not is_onnx_diagnostics_log_artifact_enabled():
238            return super().__enter__()
239        else:
240            return self
241
242
243diagnose_call = functools.partial(
244    decorator.diagnose_call,
245    diagnostic_type=Diagnostic,
246    format_argument=format_argument,
247)
248
249
250@dataclasses.dataclass
251class UnsupportedFxNodeDiagnostic(Diagnostic):
252    unsupported_fx_node: torch.fx.Node | None = None
253
254    def __post_init__(self) -> None:
255        super().__post_init__()
256        # NOTE: This is a hack to make sure that the additional fields must be set and
257        # not None. Ideally they should not be set as optional. But this is a known
258        # limitation with `dataclasses`. Resolvable in Python 3.10 with `kw_only=True`.
259        # https://stackoverflow.com/questions/69711886/python-dataclasses-inheritance-and-default-values
260        if self.unsupported_fx_node is None:
261            raise ValueError("unsupported_fx_node must be specified.")
262