1# mypy: ignore-errors 2 3import torch 4from torch.utils._pytree import tree_map 5from typing import Iterator, List, Optional 6import logging 7import contextlib 8import itertools 9from torch.utils._python_dispatch import TorchDispatchMode 10from torch.utils.weak import WeakTensorKeyDictionary 11import functools 12from torch._C._profiler import gather_traceback, symbolize_tracebacks 13 14logger = logging.getLogger("LoggingTensor") 15 16_dtype_abbrs = { 17 torch.bfloat16: "bf16", 18 torch.float64: "f64", 19 torch.float32: "f32", 20 torch.float16: "f16", 21 torch.complex32: "c32", 22 torch.complex64: "c64", 23 torch.complex128: "c128", 24 torch.int8: "i8", 25 torch.int16: "i16", 26 torch.int32: "i32", 27 torch.int64: "i64", 28 torch.bool: "b8", 29 torch.uint8: "u8", 30} 31 32# How the chain of calls works for LoggingTensor: 33# 1. Call torch.sin 34# 2. Attempt __torch_function__. In LoggingTensor torch function is disabled so we bypass it entirely 35# 3. Enter dispatcher, wind your way through Autograd 36# 4. Hit Python dispatch key, call __torch_dispatch__ 37 38# This Tensor can work with autograd in two ways: 39# - The wrapped Tensor does not require gradients. In that case, the LoggingTensor 40# can require gradients if the user asks for it as a constructor kwarg. 41# - The wrapped Tensor can require gradients. In that case autograd will be tracked 42# for the wrapped Tensor and the LoggingTensor itself cannot require gradients. 43# WARNING: We allow these two possibilities for testing purposes. You should NEVER use both in a single 44# test or you might get surprising behavior. 45 46# TODO: TensorBase should work 47class LoggingTensor(torch.Tensor): 48 elem: torch.Tensor 49 50 __slots__ = ['elem'] 51 52 context = contextlib.nullcontext 53 54 @staticmethod 55 def __new__(cls, elem, *args, **kwargs): 56 # The wrapping tensor (LoggingTensor) shouldn't hold any 57 # memory for the class in question, but it should still 58 # advertise the same device as before 59 r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] 60 cls, elem.size(), 61 strides=elem.stride(), storage_offset=elem.storage_offset(), 62 # TODO: clone storage aliasing 63 dtype=elem.dtype, layout=elem.layout, 64 device=elem.device, requires_grad=kwargs.get("requires_grad", False) 65 ) 66 # ...the real tensor is held as an element on the tensor. 67 r.elem = elem.detach() if r.requires_grad else elem 68 return r 69 70 def __repr__(self): 71 return super().__repr__(tensor_contents=f"{self.elem}") 72 73 @classmethod 74 def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 75 def unwrap(e): 76 return e.elem if isinstance(e, cls) else e 77 78 def wrap(e): 79 return cls(e) if isinstance(e, torch.Tensor) else e 80 81 with cls.context(): 82 rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))) 83 logging.getLogger("LoggingTensor").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs) # noqa: G004 84 return rs 85 86class LoggingTensorMode(TorchDispatchMode): 87 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 88 if kwargs is None: 89 kwargs = {} 90 rs = func(*args, **kwargs) 91 logging.getLogger("LoggingTensor").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs) # noqa: G004 92 return rs 93 94class LoggingTensorReentrant(LoggingTensor): 95 context = torch.overrides.enable_reentrant_dispatch 96 97# https://stackoverflow.com/questions/36408496/python-logging-handler-to-append-to-list 98class LoggingTensorHandler(logging.Handler): 99 def __init__( 100 self, log_list: List[str], use_shortid_for_all_tensors: bool, 101 with_type: bool, tracebacks_list: Optional[List]) -> None: 102 logging.Handler.__init__(self) 103 self.log_list = log_list 104 self.use_shortid_for_all_tensors = use_shortid_for_all_tensors 105 self.tracebacks_list = tracebacks_list 106 self.memo = WeakTensorKeyDictionary() 107 self.next_id = 0 108 self.with_type = with_type 109 110 def _shortid(self, t: torch.Tensor) -> int: 111 if t not in self.memo: 112 self.memo[t] = self.next_id 113 self.next_id += 1 114 return self.memo[t] 115 116 def _fmt(self, a: object, with_type: bool = False) -> str: 117 cond_cls = torch.Tensor if self.use_shortid_for_all_tensors else LoggingTensor 118 if isinstance(a, cond_cls): 119 maybe_type = "" 120 if with_type and self.with_type: 121 maybe_type = f": {_dtype_abbrs[a.dtype]}[{', '.join(map(str, a.shape))}]" 122 x = f"${self._shortid(a)}{maybe_type}" 123 return x 124 else: 125 return repr(a) 126 127 def emit(self, record): 128 fmt_args = ", ".join( 129 itertools.chain( 130 (str(tree_map(self._fmt, a)) for a in record.args[0]), 131 (f"{k}={str(tree_map(self._fmt, v))}" for k, v in record.args[1].items()), 132 ) 133 ) 134 fmt_rets = tree_map(functools.partial(self._fmt, with_type=True), record.args[2]) 135 self.log_list.append(f'{fmt_rets} = {record.msg}({fmt_args})') 136 if self.tracebacks_list is not None: 137 self.tracebacks_list.append(record.traceback) 138 139def log_input(name: str, var: object) -> None: 140 logger.info("input", (name,), {}, var) # noqa: PLE1205 141 142class GatherTraceback(logging.Filter): 143 def __init__(self, python=True, script=True, cpp=False): 144 self.python = python 145 self.script = script 146 self.cpp = cpp 147 148 def filter(self, record): 149 record.traceback = gather_traceback(python=self.python, script=self.script, cpp=self.cpp) 150 return True 151 152@contextlib.contextmanager 153def capture_logs(is_mode=False, python_tb=False, script_tb=False, cpp_tb=False) -> Iterator[List[str]]: 154 collect_traceback = python_tb or script_tb or cpp_tb 155 log_list: List[str] = [] 156 tracebacks_list: List[str] = [] 157 handler = LoggingTensorHandler( 158 log_list, 159 with_type=True, 160 use_shortid_for_all_tensors=is_mode, 161 tracebacks_list=tracebacks_list if collect_traceback else None 162 ) 163 logger.addHandler(handler) 164 logger.setLevel(logging.INFO) 165 logger.propagate = False 166 if collect_traceback: 167 logger.addFilter(GatherTraceback(python=python_tb, script=script_tb, cpp=cpp_tb)) 168 try: 169 if collect_traceback: 170 yield log_list, tracebacks_list 171 else: 172 yield log_list 173 finally: 174 symbolized_tracebacks = symbolize_tracebacks(tracebacks_list) 175 tracebacks_list.clear() 176 tracebacks_list.extend(symbolized_tracebacks) 177 logger.removeHandler(handler) 178 179@contextlib.contextmanager 180def capture_logs_with_logging_tensor_mode(python_tb=False, script_tb=False, cpp_tb=False): 181 with LoggingTensorMode(), capture_logs(True, python_tb, script_tb, cpp_tb) as logs: 182 yield logs 183