xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/logging_tensor.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import torch
4from torch.utils._pytree import tree_map
5from typing import Iterator, List, Optional
6import logging
7import contextlib
8import itertools
9from torch.utils._python_dispatch import TorchDispatchMode
10from torch.utils.weak import WeakTensorKeyDictionary
11import functools
12from torch._C._profiler import gather_traceback, symbolize_tracebacks
13
14logger = logging.getLogger("LoggingTensor")
15
16_dtype_abbrs = {
17    torch.bfloat16: "bf16",
18    torch.float64: "f64",
19    torch.float32: "f32",
20    torch.float16: "f16",
21    torch.complex32: "c32",
22    torch.complex64: "c64",
23    torch.complex128: "c128",
24    torch.int8: "i8",
25    torch.int16: "i16",
26    torch.int32: "i32",
27    torch.int64: "i64",
28    torch.bool: "b8",
29    torch.uint8: "u8",
30}
31
32# How the chain of calls works for LoggingTensor:
33# 1. Call torch.sin
34# 2. Attempt __torch_function__. In LoggingTensor torch function is disabled so we bypass it entirely
35# 3. Enter dispatcher, wind your way through Autograd
36# 4. Hit Python dispatch key, call __torch_dispatch__
37
38# This Tensor can work with autograd in two ways:
39#  - The wrapped Tensor does not require gradients. In that case, the LoggingTensor
40#    can require gradients if the user asks for it as a constructor kwarg.
41#  - The wrapped Tensor can require gradients. In that case autograd will be tracked
42#    for the wrapped Tensor and the LoggingTensor itself cannot require gradients.
43# WARNING: We allow these two possibilities for testing purposes. You should NEVER use both in a single
44# test or you might get surprising behavior.
45
46# TODO: TensorBase should work
47class LoggingTensor(torch.Tensor):
48    elem: torch.Tensor
49
50    __slots__ = ['elem']
51
52    context = contextlib.nullcontext
53
54    @staticmethod
55    def __new__(cls, elem, *args, **kwargs):
56        # The wrapping tensor (LoggingTensor) shouldn't hold any
57        # memory for the class in question, but it should still
58        # advertise the same device as before
59        r = torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
60            cls, elem.size(),
61            strides=elem.stride(), storage_offset=elem.storage_offset(),
62            # TODO: clone storage aliasing
63            dtype=elem.dtype, layout=elem.layout,
64            device=elem.device, requires_grad=kwargs.get("requires_grad", False)
65        )
66        # ...the real tensor is held as an element on the tensor.
67        r.elem = elem.detach() if r.requires_grad else elem
68        return r
69
70    def __repr__(self):
71        return super().__repr__(tensor_contents=f"{self.elem}")
72
73    @classmethod
74    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
75        def unwrap(e):
76            return e.elem if isinstance(e, cls) else e
77
78        def wrap(e):
79            return cls(e) if isinstance(e, torch.Tensor) else e
80
81        with cls.context():
82            rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
83        logging.getLogger("LoggingTensor").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs)  # noqa: G004
84        return rs
85
86class LoggingTensorMode(TorchDispatchMode):
87    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
88        if kwargs is None:
89            kwargs = {}
90        rs = func(*args, **kwargs)
91        logging.getLogger("LoggingTensor").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs)  # noqa: G004
92        return rs
93
94class LoggingTensorReentrant(LoggingTensor):
95    context = torch.overrides.enable_reentrant_dispatch
96
97# https://stackoverflow.com/questions/36408496/python-logging-handler-to-append-to-list
98class LoggingTensorHandler(logging.Handler):
99    def __init__(
100            self, log_list: List[str], use_shortid_for_all_tensors: bool,
101            with_type: bool, tracebacks_list: Optional[List]) -> None:
102        logging.Handler.__init__(self)
103        self.log_list = log_list
104        self.use_shortid_for_all_tensors = use_shortid_for_all_tensors
105        self.tracebacks_list = tracebacks_list
106        self.memo = WeakTensorKeyDictionary()
107        self.next_id = 0
108        self.with_type = with_type
109
110    def _shortid(self, t: torch.Tensor) -> int:
111        if t not in self.memo:
112            self.memo[t] = self.next_id
113            self.next_id += 1
114        return self.memo[t]
115
116    def _fmt(self, a: object, with_type: bool = False) -> str:
117        cond_cls = torch.Tensor if self.use_shortid_for_all_tensors else LoggingTensor
118        if isinstance(a, cond_cls):
119            maybe_type = ""
120            if with_type and self.with_type:
121                maybe_type = f": {_dtype_abbrs[a.dtype]}[{', '.join(map(str, a.shape))}]"
122            x = f"${self._shortid(a)}{maybe_type}"
123            return x
124        else:
125            return repr(a)
126
127    def emit(self, record):
128        fmt_args = ", ".join(
129            itertools.chain(
130                (str(tree_map(self._fmt, a)) for a in record.args[0]),
131                (f"{k}={str(tree_map(self._fmt, v))}" for k, v in record.args[1].items()),
132            )
133        )
134        fmt_rets = tree_map(functools.partial(self._fmt, with_type=True), record.args[2])
135        self.log_list.append(f'{fmt_rets} = {record.msg}({fmt_args})')
136        if self.tracebacks_list is not None:
137            self.tracebacks_list.append(record.traceback)
138
139def log_input(name: str, var: object) -> None:
140    logger.info("input", (name,), {}, var)  # noqa: PLE1205
141
142class GatherTraceback(logging.Filter):
143    def __init__(self, python=True, script=True, cpp=False):
144        self.python = python
145        self.script = script
146        self.cpp = cpp
147
148    def filter(self, record):
149        record.traceback = gather_traceback(python=self.python, script=self.script, cpp=self.cpp)
150        return True
151
152@contextlib.contextmanager
153def capture_logs(is_mode=False, python_tb=False, script_tb=False, cpp_tb=False) -> Iterator[List[str]]:
154    collect_traceback = python_tb or script_tb or cpp_tb
155    log_list: List[str] = []
156    tracebacks_list: List[str] = []
157    handler = LoggingTensorHandler(
158        log_list,
159        with_type=True,
160        use_shortid_for_all_tensors=is_mode,
161        tracebacks_list=tracebacks_list if collect_traceback else None
162    )
163    logger.addHandler(handler)
164    logger.setLevel(logging.INFO)
165    logger.propagate = False
166    if collect_traceback:
167        logger.addFilter(GatherTraceback(python=python_tb, script=script_tb, cpp=cpp_tb))
168    try:
169        if collect_traceback:
170            yield log_list, tracebacks_list
171        else:
172            yield log_list
173    finally:
174        symbolized_tracebacks = symbolize_tracebacks(tracebacks_list)
175        tracebacks_list.clear()
176        tracebacks_list.extend(symbolized_tracebacks)
177        logger.removeHandler(handler)
178
179@contextlib.contextmanager
180def capture_logs_with_logging_tensor_mode(python_tb=False, script_tb=False, cpp_tb=False):
181    with LoggingTensorMode(), capture_logs(True, python_tb, script_tb, cpp_tb) as logs:
182        yield logs
183