1# mypy: allow-untyped-defs 2import os 3import textwrap 4from enum import auto, Enum 5from traceback import extract_stack, format_exc, format_list, StackSummary 6from typing import Any, cast, NoReturn, Optional, Tuple, TYPE_CHECKING 7 8import torch._guards 9 10from . import config 11from .utils import counters 12 13 14if TYPE_CHECKING: 15 from torch._guards import CompileId 16 17 18def exportdb_error_message(case_name): 19 return ( 20 "For more information about this error, see: " 21 + "https://pytorch.org/docs/main/generated/exportdb/index.html#" 22 + case_name.replace("_", "-") 23 ) 24 25 26import logging 27 28 29log = logging.getLogger(__name__) 30graph_breaks_log = torch._logging.getArtifactLogger(__name__, "graph_breaks") 31 32 33class TorchDynamoException(RuntimeError): 34 pass 35 36 37class InternalTorchDynamoError(TorchDynamoException): 38 pass 39 40 41class RestartAnalysis(TorchDynamoException): 42 restart_reason: str 43 44 def __init__(self, *args, restart_reason=None) -> None: 45 self.restart_reason = restart_reason 46 super().__init__(*args) 47 48 49class SpeculationRestartAnalysis(RestartAnalysis): 50 pass 51 52 53class UnspecializeRestartAnalysis(RestartAnalysis): 54 pass 55 56 57class CompileCollectiveRestartAnalysis(RestartAnalysis): 58 pass 59 60 61class SkipFrame(TorchDynamoException): 62 pass 63 64 65class TorchRuntimeError(TorchDynamoException): 66 pass 67 68 69class InvalidBackend(TorchDynamoException): 70 def __init__(self, name) -> None: 71 super().__init__( 72 f"Invalid backend: {name!r}, see `torch._dynamo.list_backends()` for available backends." 73 ) 74 75 76class ResetRequired(TorchDynamoException): 77 def __init__(self) -> None: 78 super().__init__( 79 textwrap.dedent( 80 """ 81 Must call `torch._dynamo.reset()` before changing backends. Detected two calls to 82 `torch.compile()` with a different backend compiler arguments. 83 """ 84 ) 85 ) 86 87 88class BackendCompilerFailed(TorchDynamoException): 89 def __init__(self, backend_fn, inner_exception) -> None: 90 self.backend_name = getattr(backend_fn, "__name__", "?") 91 self.inner_exception = inner_exception 92 msg = f"backend={self.backend_name!r} raised:\n{type(inner_exception).__name__}: {inner_exception}" 93 super().__init__(msg) 94 95 96class Unsupported(TorchDynamoException): 97 def __init__(self, msg, *, case_name=None) -> None: 98 super().__init__(msg) 99 self.real_stack = torch._guards.TracingContext.extract_stack() 100 self.msg = msg 101 self.category: Optional[str] = None 102 self.add_to_stats() 103 self.case_name: Optional[str] = case_name 104 105 def remove_from_stats(self): 106 assert self.category is not None 107 counters[self.category][self.msg] -= 1 108 if counters[self.category][self.msg] <= 0: 109 del counters[self.category][self.msg] 110 111 def add_to_stats(self, category="unimplemented"): 112 self.category = category 113 counters[category][self.msg] += 1 114 115 116class RecompileError(TorchDynamoException): 117 pass 118 119 120class ArgsMismatchError(Unsupported): 121 def __init__(self, msg) -> None: 122 super().__init__(msg) 123 124 125class AttributeMutationError(Unsupported): 126 def __init__(self, msg) -> None: 127 super().__init__(msg) 128 129 130class CondOpArgsMismatchError(ArgsMismatchError): 131 """ 132 Internal error from cond() due to arguments mismatch. 133 """ 134 135 def __init__(self, msg) -> None: 136 super().__init__(msg) 137 138 139class UserErrorType(Enum): 140 DYNAMIC_CONTROL_FLOW = auto() 141 ANTI_PATTERN = auto() 142 STANDARD_LIBRARY = auto() 143 CONSTRAINT_VIOLATION = auto() 144 DYNAMIC_DIM = auto() 145 INVALID_INPUT = auto() 146 INVALID_OUTPUT = auto() 147 148 149class UserError(Unsupported): 150 def __init__(self, error_type: UserErrorType, msg, case_name=None) -> None: 151 """ 152 Type of errors that would be valid in Eager, but not supported in TorchDynamo. 153 The error message should tell user about next actions. 154 155 error_type: Type of user error 156 msg: Actionable error message 157 case_name: (Optional) Unique name (snake case) for the usage example in exportdb. 158 """ 159 if case_name is not None: 160 assert isinstance(case_name, str) 161 if msg.endswith("."): 162 msg += " " 163 else: 164 msg += "\n" 165 msg += exportdb_error_message(case_name) 166 super().__init__(msg) 167 self.error_type = error_type 168 self.message = msg 169 170 171class SkipCodeRecursiveException(TorchDynamoException): 172 pass 173 174 175class CacheLimitExceeded(SkipCodeRecursiveException, Unsupported): 176 pass 177 178 179class UnsafeScriptObjectError(TorchDynamoException): 180 pass 181 182 183class UncapturedHigherOrderOpError(TorchDynamoException): 184 pass 185 186 187class IncorrectUsage(Exception): 188 pass 189 190 191class ObservedException(TorchDynamoException): 192 # An exception observed during the tracing. This exception is used by Dynamo to handle exceptions. 193 pass 194 195 196class ObservedUserStopIteration(ObservedException): 197 # An UserStopIteraion exception observed during the Dynamo tracing (e.g Dynamo tracing __next__) 198 value: Optional[Any] 199 200 # Reference `StopIteration_init` in CPython 201 # https://github.com/python/cpython/blob/3.11/Objects/exceptions.c#L568-L584 202 def __init__(self, *args, **kwargs) -> None: 203 super().__init__("unhandled `raise StopIteration`") 204 if len(args) > 0: 205 self.value = args[0] 206 else: 207 self.value = None 208 209 210class ObservedKeyError(ObservedException): 211 # A KeyError exception to be raised from inside Dynamo tracing. This can happen on dict __getitem__ 212 pass 213 214 215class ObservedAttributeError(ObservedException): 216 # An AttributeError exception to be raised from inside Dynamo tracing. This can happen on user defined object __getattr__ 217 pass 218 219 220observed_exception_map = { 221 StopIteration: ObservedUserStopIteration, 222 KeyError: ObservedKeyError, 223 AttributeError: ObservedAttributeError, 224} 225 226 227def raise_observed_exception(e, tx, vt): 228 from .variables import BuiltinVariable 229 230 # CPython here raises an exception. Since there is no python code, we have to manually setup the exception 231 # stack and raise the exception. 232 exception_vt = BuiltinVariable(e).call_function(vt, [], {}) 233 tx.exn_vt_stack.append(exception_vt) 234 raise observed_exception_map[e] 235 236 237def handle_observed_exception(tx): 238 # This is essentially exception handling code, equivalent of this pseudo code 239 # 240 # try: 241 # ... somebody raising StopIteration 242 # except StopIteration 243 # pass 244 # 245 # If this was going through the python code, we would have called exception_handler method, but FOR_ITER 246 # handles the exception completely in CPython. For example for 3.11, the resulting bytecode is 247 # 248 # 249 # 6 46 LOAD_GLOBAL 2 (StopIteration) 250 # 58 RAISE_VARARGS 1 251 # >> 60 PUSH_EXC_INFO 252 253 # 7 62 LOAD_GLOBAL 2 (StopIteration) 254 # 74 CHECK_EXC_MATCH 255 # 76 POP_JUMP_FORWARD_IF_FALSE 3 (to 84) 256 # 78 POP_TOP 257 258 # 8 80 POP_EXCEPT 259 # 260 261 # Fortunately this translates to a simple pop from the exn_vt_stack 262 tx.exn_vt_stack.pop() 263 264 265# These exceptions are ok to fallback to eager/graph_break. 266exceptions_allowed_to_be_fallback = ( 267 torch._subclasses.fake_tensor.DataDependentOutputException, 268 torch._subclasses.fake_tensor.DynamicOutputShapeException, 269 torch._subclasses.fake_tensor.UnsupportedOperatorException, 270 torch._subclasses.fake_tensor.UnsupportedFakeTensorException, 271) 272 273 274def unimplemented_with_warning(e: Exception, code, msg: str) -> NoReturn: 275 # This function calls unimplemented internally and eventually graph breaks 276 # or falls to eager. unimplemented itself does not print any user warnings, 277 # i.e., its very silent. This helper function is intended when an error is 278 # encountered in the torch.compile stack which is worth showing as warning 279 # to the user. For example, if AOT Autograd backend fails with a fake tensor 280 # exception, its ok to fallback to eager but not silently. Here, we can use 281 # this function to log the message and the stack trace. 282 graph_break_msg = format_error_msg_verbose(e, code) 283 graph_breaks_log.debug("%s", graph_break_msg) 284 log.warning(msg) 285 unimplemented(msg, from_exc=e) 286 287 288_NOTHING = object() 289 290 291def unimplemented( 292 msg: str, *, from_exc: Any = _NOTHING, case_name: Optional[str] = None 293) -> NoReturn: 294 assert msg != os.environ.get("BREAK", False) 295 if from_exc is not _NOTHING: 296 raise Unsupported(msg, case_name=case_name) from from_exc 297 raise Unsupported(msg, case_name=case_name) 298 299 300def warning(msg: str) -> None: 301 counters["warnings"][msg] += 1 302 assert msg != os.environ.get("BREAK", False) 303 304 305# KeyError has special handling for its args 306# see https://github.com/python/cpython/blob/3.11/Objects/exceptions.c#L2534 for details 307class KeyErrorMsg: 308 def __init__(self, value) -> None: 309 self.value = value 310 311 def __str__(self) -> str: 312 return str(self.value) 313 314 def __repr__(self) -> str: 315 return self.__str__() 316 317 318def augment_exc_message(exc: Exception, msg: str = "\n", export: bool = False) -> None: 319 import traceback 320 321 exc.innermost_user_frame_summary = None # type: ignore[attr-defined] 322 323 real_stack = get_real_stack(exc) 324 if real_stack is not None and len(real_stack) > 0: 325 exc.innermost_user_frame_summary = real_stack[-1] # type: ignore[attr-defined] 326 msg += f"\nfrom user code:\n {''.join(traceback.format_list(real_stack))}" 327 328 if config.replay_record_enabled and hasattr(exc, "record_filename"): 329 msg += f"\nLast frame execution written to {exc.record_filename}. To run only this frame while debugging, run\ 330 torch._dynamo.replay('{exc.record_filename}').\n" 331 332 if not config.verbose and hasattr(exc, "real_stack"): 333 msg += '\nSet TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information\n' 334 335 if hasattr(exc, "inner_exception") and hasattr( 336 exc.inner_exception, "minifier_path" 337 ): 338 if hasattr(exc.inner_exception, "buck_command"): 339 msg += ( 340 f"\nMinifier script written to {exc.inner_exception.minifier_path}. Run " 341 f"this buck command to find the smallest traced graph " 342 f"which reproduces this error: {exc.inner_exception.buck_command}\n" 343 ) 344 else: 345 msg += ( 346 f"\nMinifier script written to {exc.inner_exception.minifier_path}. Run " 347 "this script to find the smallest traced graph which reproduces this error.\n" 348 ) 349 350 if not config.suppress_errors and not export: 351 msg += ( 352 "\n\n" 353 "You can suppress this exception and fall back to eager by setting:\n" 354 " import torch._dynamo\n" 355 " torch._dynamo.config.suppress_errors = True\n" 356 ) 357 358 old_msg = "" if len(exc.args) == 0 else str(exc.args[0]) 359 360 if isinstance(exc, KeyError): 361 exc.args = (KeyErrorMsg(old_msg + msg),) + exc.args[1:] 362 else: 363 new_msg = old_msg + msg 364 exc.args = (new_msg,) + exc.args[1:] 365 366 367def get_exc_message( 368 e: Exception, compile_id: "CompileId" 369) -> Tuple[Optional[str], Optional[int]]: 370 filename = None 371 lineno = None 372 if e.innermost_user_frame_summary is not None: # type: ignore[attr-defined] 373 filename = e.innermost_user_frame_summary.filename # type: ignore[attr-defined] 374 lineno = e.innermost_user_frame_summary.lineno # type: ignore[attr-defined] 375 e.compile_id = compile_id # type: ignore[attr-defined] 376 return filename, lineno 377 378 379def get_real_stack(exc: Exception, frame=None) -> Optional[StackSummary]: 380 real_stack = getattr(exc, "real_stack", None) 381 if real_stack is None: 382 return None 383 384 # NB: it's possible for real_stack to be []; we still attempt to 385 # report a stack anyway because the stack_above_dynamo may still 386 # be useful for debugging 387 388 stack_above_dynamo = [] 389 if frame is not None: 390 # NB: frame is PyInterpreterFrame on Python 3.11 and later, 391 # not a TRUE frame object. You can't actually feed it 392 # to traceback because it doesn't have enough information. 393 # To solve this problem, we technically should just materialize 394 # the frame, the same way _PyFrame_GetFrameObject would do 395 # (but we cannot actually do this, because this populates 396 # frame_obj field, which default eval frame doesn't like). 397 # 398 # Fortunately, in this case, we can hack it: there's no need 399 # to actually use the truly top frame, we can just extract 400 # from where we are right now and rely on filter_stack to 401 # get rid of all the dynamo frames. For ease of testing 402 # we apply this behavior to ALL Python versions 403 stack_above_dynamo = filter_stack(extract_stack()) 404 405 return cast(StackSummary, stack_above_dynamo + real_stack) 406 407 408# filter out all frames after entering dynamo 409def filter_stack(stack): 410 user_stack = [] 411 for frame in stack: 412 if "convert_frame" in frame.filename: 413 break 414 if "eval_frame" in frame.filename or "torch._dynamo.optimize(" in frame.line: 415 continue 416 user_stack.append(frame) 417 418 return user_stack 419 420 421def format_error_msg_verbose( 422 exc: Exception, code, record_filename=None, frame=None 423) -> str: 424 msg = ( 425 f"WON'T CONVERT {code.co_name} {code.co_filename} line {code.co_firstlineno}\n" 426 ) 427 msg += "=" * 10 + " TorchDynamo Stack Trace " + "=" * 10 + "\n" 428 msg += format_exc() 429 real_stack = get_real_stack(exc, frame) 430 if real_stack is not None: 431 msg += ( 432 "\n" 433 + "=" * 10 434 + " The above exception occurred while processing the following code " 435 + "=" * 10 436 + "\n\n" 437 ) 438 msg += "".join(format_list(real_stack)) 439 msg += "\n" 440 msg += "=" * 10 441 442 return msg 443 444 445def format_error_msg(exc: Exception, code, record_filename=None, frame=None) -> str: 446 msg = os.linesep * 2 447 448 if config.verbose: 449 msg = format_error_msg_verbose(exc, code, record_filename, frame) 450 else: 451 msg = f"WON'T CONVERT {code.co_name} {code.co_filename}\ 452 line {code.co_firstlineno} \ndue to: \n{format_exc()}" 453 454 return msg 455