1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4import dataclasses 5import functools 6from typing import Any, TYPE_CHECKING 7 8import onnxscript # type: ignore[import] 9from onnxscript.function_libs.torch_lib import graph_building # type: ignore[import] 10 11import torch 12import torch.fx 13from torch.onnx._internal import diagnostics 14from torch.onnx._internal.diagnostics import infra 15from torch.onnx._internal.diagnostics.infra import decorator, formatter 16from torch.onnx._internal.fx import registration, type_utils as fx_type_utils 17 18 19if TYPE_CHECKING: 20 import logging 21 22# NOTE: The following limits are for the number of items to display in diagnostics for 23# a list, tuple or dict. The limit is picked such that common useful scenarios such as 24# operator arguments are covered, while preventing excessive processing loads on considerably 25# large containers such as the dictionary mapping from fx to onnx nodes. 26_CONTAINER_ITEM_LIMIT: int = 10 27 28# NOTE(bowbao): This is a shim over `torch.onnx._internal.diagnostics`, which is 29# used in `torch.onnx`, and loaded with `torch`. Hence anything related to `onnxscript` 30# cannot be put there. 31 32# [NOTE: `dynamo_export` diagnostics logging] 33# The 'dynamo_export' diagnostics leverages the PT2 artifact logger to handle the verbosity 34# level of logs that are recorded in each SARIF log diagnostic. In addition to SARIF log, 35# terminal logging is by default disabled. Terminal logging can be activated by setting 36# the environment variable `TORCH_LOGS="onnx_diagnostics"`. When the environment variable 37# is set, it also fixes logging level to `logging.DEBUG`, overriding the verbosity level 38# specified in the diagnostic options. 39# See `torch/_logging/__init__.py` for more on PT2 logging. 40_ONNX_DIAGNOSTICS_ARTIFACT_LOGGER_NAME = "onnx_diagnostics" 41diagnostic_logger = torch._logging.getArtifactLogger( 42 "torch.onnx", _ONNX_DIAGNOSTICS_ARTIFACT_LOGGER_NAME 43) 44 45 46def is_onnx_diagnostics_log_artifact_enabled() -> bool: 47 return torch._logging._internal.log_state.is_artifact_enabled( 48 _ONNX_DIAGNOSTICS_ARTIFACT_LOGGER_NAME 49 ) 50 51 52@functools.singledispatch 53def _format_argument(obj: Any) -> str: 54 return formatter.format_argument(obj) 55 56 57def format_argument(obj: Any) -> str: 58 formatter = _format_argument.dispatch(type(obj)) 59 return formatter(obj) 60 61 62# NOTE: EDITING BELOW? READ THIS FIRST! 63# 64# The below functions register the `format_argument` function for different types via 65# `functools.singledispatch` registry. These are invoked by the diagnostics system 66# when recording function arguments and return values as part of a diagnostic. 67# Hence, code with heavy workload should be avoided. Things to avoid for example: 68# `torch.fx.GraphModule.print_readable()`. 69 70 71@_format_argument.register 72def _torch_nn_module(obj: torch.nn.Module) -> str: 73 return f"torch.nn.Module({obj.__class__.__name__})" 74 75 76@_format_argument.register 77def _torch_fx_graph_module(obj: torch.fx.GraphModule) -> str: 78 return f"torch.fx.GraphModule({obj.__class__.__name__})" 79 80 81@_format_argument.register 82def _torch_fx_node(obj: torch.fx.Node) -> str: 83 node_string = f"fx.Node({obj.target})[{obj.op}]:" 84 if "val" not in obj.meta: 85 return node_string + "None" 86 return node_string + format_argument(obj.meta["val"]) 87 88 89@_format_argument.register 90def _torch_fx_symbolic_bool(obj: torch.SymBool) -> str: 91 return f"SymBool({obj})" 92 93 94@_format_argument.register 95def _torch_fx_symbolic_int(obj: torch.SymInt) -> str: 96 return f"SymInt({obj})" 97 98 99@_format_argument.register 100def _torch_fx_symbolic_float(obj: torch.SymFloat) -> str: 101 return f"SymFloat({obj})" 102 103 104@_format_argument.register 105def _torch_tensor(obj: torch.Tensor) -> str: 106 return f"Tensor({fx_type_utils.from_torch_dtype_to_abbr(obj.dtype)}{_stringify_shape(obj.shape)})" 107 108 109@_format_argument.register 110def _int(obj: int) -> str: 111 return str(obj) 112 113 114@_format_argument.register 115def _float(obj: float) -> str: 116 return str(obj) 117 118 119@_format_argument.register 120def _bool(obj: bool) -> str: 121 return str(obj) 122 123 124@_format_argument.register 125def _str(obj: str) -> str: 126 return obj 127 128 129@_format_argument.register 130def _registration_onnx_function(obj: registration.ONNXFunction) -> str: 131 # TODO: Compact display of `param_schema`. 132 return f"registration.ONNXFunction({obj.op_full_name}, is_custom={obj.is_custom}, is_complex={obj.is_complex})" 133 134 135@_format_argument.register 136def _list(obj: list) -> str: 137 list_string = f"List[length={len(obj)}](\n" 138 if not obj: 139 return list_string + "None)" 140 for i, item in enumerate(obj): 141 if i >= _CONTAINER_ITEM_LIMIT: 142 # NOTE: Print only first _CONTAINER_ITEM_LIMIT items. 143 list_string += "...,\n" 144 break 145 list_string += f"{format_argument(item)},\n" 146 return list_string + ")" 147 148 149@_format_argument.register 150def _tuple(obj: tuple) -> str: 151 tuple_string = f"Tuple[length={len(obj)}](\n" 152 if not obj: 153 return tuple_string + "None)" 154 for i, item in enumerate(obj): 155 if i >= _CONTAINER_ITEM_LIMIT: 156 # NOTE: Print only first _CONTAINER_ITEM_LIMIT items. 157 tuple_string += "...,\n" 158 break 159 tuple_string += f"{format_argument(item)},\n" 160 return tuple_string + ")" 161 162 163@_format_argument.register 164def _dict(obj: dict) -> str: 165 dict_string = f"Dict[length={len(obj)}](\n" 166 if not obj: 167 return dict_string + "None)" 168 for i, (key, value) in enumerate(obj.items()): 169 if i >= _CONTAINER_ITEM_LIMIT: 170 # NOTE: Print only first _CONTAINER_ITEM_LIMIT items. 171 dict_string += "...\n" 172 break 173 dict_string += f"{key}: {format_argument(value)},\n" 174 return dict_string + ")" 175 176 177@_format_argument.register 178def _torch_nn_parameter(obj: torch.nn.Parameter) -> str: 179 return f"Parameter({format_argument(obj.data)})" 180 181 182@_format_argument.register 183def _onnxscript_torch_script_tensor(obj: graph_building.TorchScriptTensor) -> str: 184 return f"`TorchScriptTensor({fx_type_utils.from_torch_dtype_to_abbr(obj.dtype)}{_stringify_shape(obj.shape)})`" # type: ignore[arg-type] # noqa: B950 185 186 187@_format_argument.register 188def _onnxscript_onnx_function(obj: onnxscript.OnnxFunction) -> str: 189 return f"`OnnxFunction({obj.name})`" 190 191 192@_format_argument.register 193def _onnxscript_traced_onnx_function(obj: onnxscript.TracedOnnxFunction) -> str: 194 return f"`TracedOnnxFunction({obj.name})`" 195 196 197# from torch/fx/graph.py to follow torch format 198def _stringify_shape(shape: torch.Size | None) -> str: 199 if shape is None: 200 return "" 201 return f"[{', '.join(str(x) for x in shape)}]" 202 203 204rules = diagnostics.rules 205levels = diagnostics.levels 206RuntimeErrorWithDiagnostic = infra.RuntimeErrorWithDiagnostic 207LazyString = formatter.LazyString 208DiagnosticOptions = infra.DiagnosticOptions 209 210 211@dataclasses.dataclass 212class Diagnostic(infra.Diagnostic): 213 logger: logging.Logger = dataclasses.field(init=False, default=diagnostic_logger) 214 215 def log(self, level: int, message: str, *args, **kwargs) -> None: 216 if self.logger.isEnabledFor(level): 217 formatted_message = message % args 218 if is_onnx_diagnostics_log_artifact_enabled(): 219 # Only log to terminal if artifact is enabled. 220 # See [NOTE: `dynamo_export` diagnostics logging] for details. 221 self.logger.log(level, formatted_message, **kwargs) 222 223 self.additional_messages.append(formatted_message) 224 225 226@dataclasses.dataclass 227class DiagnosticContext(infra.DiagnosticContext[Diagnostic]): 228 logger: logging.Logger = dataclasses.field(init=False, default=diagnostic_logger) 229 _bound_diagnostic_type: type[Diagnostic] = dataclasses.field( 230 init=False, default=Diagnostic 231 ) 232 233 def __enter__(self): 234 self._previous_log_level = self.logger.level 235 # Adjust the logger level based on `options.verbosity_level` and the environment 236 # variable `TORCH_LOGS`. See [NOTE: `dynamo_export` diagnostics logging] for details. 237 if not is_onnx_diagnostics_log_artifact_enabled(): 238 return super().__enter__() 239 else: 240 return self 241 242 243diagnose_call = functools.partial( 244 decorator.diagnose_call, 245 diagnostic_type=Diagnostic, 246 format_argument=format_argument, 247) 248 249 250@dataclasses.dataclass 251class UnsupportedFxNodeDiagnostic(Diagnostic): 252 unsupported_fx_node: torch.fx.Node | None = None 253 254 def __post_init__(self) -> None: 255 super().__post_init__() 256 # NOTE: This is a hack to make sure that the additional fields must be set and 257 # not None. Ideally they should not be set as optional. But this is a known 258 # limitation with `dataclasses`. Resolvable in Python 3.10 with `kw_only=True`. 259 # https://stackoverflow.com/questions/69711886/python-dataclasses-inheritance-and-default-values 260 if self.unsupported_fx_node is None: 261 raise ValueError("unsupported_fx_node must be specified.") 262