xref: /aosp_15_r20/external/pytorch/torch/utils/viz/_cycles.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import gc
3import sys
4from typing import Any, Dict, List, NamedTuple, Optional, Tuple
5import types
6import weakref
7import json
8from tempfile import NamedTemporaryFile
9import torch
10from torch.cuda._memory_viz import _frames_fmt, _block_extra
11import atexit
12import logging
13logger = logging.getLogger(__name__)
14
15def observe_garbage(observer):
16    enabled = True
17
18    def disable():
19        # when GC runs during exit, things like `sys` will already be unloaded
20        # so we have to disable the callback to avoid hitting errors.
21        nonlocal enabled
22        enabled = False
23    atexit.register(disable)
24
25    def gc_callback(phase, info):
26        nonlocal enabled
27        if not enabled:
28            return
29        if phase == "start":
30            gc.set_debug(gc.DEBUG_SAVEALL)
31        elif phase == "stop":
32            orig_trace = sys.getprofile()
33            self_return = [False]
34
35            def do_collect(*args, **kwargs):
36                nonlocal enabled
37                if not self_return[0]:
38                    self_return[0] = True
39                else:
40                    sys.setprofile(orig_trace)
41                    enabled = False
42                    try:
43                        # things in gc.garbage have survived a collection
44                        # so to free them we have to collect a generation greater than them
45                        # but that might _also_ free other stuff and we don't want to miss
46                        # that stuff. So we have to now force gc at the highest level here,
47                        # report all of what we found, _then_ we can free it up.
48                        if info['generation'] != 2:
49                            gc.collect()
50                        observer(gc.garbage)
51                        gc.garbage.clear()
52                        # we have to re-run GC to clean up the cycles
53                        # we saved from before.
54                        gc.set_debug(0)
55                        before = torch.cuda.memory_allocated()
56                        gc.collect()
57                        after = torch.cuda.memory_allocated()
58                        if before != after:
59                            logger.warning("CUDA Memory changed during GC, %d bytes freed.", before - after)
60                    finally:
61                        enabled = True
62                if orig_trace is not None:
63                    return orig_trace(*args, **kwargs)
64            sys.setprofile(do_collect)
65
66    gc.callbacks.append(gc_callback)
67
68    # provide a way to disarm the callback
69    def remove():
70        gc.callbacks.remove(gc_callback)
71    return remove
72
73# Function to visualize cycles adapated from refcycle:
74# Copyright 2013 Mark Dickinson
75#
76# Licensed under the Apache License, Version 2.0 (the "License");
77# you may not use this file except in compliance with the License.
78# You may obtain a copy of the License at
79#
80#   http://www.apache.org/licenses/LICENSE-2.0
81#
82# Unless required by applicable law or agreed to in writing, software
83# distributed under the License is distributed on an "AS IS" BASIS,
84# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
85# See the License for the specific language governing permissions and
86# limitations under the License.
87
88def _get_cell_type():
89    def f(x=None):
90        return lambda: x
91    return type(f().__closure__[0])
92
93CellType = _get_cell_type()
94
95def annotated_references(obj):
96    """
97    Return known information about references held by the given object.
98
99    Returns a mapping from referents to lists of descriptions.  Note that there
100    may be more than one edge leading to any particular referent; hence the
101    need for a list.  Descriptions are currently strings.
102
103    """
104    references: Dict[int, List[str]] = {}
105
106    def add_reference(name, obj):
107        references.setdefault(id(obj), []).append(name)
108
109    def add_attrs(*attrs):
110        for attr in attrs:
111            if hasattr(obj, attr):
112                add_reference(attr, getattr(obj, attr))
113
114    def add_cell_references():
115        try:
116            add_attrs("cell_contents")
117        except ValueError:
118            # if cell_contents is empty,
119            # accessing it raises ValueError
120            # in this case there is no object to
121            # annotate
122            pass
123
124    def add_function_references():
125        add_attrs("__defaults__",
126                  "__closure__",
127                  "__globals__",
128                  "__code__",
129                  "__name__",
130                  "__module__",
131                  "__doc__"
132                  "__qualname__",
133                  "__annotations__",
134                  "__kwdefaults__")
135
136
137    def add_sequence_references():
138        for position, item in enumerate(obj):
139            add_reference(f"[{position}]", item)
140
141    def add_dict_references():
142        for key, value in obj.items():
143            add_reference("key", key)
144            add_reference(f"[{repr(key)}]", value)
145
146    def add_set_references():
147        for elt in obj:
148            add_reference("element", elt)
149
150    def add_bound_method_references():
151        add_attrs("__self__", "__func__", "im_class")
152
153    def add_weakref_references():
154        # For subclasses of weakref, we can't reliably distinguish the
155        # callback (if any) from other attributes.
156        if type(obj) is weakref.ref:
157            referents = gc.get_referents(obj)
158            if len(referents) == 1:
159                target = referents[0]
160                add_reference("__callback__", target)
161
162
163    def add_frame_references():
164        f_locals = obj.f_locals
165        add_attrs("f_back", "f_code", "f_builtins", "f_globals", "f_trace", "f_locals")
166        # Some badly-behaved code replaces the f_locals dict with
167        # something that doesn't support the full dict interface.  So we
168        # only continue with the annotation if f_locals is a Python dict.
169        if type(f_locals) is dict:
170            for name, local in obj.f_locals.items():
171                add_reference(f"local {name}", local)
172
173    def add_getset_descriptor_references():
174        add_attrs("__objclass__", "__name__", "__doc__")
175
176    type_based_references = {
177        tuple: add_sequence_references,
178        list: add_sequence_references,
179        dict: add_dict_references,
180        set: add_set_references,
181        frozenset: add_set_references,
182        types.FunctionType: add_function_references,
183        types.FrameType: add_frame_references,
184        CellType: add_cell_references,
185        types.MethodType: add_bound_method_references,
186        weakref.ref: add_weakref_references,
187        types.GetSetDescriptorType: add_getset_descriptor_references,
188    }
189
190    for type_ in type(obj).__mro__:
191        if type_ in type_based_references:
192            type_based_references[type_]()
193
194    add_attrs("__dict__", "__class__")
195    if isinstance(obj, type):
196        add_attrs("__mro__")
197
198    return references
199
200###############################################################################
201# Object annotations.
202
203
204BASE_TYPES = (int, float, complex, type(None), str, bytes)
205FRAME_FILENAME_LIMIT = 32
206
207def object_annotation(obj):
208    """
209    Return a string to be used for Graphviz nodes.
210
211    The string should be short but as informative as possible.
212    """
213
214    def format_sequence(obj):
215        body = ','.join(repr(x) if isinstance(x, BASE_TYPES) else type(x).__name__ for i, x in zip(range(8), obj))
216        if len(obj) > 8:
217            body = f'{body}, ...{len(obj) - 8}'
218        return body
219
220    # For basic types, use the repr.
221    if isinstance(obj, BASE_TYPES):
222        return repr(obj)
223    if type(obj).__name__ == 'function':
224        return f"function\n{obj.__name__}"
225    elif isinstance(obj, types.MethodType):
226        try:
227            func_name = obj.__func__.__qualname__
228        except AttributeError:
229            func_name = "<anonymous>"
230        return f"instancemethod\n{func_name}"
231    elif isinstance(obj, list):
232        return f"[{format_sequence(obj)}]"
233    elif isinstance(obj, tuple):
234        return f"({format_sequence(obj)})"
235    elif isinstance(obj, dict):
236        return f"dict[{len(obj)}]"
237    elif isinstance(obj, types.ModuleType):
238        return f"module\n{obj.__name__}"
239    elif isinstance(obj, type):
240        return f"type\n{obj.__name__}"
241    elif isinstance(obj, weakref.ref):
242        referent = obj()
243        if referent is None:
244            return "weakref (dead referent)"
245        else:
246            return f"weakref to id 0x{id(referent):x}"
247    elif isinstance(obj, types.FrameType):
248        filename = obj.f_code.co_filename
249        if len(filename) > FRAME_FILENAME_LIMIT:
250            filename = "..." + filename[-(FRAME_FILENAME_LIMIT - 3):]
251        return f"frame\n{filename}:{obj.f_lineno}"
252    else:
253        return f"object\n{type(obj).__module__}.{type(obj).__name__}"
254
255
256
257class Node(NamedTuple):
258    label: str
259    context: Optional[str]
260    root: bool
261    referrents: List[Tuple[str, int]]
262
263def create_graph(objects, *, context=None, filter=None):
264    if context is None:
265        context = cuda_allocation_context()
266    if filter is None:
267        filter = is_cuda_tensor
268
269    nodes = [Node(object_annotation(obj), context(obj), filter(obj), []) for obj in objects]
270    node_referrers: List[List[int]] = [[] for obj in objects]
271
272    id_to_node = {id(obj): i for i, obj in enumerate(objects)}
273    for obj in objects:
274        fidx = id_to_node[id(obj)]
275        f = nodes[fidx]
276        references = annotated_references(obj)
277        for referrent in gc.get_referents(obj):
278            rid = id(referrent)
279            tidx = id_to_node.get(rid, None)
280            if tidx is None:
281                continue
282            t = nodes[tidx]
283            labels = references.get(rid, ["?"])
284            node_referrers[tidx].append(fidx)
285            for label in labels:
286                f.referrents.append((label, tidx))
287
288    to_search = [i for i, n in enumerate(nodes) if n.root]
289    to_keep = set()
290    while to_search:
291        idx = to_search.pop()
292        if idx in to_keep:
293            continue
294        to_keep.add(idx)
295        referrers = node_referrers[idx]
296        to_search.extend(referrers)
297    id_to_filtered_id: Dict[int, int] = {}
298    filtered: List[Any] = []
299    for i, n in enumerate(nodes):
300        if i in to_keep:
301            id_to_filtered_id[i] = len(id_to_filtered_id)
302            filtered.append(n)
303    for n in filtered:
304        n.referrents[:] = [(label, id_to_filtered_id[idx])
305                           for (label, idx) in n.referrents
306                           if idx in id_to_filtered_id]
307    return filtered
308
309def escape(n):
310    return json.dumps(n)
311
312
313def is_cuda_tensor(obj):
314    return isinstance(obj, torch.Tensor) and obj.is_cuda and not isinstance(obj, torch._subclasses.FakeTensor)
315
316def cuda_allocation_context():
317    snapshot = torch.cuda.memory._snapshot()
318    addr_to_frame = {}
319    for seg in snapshot['segments']:
320        addr = seg['address']
321        for blk in seg['blocks']:
322            if blk['state'] == 'active_allocated':
323                frames, real_size = _block_extra(blk)
324                addr_to_frame[addr] = frames
325            addr += blk['size']
326
327    def object_context(obj):
328        if is_cuda_tensor(obj):
329            addr = obj.untyped_storage().data_ptr()
330            frames = addr_to_frame.get(addr)
331            if frames is not None:
332                return '\n'.join(_frames_fmt(frames, full_filename=True))
333        return None
334    return object_context
335
336def to_dot(nodes):
337    lines = ["digraph GraphName {", "node [shape=rect];", 'rankdir=LR;']
338    for i, n in enumerate(nodes):
339        lines.append(f'{i} [label={escape(n.label)}, color={ "red" if n.root else "black"}];')
340
341    for i, f in enumerate(nodes):
342        for label, j in f.referrents:
343            lines.append(f'{i} -> {j} [label = {escape(label)}]')
344    lines.append("}\n")
345    return '\n'.join(lines)
346
347_template = """
348<!DOCTYPE html>
349<html>
350<head>
351  <style>
352    body {
353      margin: 0;
354      padding: 0;
355      overflow: hidden;
356    }
357
358    #container {
359      display: flex;
360      flex-direction: column;
361      height: 100vh;
362    }
363
364    #main {
365      flex: 2;
366      overflow: auto;
367    }
368
369    #preContainer {
370      flex: 1;
371      overflow: auto;
372    }
373
374    svg {
375        overflow: scroll;
376    }
377
378    pre {
379      margin: 0;
380      padding: 10px;
381    }
382  </style>
383</head>
384<body>
385  <div id="container">
386    <div id="main">
387    </div>
388    <div id="preContainer">
389      <pre id="stacktrace">Mouse over tensor objects to see where they were allocated.</pre>
390    </div>
391  </div>
392<script src='https://cdnjs.cloudflare.com/ajax/libs/viz.js/1.8.0/viz-lite.js'></script>
393<script>
394let dot = $DOT
395let image = Viz(dot, {format: 'svg'});
396document.getElementById('main').innerHTML = image
397$LISTENERS
398</script>
399</body>
400</html>
401"""
402_listener_template = """
403document.getElementById('node{id}').addEventListener('mouseover', function(event) {{
404  document.getElementById("stacktrace").textContent = {stack}
405}})
406"""
407def to_html(nodes):
408    listeners = []
409    for i, n in enumerate(nodes):
410        if n.context is None:
411            continue
412        s = _listener_template.format(id=str(i + 1), stack=escape(f'{n.label}:\n{n.context}'))
413        listeners.append(s)
414    dot = to_dot(nodes)
415    return _template.replace('$DOT', repr(dot)).replace('$LISTENERS', '\n'.join(listeners))
416
417def observe_tensor_cycles(callback):
418    torch.cuda.memory._record_memory_history(max_entries=100000)
419
420    def observer(garbage):
421        if garbage:
422            if not any(is_cuda_tensor(obj) for obj in garbage):
423                logger.info("No CUDA Tensors found in garbage")
424                return
425            callback(to_html(create_graph(garbage)))
426    return observe_garbage(observer)
427
428
429def warn_tensor_cycles():
430    """
431    Install a warning that reports whenever a cycle that is holding CUDA memory is observed.
432
433    The warning produces an .html file that visualizes the cycle,
434    and links it to the stack frame that allocted the CUDA tensor.
435
436    Reference cycles are freed by the cycle collector rather than being cleaned up
437    when the objects in the cycle first become unreachable. If a cycle points to a tensor,
438    the CUDA memory for that tensor will not be freed until garbage collection runs.
439    Accumulation of CUDA allocations can lead to out of memory errors (OOMs), as well as
440    non-deterministic allocation behavior which is harder to debug.
441    """
442    logger.info("Watching Python reference cycles for CUDA Tensors.")
443
444    def write_and_log(html):
445        with NamedTemporaryFile('w', suffix='.html', delete=False) as f:
446            f.write(html)
447            logger.warning('Reference cycle includes a CUDA Tensor see visualization of cycle %s', f.name)
448    return observe_tensor_cycles(write_and_log)
449