xref: /aosp_15_r20/external/pytorch/torch/_dynamo/exc.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import os
3import textwrap
4from enum import auto, Enum
5from traceback import extract_stack, format_exc, format_list, StackSummary
6from typing import Any, cast, NoReturn, Optional, Tuple, TYPE_CHECKING
7
8import torch._guards
9
10from . import config
11from .utils import counters
12
13
14if TYPE_CHECKING:
15    from torch._guards import CompileId
16
17
18def exportdb_error_message(case_name):
19    return (
20        "For more information about this error, see: "
21        + "https://pytorch.org/docs/main/generated/exportdb/index.html#"
22        + case_name.replace("_", "-")
23    )
24
25
26import logging
27
28
29log = logging.getLogger(__name__)
30graph_breaks_log = torch._logging.getArtifactLogger(__name__, "graph_breaks")
31
32
33class TorchDynamoException(RuntimeError):
34    pass
35
36
37class InternalTorchDynamoError(TorchDynamoException):
38    pass
39
40
41class RestartAnalysis(TorchDynamoException):
42    restart_reason: str
43
44    def __init__(self, *args, restart_reason=None) -> None:
45        self.restart_reason = restart_reason
46        super().__init__(*args)
47
48
49class SpeculationRestartAnalysis(RestartAnalysis):
50    pass
51
52
53class UnspecializeRestartAnalysis(RestartAnalysis):
54    pass
55
56
57class CompileCollectiveRestartAnalysis(RestartAnalysis):
58    pass
59
60
61class SkipFrame(TorchDynamoException):
62    pass
63
64
65class TorchRuntimeError(TorchDynamoException):
66    pass
67
68
69class InvalidBackend(TorchDynamoException):
70    def __init__(self, name) -> None:
71        super().__init__(
72            f"Invalid backend: {name!r}, see `torch._dynamo.list_backends()` for available backends."
73        )
74
75
76class ResetRequired(TorchDynamoException):
77    def __init__(self) -> None:
78        super().__init__(
79            textwrap.dedent(
80                """
81                Must call `torch._dynamo.reset()` before changing backends.  Detected two calls to
82                `torch.compile()` with a different backend compiler arguments.
83                """
84            )
85        )
86
87
88class BackendCompilerFailed(TorchDynamoException):
89    def __init__(self, backend_fn, inner_exception) -> None:
90        self.backend_name = getattr(backend_fn, "__name__", "?")
91        self.inner_exception = inner_exception
92        msg = f"backend={self.backend_name!r} raised:\n{type(inner_exception).__name__}: {inner_exception}"
93        super().__init__(msg)
94
95
96class Unsupported(TorchDynamoException):
97    def __init__(self, msg, *, case_name=None) -> None:
98        super().__init__(msg)
99        self.real_stack = torch._guards.TracingContext.extract_stack()
100        self.msg = msg
101        self.category: Optional[str] = None
102        self.add_to_stats()
103        self.case_name: Optional[str] = case_name
104
105    def remove_from_stats(self):
106        assert self.category is not None
107        counters[self.category][self.msg] -= 1
108        if counters[self.category][self.msg] <= 0:
109            del counters[self.category][self.msg]
110
111    def add_to_stats(self, category="unimplemented"):
112        self.category = category
113        counters[category][self.msg] += 1
114
115
116class RecompileError(TorchDynamoException):
117    pass
118
119
120class ArgsMismatchError(Unsupported):
121    def __init__(self, msg) -> None:
122        super().__init__(msg)
123
124
125class AttributeMutationError(Unsupported):
126    def __init__(self, msg) -> None:
127        super().__init__(msg)
128
129
130class CondOpArgsMismatchError(ArgsMismatchError):
131    """
132    Internal error from cond() due to arguments mismatch.
133    """
134
135    def __init__(self, msg) -> None:
136        super().__init__(msg)
137
138
139class UserErrorType(Enum):
140    DYNAMIC_CONTROL_FLOW = auto()
141    ANTI_PATTERN = auto()
142    STANDARD_LIBRARY = auto()
143    CONSTRAINT_VIOLATION = auto()
144    DYNAMIC_DIM = auto()
145    INVALID_INPUT = auto()
146    INVALID_OUTPUT = auto()
147
148
149class UserError(Unsupported):
150    def __init__(self, error_type: UserErrorType, msg, case_name=None) -> None:
151        """
152        Type of errors that would be valid in Eager, but not supported in TorchDynamo.
153        The error message should tell user about next actions.
154
155        error_type: Type of user error
156        msg: Actionable error message
157        case_name: (Optional) Unique name (snake case) for the usage example in exportdb.
158        """
159        if case_name is not None:
160            assert isinstance(case_name, str)
161            if msg.endswith("."):
162                msg += " "
163            else:
164                msg += "\n"
165            msg += exportdb_error_message(case_name)
166        super().__init__(msg)
167        self.error_type = error_type
168        self.message = msg
169
170
171class SkipCodeRecursiveException(TorchDynamoException):
172    pass
173
174
175class CacheLimitExceeded(SkipCodeRecursiveException, Unsupported):
176    pass
177
178
179class UnsafeScriptObjectError(TorchDynamoException):
180    pass
181
182
183class UncapturedHigherOrderOpError(TorchDynamoException):
184    pass
185
186
187class IncorrectUsage(Exception):
188    pass
189
190
191class ObservedException(TorchDynamoException):
192    # An exception observed during the tracing. This exception is used by Dynamo to handle exceptions.
193    pass
194
195
196class ObservedUserStopIteration(ObservedException):
197    # An UserStopIteraion exception observed during the Dynamo tracing (e.g Dynamo tracing __next__)
198    value: Optional[Any]
199
200    # Reference `StopIteration_init` in CPython
201    # https://github.com/python/cpython/blob/3.11/Objects/exceptions.c#L568-L584
202    def __init__(self, *args, **kwargs) -> None:
203        super().__init__("unhandled `raise StopIteration`")
204        if len(args) > 0:
205            self.value = args[0]
206        else:
207            self.value = None
208
209
210class ObservedKeyError(ObservedException):
211    # A KeyError exception to be raised from inside Dynamo tracing. This can happen on dict __getitem__
212    pass
213
214
215class ObservedAttributeError(ObservedException):
216    # An AttributeError exception to be raised from inside Dynamo tracing. This can happen on user defined object __getattr__
217    pass
218
219
220observed_exception_map = {
221    StopIteration: ObservedUserStopIteration,
222    KeyError: ObservedKeyError,
223    AttributeError: ObservedAttributeError,
224}
225
226
227def raise_observed_exception(e, tx, vt):
228    from .variables import BuiltinVariable
229
230    # CPython here raises an exception. Since there is no python code, we have to manually setup the exception
231    # stack and raise the exception.
232    exception_vt = BuiltinVariable(e).call_function(vt, [], {})
233    tx.exn_vt_stack.append(exception_vt)
234    raise observed_exception_map[e]
235
236
237def handle_observed_exception(tx):
238    # This is essentially exception handling code, equivalent of this pseudo code
239    #
240    # try:
241    #     ... somebody raising StopIteration
242    # except StopIteration
243    #     pass
244    #
245    # If this was going through the python code, we would have called exception_handler method, but FOR_ITER
246    # handles the exception completely in CPython. For example for 3.11, the resulting bytecode is
247    #
248    #
249    #   6          46 LOAD_GLOBAL              2 (StopIteration)
250    #              58 RAISE_VARARGS            1
251    #         >>   60 PUSH_EXC_INFO
252
253    #   7          62 LOAD_GLOBAL              2 (StopIteration)
254    #              74 CHECK_EXC_MATCH
255    #              76 POP_JUMP_FORWARD_IF_FALSE     3 (to 84)
256    #              78 POP_TOP
257
258    #   8          80 POP_EXCEPT
259    #
260
261    # Fortunately this translates to a simple pop from the exn_vt_stack
262    tx.exn_vt_stack.pop()
263
264
265# These exceptions are ok to fallback to eager/graph_break.
266exceptions_allowed_to_be_fallback = (
267    torch._subclasses.fake_tensor.DataDependentOutputException,
268    torch._subclasses.fake_tensor.DynamicOutputShapeException,
269    torch._subclasses.fake_tensor.UnsupportedOperatorException,
270    torch._subclasses.fake_tensor.UnsupportedFakeTensorException,
271)
272
273
274def unimplemented_with_warning(e: Exception, code, msg: str) -> NoReturn:
275    # This function calls unimplemented internally and eventually graph breaks
276    # or falls to eager. unimplemented itself does not print any user warnings,
277    # i.e., its very silent. This helper function is intended when an error is
278    # encountered in the torch.compile stack which is worth showing as warning
279    # to the user. For example, if AOT Autograd backend fails with a fake tensor
280    # exception, its ok to fallback to eager but not silently. Here, we can use
281    # this function to log the message and the stack trace.
282    graph_break_msg = format_error_msg_verbose(e, code)
283    graph_breaks_log.debug("%s", graph_break_msg)
284    log.warning(msg)
285    unimplemented(msg, from_exc=e)
286
287
288_NOTHING = object()
289
290
291def unimplemented(
292    msg: str, *, from_exc: Any = _NOTHING, case_name: Optional[str] = None
293) -> NoReturn:
294    assert msg != os.environ.get("BREAK", False)
295    if from_exc is not _NOTHING:
296        raise Unsupported(msg, case_name=case_name) from from_exc
297    raise Unsupported(msg, case_name=case_name)
298
299
300def warning(msg: str) -> None:
301    counters["warnings"][msg] += 1
302    assert msg != os.environ.get("BREAK", False)
303
304
305# KeyError has special handling for its args
306# see https://github.com/python/cpython/blob/3.11/Objects/exceptions.c#L2534 for details
307class KeyErrorMsg:
308    def __init__(self, value) -> None:
309        self.value = value
310
311    def __str__(self) -> str:
312        return str(self.value)
313
314    def __repr__(self) -> str:
315        return self.__str__()
316
317
318def augment_exc_message(exc: Exception, msg: str = "\n", export: bool = False) -> None:
319    import traceback
320
321    exc.innermost_user_frame_summary = None  # type: ignore[attr-defined]
322
323    real_stack = get_real_stack(exc)
324    if real_stack is not None and len(real_stack) > 0:
325        exc.innermost_user_frame_summary = real_stack[-1]  # type: ignore[attr-defined]
326        msg += f"\nfrom user code:\n {''.join(traceback.format_list(real_stack))}"
327
328    if config.replay_record_enabled and hasattr(exc, "record_filename"):
329        msg += f"\nLast frame execution written to {exc.record_filename}. To run only this frame while debugging, run\
330 torch._dynamo.replay('{exc.record_filename}').\n"
331
332    if not config.verbose and hasattr(exc, "real_stack"):
333        msg += '\nSet TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information\n'
334
335    if hasattr(exc, "inner_exception") and hasattr(
336        exc.inner_exception, "minifier_path"
337    ):
338        if hasattr(exc.inner_exception, "buck_command"):
339            msg += (
340                f"\nMinifier script written to {exc.inner_exception.minifier_path}. Run "
341                f"this buck command to find the smallest traced graph "
342                f"which reproduces this error: {exc.inner_exception.buck_command}\n"
343            )
344        else:
345            msg += (
346                f"\nMinifier script written to {exc.inner_exception.minifier_path}. Run "
347                "this script to find the smallest traced graph which reproduces this error.\n"
348            )
349
350    if not config.suppress_errors and not export:
351        msg += (
352            "\n\n"
353            "You can suppress this exception and fall back to eager by setting:\n"
354            "    import torch._dynamo\n"
355            "    torch._dynamo.config.suppress_errors = True\n"
356        )
357
358    old_msg = "" if len(exc.args) == 0 else str(exc.args[0])
359
360    if isinstance(exc, KeyError):
361        exc.args = (KeyErrorMsg(old_msg + msg),) + exc.args[1:]
362    else:
363        new_msg = old_msg + msg
364        exc.args = (new_msg,) + exc.args[1:]
365
366
367def get_exc_message(
368    e: Exception, compile_id: "CompileId"
369) -> Tuple[Optional[str], Optional[int]]:
370    filename = None
371    lineno = None
372    if e.innermost_user_frame_summary is not None:  # type: ignore[attr-defined]
373        filename = e.innermost_user_frame_summary.filename  # type: ignore[attr-defined]
374        lineno = e.innermost_user_frame_summary.lineno  # type: ignore[attr-defined]
375    e.compile_id = compile_id  # type: ignore[attr-defined]
376    return filename, lineno
377
378
379def get_real_stack(exc: Exception, frame=None) -> Optional[StackSummary]:
380    real_stack = getattr(exc, "real_stack", None)
381    if real_stack is None:
382        return None
383
384    # NB: it's possible for real_stack to be []; we still attempt to
385    # report a stack anyway because the stack_above_dynamo may still
386    # be useful for debugging
387
388    stack_above_dynamo = []
389    if frame is not None:
390        # NB: frame is PyInterpreterFrame on Python 3.11 and later,
391        # not a TRUE frame object.  You can't actually feed it
392        # to traceback because it doesn't have enough information.
393        # To solve this problem, we technically should just materialize
394        # the frame, the same way _PyFrame_GetFrameObject would do
395        # (but we cannot actually do this, because this populates
396        # frame_obj field, which default eval frame doesn't like).
397        #
398        # Fortunately, in this case, we can hack it: there's no need
399        # to actually use the truly top frame, we can just extract
400        # from where we are right now and rely on filter_stack to
401        # get rid of all the dynamo frames.  For ease of testing
402        # we apply this behavior to ALL Python versions
403        stack_above_dynamo = filter_stack(extract_stack())
404
405    return cast(StackSummary, stack_above_dynamo + real_stack)
406
407
408# filter out all frames after entering dynamo
409def filter_stack(stack):
410    user_stack = []
411    for frame in stack:
412        if "convert_frame" in frame.filename:
413            break
414        if "eval_frame" in frame.filename or "torch._dynamo.optimize(" in frame.line:
415            continue
416        user_stack.append(frame)
417
418    return user_stack
419
420
421def format_error_msg_verbose(
422    exc: Exception, code, record_filename=None, frame=None
423) -> str:
424    msg = (
425        f"WON'T CONVERT {code.co_name} {code.co_filename} line {code.co_firstlineno}\n"
426    )
427    msg += "=" * 10 + " TorchDynamo Stack Trace " + "=" * 10 + "\n"
428    msg += format_exc()
429    real_stack = get_real_stack(exc, frame)
430    if real_stack is not None:
431        msg += (
432            "\n"
433            + "=" * 10
434            + " The above exception occurred while processing the following code "
435            + "=" * 10
436            + "\n\n"
437        )
438        msg += "".join(format_list(real_stack))
439        msg += "\n"
440        msg += "=" * 10
441
442    return msg
443
444
445def format_error_msg(exc: Exception, code, record_filename=None, frame=None) -> str:
446    msg = os.linesep * 2
447
448    if config.verbose:
449        msg = format_error_msg_verbose(exc, code, record_filename, frame)
450    else:
451        msg = f"WON'T CONVERT {code.co_name} {code.co_filename}\
452 line {code.co_firstlineno} \ndue to: \n{format_exc()}"
453
454    return msg
455