1# mypy: allow-untyped-defs 2import gc 3import sys 4from typing import Any, Dict, List, NamedTuple, Optional, Tuple 5import types 6import weakref 7import json 8from tempfile import NamedTemporaryFile 9import torch 10from torch.cuda._memory_viz import _frames_fmt, _block_extra 11import atexit 12import logging 13logger = logging.getLogger(__name__) 14 15def observe_garbage(observer): 16 enabled = True 17 18 def disable(): 19 # when GC runs during exit, things like `sys` will already be unloaded 20 # so we have to disable the callback to avoid hitting errors. 21 nonlocal enabled 22 enabled = False 23 atexit.register(disable) 24 25 def gc_callback(phase, info): 26 nonlocal enabled 27 if not enabled: 28 return 29 if phase == "start": 30 gc.set_debug(gc.DEBUG_SAVEALL) 31 elif phase == "stop": 32 orig_trace = sys.getprofile() 33 self_return = [False] 34 35 def do_collect(*args, **kwargs): 36 nonlocal enabled 37 if not self_return[0]: 38 self_return[0] = True 39 else: 40 sys.setprofile(orig_trace) 41 enabled = False 42 try: 43 # things in gc.garbage have survived a collection 44 # so to free them we have to collect a generation greater than them 45 # but that might _also_ free other stuff and we don't want to miss 46 # that stuff. So we have to now force gc at the highest level here, 47 # report all of what we found, _then_ we can free it up. 48 if info['generation'] != 2: 49 gc.collect() 50 observer(gc.garbage) 51 gc.garbage.clear() 52 # we have to re-run GC to clean up the cycles 53 # we saved from before. 54 gc.set_debug(0) 55 before = torch.cuda.memory_allocated() 56 gc.collect() 57 after = torch.cuda.memory_allocated() 58 if before != after: 59 logger.warning("CUDA Memory changed during GC, %d bytes freed.", before - after) 60 finally: 61 enabled = True 62 if orig_trace is not None: 63 return orig_trace(*args, **kwargs) 64 sys.setprofile(do_collect) 65 66 gc.callbacks.append(gc_callback) 67 68 # provide a way to disarm the callback 69 def remove(): 70 gc.callbacks.remove(gc_callback) 71 return remove 72 73# Function to visualize cycles adapated from refcycle: 74# Copyright 2013 Mark Dickinson 75# 76# Licensed under the Apache License, Version 2.0 (the "License"); 77# you may not use this file except in compliance with the License. 78# You may obtain a copy of the License at 79# 80# http://www.apache.org/licenses/LICENSE-2.0 81# 82# Unless required by applicable law or agreed to in writing, software 83# distributed under the License is distributed on an "AS IS" BASIS, 84# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 85# See the License for the specific language governing permissions and 86# limitations under the License. 87 88def _get_cell_type(): 89 def f(x=None): 90 return lambda: x 91 return type(f().__closure__[0]) 92 93CellType = _get_cell_type() 94 95def annotated_references(obj): 96 """ 97 Return known information about references held by the given object. 98 99 Returns a mapping from referents to lists of descriptions. Note that there 100 may be more than one edge leading to any particular referent; hence the 101 need for a list. Descriptions are currently strings. 102 103 """ 104 references: Dict[int, List[str]] = {} 105 106 def add_reference(name, obj): 107 references.setdefault(id(obj), []).append(name) 108 109 def add_attrs(*attrs): 110 for attr in attrs: 111 if hasattr(obj, attr): 112 add_reference(attr, getattr(obj, attr)) 113 114 def add_cell_references(): 115 try: 116 add_attrs("cell_contents") 117 except ValueError: 118 # if cell_contents is empty, 119 # accessing it raises ValueError 120 # in this case there is no object to 121 # annotate 122 pass 123 124 def add_function_references(): 125 add_attrs("__defaults__", 126 "__closure__", 127 "__globals__", 128 "__code__", 129 "__name__", 130 "__module__", 131 "__doc__" 132 "__qualname__", 133 "__annotations__", 134 "__kwdefaults__") 135 136 137 def add_sequence_references(): 138 for position, item in enumerate(obj): 139 add_reference(f"[{position}]", item) 140 141 def add_dict_references(): 142 for key, value in obj.items(): 143 add_reference("key", key) 144 add_reference(f"[{repr(key)}]", value) 145 146 def add_set_references(): 147 for elt in obj: 148 add_reference("element", elt) 149 150 def add_bound_method_references(): 151 add_attrs("__self__", "__func__", "im_class") 152 153 def add_weakref_references(): 154 # For subclasses of weakref, we can't reliably distinguish the 155 # callback (if any) from other attributes. 156 if type(obj) is weakref.ref: 157 referents = gc.get_referents(obj) 158 if len(referents) == 1: 159 target = referents[0] 160 add_reference("__callback__", target) 161 162 163 def add_frame_references(): 164 f_locals = obj.f_locals 165 add_attrs("f_back", "f_code", "f_builtins", "f_globals", "f_trace", "f_locals") 166 # Some badly-behaved code replaces the f_locals dict with 167 # something that doesn't support the full dict interface. So we 168 # only continue with the annotation if f_locals is a Python dict. 169 if type(f_locals) is dict: 170 for name, local in obj.f_locals.items(): 171 add_reference(f"local {name}", local) 172 173 def add_getset_descriptor_references(): 174 add_attrs("__objclass__", "__name__", "__doc__") 175 176 type_based_references = { 177 tuple: add_sequence_references, 178 list: add_sequence_references, 179 dict: add_dict_references, 180 set: add_set_references, 181 frozenset: add_set_references, 182 types.FunctionType: add_function_references, 183 types.FrameType: add_frame_references, 184 CellType: add_cell_references, 185 types.MethodType: add_bound_method_references, 186 weakref.ref: add_weakref_references, 187 types.GetSetDescriptorType: add_getset_descriptor_references, 188 } 189 190 for type_ in type(obj).__mro__: 191 if type_ in type_based_references: 192 type_based_references[type_]() 193 194 add_attrs("__dict__", "__class__") 195 if isinstance(obj, type): 196 add_attrs("__mro__") 197 198 return references 199 200############################################################################### 201# Object annotations. 202 203 204BASE_TYPES = (int, float, complex, type(None), str, bytes) 205FRAME_FILENAME_LIMIT = 32 206 207def object_annotation(obj): 208 """ 209 Return a string to be used for Graphviz nodes. 210 211 The string should be short but as informative as possible. 212 """ 213 214 def format_sequence(obj): 215 body = ','.join(repr(x) if isinstance(x, BASE_TYPES) else type(x).__name__ for i, x in zip(range(8), obj)) 216 if len(obj) > 8: 217 body = f'{body}, ...{len(obj) - 8}' 218 return body 219 220 # For basic types, use the repr. 221 if isinstance(obj, BASE_TYPES): 222 return repr(obj) 223 if type(obj).__name__ == 'function': 224 return f"function\n{obj.__name__}" 225 elif isinstance(obj, types.MethodType): 226 try: 227 func_name = obj.__func__.__qualname__ 228 except AttributeError: 229 func_name = "<anonymous>" 230 return f"instancemethod\n{func_name}" 231 elif isinstance(obj, list): 232 return f"[{format_sequence(obj)}]" 233 elif isinstance(obj, tuple): 234 return f"({format_sequence(obj)})" 235 elif isinstance(obj, dict): 236 return f"dict[{len(obj)}]" 237 elif isinstance(obj, types.ModuleType): 238 return f"module\n{obj.__name__}" 239 elif isinstance(obj, type): 240 return f"type\n{obj.__name__}" 241 elif isinstance(obj, weakref.ref): 242 referent = obj() 243 if referent is None: 244 return "weakref (dead referent)" 245 else: 246 return f"weakref to id 0x{id(referent):x}" 247 elif isinstance(obj, types.FrameType): 248 filename = obj.f_code.co_filename 249 if len(filename) > FRAME_FILENAME_LIMIT: 250 filename = "..." + filename[-(FRAME_FILENAME_LIMIT - 3):] 251 return f"frame\n{filename}:{obj.f_lineno}" 252 else: 253 return f"object\n{type(obj).__module__}.{type(obj).__name__}" 254 255 256 257class Node(NamedTuple): 258 label: str 259 context: Optional[str] 260 root: bool 261 referrents: List[Tuple[str, int]] 262 263def create_graph(objects, *, context=None, filter=None): 264 if context is None: 265 context = cuda_allocation_context() 266 if filter is None: 267 filter = is_cuda_tensor 268 269 nodes = [Node(object_annotation(obj), context(obj), filter(obj), []) for obj in objects] 270 node_referrers: List[List[int]] = [[] for obj in objects] 271 272 id_to_node = {id(obj): i for i, obj in enumerate(objects)} 273 for obj in objects: 274 fidx = id_to_node[id(obj)] 275 f = nodes[fidx] 276 references = annotated_references(obj) 277 for referrent in gc.get_referents(obj): 278 rid = id(referrent) 279 tidx = id_to_node.get(rid, None) 280 if tidx is None: 281 continue 282 t = nodes[tidx] 283 labels = references.get(rid, ["?"]) 284 node_referrers[tidx].append(fidx) 285 for label in labels: 286 f.referrents.append((label, tidx)) 287 288 to_search = [i for i, n in enumerate(nodes) if n.root] 289 to_keep = set() 290 while to_search: 291 idx = to_search.pop() 292 if idx in to_keep: 293 continue 294 to_keep.add(idx) 295 referrers = node_referrers[idx] 296 to_search.extend(referrers) 297 id_to_filtered_id: Dict[int, int] = {} 298 filtered: List[Any] = [] 299 for i, n in enumerate(nodes): 300 if i in to_keep: 301 id_to_filtered_id[i] = len(id_to_filtered_id) 302 filtered.append(n) 303 for n in filtered: 304 n.referrents[:] = [(label, id_to_filtered_id[idx]) 305 for (label, idx) in n.referrents 306 if idx in id_to_filtered_id] 307 return filtered 308 309def escape(n): 310 return json.dumps(n) 311 312 313def is_cuda_tensor(obj): 314 return isinstance(obj, torch.Tensor) and obj.is_cuda and not isinstance(obj, torch._subclasses.FakeTensor) 315 316def cuda_allocation_context(): 317 snapshot = torch.cuda.memory._snapshot() 318 addr_to_frame = {} 319 for seg in snapshot['segments']: 320 addr = seg['address'] 321 for blk in seg['blocks']: 322 if blk['state'] == 'active_allocated': 323 frames, real_size = _block_extra(blk) 324 addr_to_frame[addr] = frames 325 addr += blk['size'] 326 327 def object_context(obj): 328 if is_cuda_tensor(obj): 329 addr = obj.untyped_storage().data_ptr() 330 frames = addr_to_frame.get(addr) 331 if frames is not None: 332 return '\n'.join(_frames_fmt(frames, full_filename=True)) 333 return None 334 return object_context 335 336def to_dot(nodes): 337 lines = ["digraph GraphName {", "node [shape=rect];", 'rankdir=LR;'] 338 for i, n in enumerate(nodes): 339 lines.append(f'{i} [label={escape(n.label)}, color={ "red" if n.root else "black"}];') 340 341 for i, f in enumerate(nodes): 342 for label, j in f.referrents: 343 lines.append(f'{i} -> {j} [label = {escape(label)}]') 344 lines.append("}\n") 345 return '\n'.join(lines) 346 347_template = """ 348<!DOCTYPE html> 349<html> 350<head> 351 <style> 352 body { 353 margin: 0; 354 padding: 0; 355 overflow: hidden; 356 } 357 358 #container { 359 display: flex; 360 flex-direction: column; 361 height: 100vh; 362 } 363 364 #main { 365 flex: 2; 366 overflow: auto; 367 } 368 369 #preContainer { 370 flex: 1; 371 overflow: auto; 372 } 373 374 svg { 375 overflow: scroll; 376 } 377 378 pre { 379 margin: 0; 380 padding: 10px; 381 } 382 </style> 383</head> 384<body> 385 <div id="container"> 386 <div id="main"> 387 </div> 388 <div id="preContainer"> 389 <pre id="stacktrace">Mouse over tensor objects to see where they were allocated.</pre> 390 </div> 391 </div> 392<script src='https://cdnjs.cloudflare.com/ajax/libs/viz.js/1.8.0/viz-lite.js'></script> 393<script> 394let dot = $DOT 395let image = Viz(dot, {format: 'svg'}); 396document.getElementById('main').innerHTML = image 397$LISTENERS 398</script> 399</body> 400</html> 401""" 402_listener_template = """ 403document.getElementById('node{id}').addEventListener('mouseover', function(event) {{ 404 document.getElementById("stacktrace").textContent = {stack} 405}}) 406""" 407def to_html(nodes): 408 listeners = [] 409 for i, n in enumerate(nodes): 410 if n.context is None: 411 continue 412 s = _listener_template.format(id=str(i + 1), stack=escape(f'{n.label}:\n{n.context}')) 413 listeners.append(s) 414 dot = to_dot(nodes) 415 return _template.replace('$DOT', repr(dot)).replace('$LISTENERS', '\n'.join(listeners)) 416 417def observe_tensor_cycles(callback): 418 torch.cuda.memory._record_memory_history(max_entries=100000) 419 420 def observer(garbage): 421 if garbage: 422 if not any(is_cuda_tensor(obj) for obj in garbage): 423 logger.info("No CUDA Tensors found in garbage") 424 return 425 callback(to_html(create_graph(garbage))) 426 return observe_garbage(observer) 427 428 429def warn_tensor_cycles(): 430 """ 431 Install a warning that reports whenever a cycle that is holding CUDA memory is observed. 432 433 The warning produces an .html file that visualizes the cycle, 434 and links it to the stack frame that allocted the CUDA tensor. 435 436 Reference cycles are freed by the cycle collector rather than being cleaned up 437 when the objects in the cycle first become unreachable. If a cycle points to a tensor, 438 the CUDA memory for that tensor will not be freed until garbage collection runs. 439 Accumulation of CUDA allocations can lead to out of memory errors (OOMs), as well as 440 non-deterministic allocation behavior which is harder to debug. 441 """ 442 logger.info("Watching Python reference cycles for CUDA Tensors.") 443 444 def write_and_log(html): 445 with NamedTemporaryFile('w', suffix='.html', delete=False) as f: 446 f.write(html) 447 logger.warning('Reference cycle includes a CUDA Tensor see visualization of cycle %s', f.name) 448 return observe_tensor_cycles(write_and_log) 449