xref: /aosp_15_r20/external/pytorch/torch/utils/_traceback.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from types import TracebackType
3from typing import List, Optional
4import tempfile
5import traceback
6import contextlib
7import inspect
8import os.path
9
10# This file contains utilities for ensuring dynamically compile()'d
11# code fragments display their line numbers in backtraces.
12#
13# The constraints:
14#
15# - We don't have control over the user exception printer (in particular,
16#   we cannot assume the linecache trick will work, c.f.
17#   https://stackoverflow.com/q/50515651/23845 )
18#
19# - We don't want to create temporary files every time we compile()
20#   some code; file creation should happen lazily only at exception
21#   time.  Arguably, you *should* be willing to write out your
22#   generated Python code to file system, but in some situations
23#   (esp. library code) it would violate user expectation to write
24#   to the file system, so we try to avoid it.  In particular, we'd
25#   like to keep the files around, so users can open up the files
26#   mentioned in the trace; if the file is invisible, we want to
27#   avoid clogging up the filesystem.
28#
29#   If this is not a constraint for you, there is a substantially simpler
30#   way to implement the functionality in this PR: instead of using
31#   eval/exec directly, just always write a Python file to filesystem
32#   and compile that.
33#
34# - You have control over a context where the compiled code will get
35#   executed, so that we can interpose while the stack is unwinding
36#   (otherwise, we have no way to interpose on the exception printing
37#   process.)
38#
39# There are two things you have to do to make use of the utilities here:
40#
41# - When you compile your source code, you must save its string source
42#   in its f_globals under the magic name "__compile_source__"
43#
44# - Before running the compiled code, enter the
45#   report_compile_source_on_error() context manager.
46
47@contextlib.contextmanager
48def report_compile_source_on_error():
49    try:
50        yield
51    except Exception as exc:
52        tb = exc.__traceback__
53
54        # Walk the traceback, looking for frames that have
55        # source attached
56        stack = []
57        while tb is not None:
58            filename = tb.tb_frame.f_code.co_filename
59            source = tb.tb_frame.f_globals.get("__compile_source__")
60
61            if filename == "<string>" and source is not None:
62                # What black magic are we doing here?  Intuitively, what
63                # we would like to do is overwrite the co_filename on any
64                # frames that were generated from exec/eval so that they
65                # point to a temporary file that has the actual line
66                # information, so Python's default error printer can print
67                # useful line information on it.
68                #
69                # Writing out the temporary file is easy.  But overwriting
70                # co_filename is not!  You can't modify the code object
71                # associated with a frame.  You can, however, reconstruct
72                # a traceback with entirely new frames from scratch, so that's
73                # what we do.  But there's another problem, which is how to
74                # make the frame?
75                #
76                # The black magic is we make a frankenstein frame and code
77                # object which resembles the original frame/code enough so
78                # that it will print properly under traceback and the default
79                # error printer, but IT IS NOT THE ORIGINAL FRAME (you
80                # couldn't, e.g., execute its code with different variables
81                # and expect it to work.)
82
83                # Don't delete the temporary file so the user can inspect it
84                # TODO: This creates a temporary file for every frame, but we
85                # technically only need one per distinct __compile_source__
86                with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix=".py") as f:
87                    f.write(source)
88                # Create a frame.  Python doesn't let you construct
89                # FrameType directly, so just make one with compile
90                frame = tb.tb_frame
91                code = compile('__inspect_currentframe()', f.name, 'eval')
92                code = code.replace(co_name=frame.f_code.co_name)
93                # Python 3.11 only
94                if hasattr(frame.f_code, 'co_linetable'):
95                    # We can't copy ALL of the metadata over, because you
96                    # can cause Python to segfault this way.  What exactly
97                    # do we need?  We need enough information for
98                    # traceback to be able to print the exception
99                    # correctly.  Code reading Lib/traceback.py reveals
100                    # that traceback calls code.co_positions() in order to
101                    # get the augmented line/col numbers.  Objects/codeobject.c,
102                    # specifically _PyCode_InitAddressRange, reveals that
103                    # this iterator is initialized from co_linetable and
104                    # co_firstfileno.  So copy these we must!
105                    code = code.replace(  # type: ignore[call-arg]
106                        co_linetable=frame.f_code.co_linetable,  # type: ignore[attr-defined]
107                        co_firstlineno=frame.f_code.co_firstlineno,  # type: ignore[attr-defined]
108                    )
109                fake_frame = eval(
110                    code,
111                    frame.f_globals,
112                    {
113                        **frame.f_locals,
114                        '__inspect_currentframe': inspect.currentframe
115                    }
116                )
117                fake_tb = TracebackType(
118                    None, fake_frame, tb.tb_lasti, tb.tb_lineno
119                )
120                stack.append(fake_tb)
121            else:
122                stack.append(tb)
123
124            tb = tb.tb_next
125
126        # Reconstruct the linked list
127        tb_next = None
128        for tb in reversed(stack):
129            tb.tb_next = tb_next
130            tb_next = tb
131
132        raise exc.with_traceback(tb_next)  # noqa: B904
133
134def shorten_filename(fn, *, base=None):
135    """Shorten a source filepath, with the assumption that torch/ subdirectories don't need to be shown to user."""
136    if base is None:
137        base = os.path.dirname(os.path.dirname(__file__))
138    # Truncate torch/foo.py to foo.py
139    try:
140        prefix = os.path.commonpath([fn, base])
141    except ValueError:
142        return fn
143    else:
144        return fn[len(prefix) + 1:]
145
146def format_frame(frame, *, base=None, line=False):
147    """
148    Format a FrameSummary in a short way, without printing full absolute path or code.
149
150    The idea is the result fits on a single line.
151    """
152    extra_line = ""
153    if line:
154        extra_line = f"{frame.line}  # "
155    return f"{extra_line}{shorten_filename(frame.filename, base=base)}:{frame.lineno} in {frame.name}"
156
157def format_traceback_short(tb):
158    """Format a TracebackType in a short way, printing only the inner-most frame."""
159    return format_frame(traceback.extract_tb(tb)[-1])
160
161class CapturedTraceback:
162    __slots__ = ['tb', 'skip']
163
164    def __init__(self, tb, skip=0):
165        self.tb = tb
166        self.skip = skip
167
168    def cleanup(self):
169        self.tb = None
170
171    def summary(self):
172        import torch._C._profiler
173
174        if self.tb is None:
175            # TODO: Maybe indicate that the traceback was elided?
176            return traceback.StackSummary()
177
178        return _extract_symbolized_tb(
179            torch._C._profiler.symbolize_tracebacks([self.tb])[0],
180            self.skip
181        )
182
183    def __getstate__(self):
184        return (None, {
185            'tb': None,  # TB is not pickleable
186            'skip': self.skip,
187        })
188
189    @staticmethod
190    def extract(*, script=False, cpp=False, skip=0):
191        """
192        Like traceback.extract_stack(), but faster (approximately 20x faster); it
193        is fast enough that you can unconditionally log stacks this way as part of
194        normal execution.  It returns a torch._C._profiler.CapturedTraceback
195        object that must be formatted specially with format_captured_tb.
196
197        By default, this only reports Python backtraces (like extract_stack).  You
198        can set the script/cpp kwargs to also turn on TorchScript/C++ trace
199        reporting.
200        """
201        import torch._C._profiler
202
203        if script or cpp:
204            assert skip == 0, "skip with script/cpp NYI"
205
206        return CapturedTraceback(
207            torch._C._profiler.gather_traceback(python=True, script=script, cpp=cpp),
208            # Elide extract() frame if we don't have script/cpp frames.  If
209            # we do have those frames, it doesn't work so force zero.
210            0 if script or cpp else skip + 1
211        )
212
213    def format(self):
214        """
215        Formats a single torch._C._profiler.CapturedTraceback into a list of
216        strings equivalent to the output of traceback.format_list.  Note that if
217        pass it CapturedTraceback with C++ traces,  it is better not to use this
218        function and use the batch formatting API format_captured_tbs to amortize
219        the cost of symbolization
220        """
221        return traceback.format_list(self.summary())
222
223    @staticmethod
224    def format_all(tbs):
225        """
226        Bulk version of CapturedTraceback.format.  Returns a list of list of strings.
227        """
228        import torch._C._profiler
229
230        # Directly populate tracebacks that already have cached summaries
231        rs: List[Optional[List[str]]] = []
232        delayed_idxs = []
233        for i, tb in enumerate(tbs):
234            if tb.tb is None:
235                rs.append([])
236            else:
237                rs.append(None)
238                delayed_idxs.append(i)
239
240        stbs = torch._C._profiler.symbolize_tracebacks([tbs[i].tb for i in delayed_idxs])
241        for i, stb in zip(delayed_idxs, stbs):
242            rs[i] = traceback.format_list(tbs[i].summary())
243
244        return rs
245
246
247def _extract_symbolized_tb(tb, skip):
248    """
249    Given a symbolized traceback from symbolize_tracebacks, return a StackSummary object of
250    pre-processed stack trace entries.
251    """
252    stack = traceback.StackSummary()
253    for f in reversed(tb[skip:]):
254        stack.append(traceback.FrameSummary(f['filename'], f['line'], f['name']))
255    return stack
256