xref: /aosp_15_r20/external/pytorch/torch/distributed/checkpoint/api.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import traceback as tb
3from typing import Any, Dict, Tuple
4
5
6WRAPPED_EXCEPTION = Tuple[BaseException, tb.StackSummary]
7
8__all__ = ["CheckpointException"]
9
10
11def _wrap_exception(exc: BaseException) -> WRAPPED_EXCEPTION:
12    return (exc, tb.extract_tb(exc.__traceback__))
13
14
15def _is_wrapped_exception(obj: Any) -> bool:
16    if not isinstance(obj, tuple):
17        return False
18    if len(obj) != 2:
19        return False
20    return isinstance(obj[0], BaseException) and isinstance(obj[1], tb.StackSummary)
21
22
23class CheckpointException(BaseException):
24    """Exception raised if failure was detected as part of a checkpoint load or save."""
25
26    def __init__(self, msg: str, failures: Dict[int, WRAPPED_EXCEPTION]):
27        super().__init__(msg, failures)
28        self._failures = failures
29
30    @property
31    def failures(self) -> Dict[int, WRAPPED_EXCEPTION]:
32        """Return a dictionary mapping node ranks to their associated exceptions in case of failure."""
33        return self._failures
34
35    def __str__(self):
36        str = f"CheckpointException ranks:{self._failures.keys()}\n"
37        for rank, exc_pair in self._failures.items():
38            exc, trace = exc_pair
39            str += f"Traceback (most recent call last): (RANK {rank})\n"
40            if trace is not None:
41                str += "".join(tb.format_list(trace))
42            str += "".join(tb.format_exception_only(type(exc), value=exc))
43        return str
44