xref: /aosp_15_r20/external/pytorch/tools/gdb/pytorch-gdb.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import textwrap
2from typing import Any
3
4import gdb  # type: ignore[import]
5
6
7class DisableBreakpoints:
8    """
9    Context-manager to temporarily disable all gdb breakpoints, useful if
10    there is a risk to hit one during the evaluation of one of our custom
11    commands
12    """
13
14    def __enter__(self) -> None:
15        self.disabled_breakpoints = []
16        for b in gdb.breakpoints():
17            if b.enabled:
18                b.enabled = False
19                self.disabled_breakpoints.append(b)
20
21    def __exit__(self, etype: Any, evalue: Any, tb: Any) -> None:
22        for b in self.disabled_breakpoints:
23            b.enabled = True
24
25
26class TensorRepr(gdb.Command):  # type: ignore[misc, no-any-unimported]
27    """
28    Print a human readable representation of the given at::Tensor.
29    Usage: torch-tensor-repr EXP
30
31    at::Tensor instances do not have a C++ implementation of a repr method: in
32    pytorch, this is done by pure-Python code. As such, torch-tensor-repr
33    internally creates a Python wrapper for the given tensor and call repr()
34    on it.
35    """
36
37    __doc__ = textwrap.dedent(__doc__).strip()
38
39    def __init__(self) -> None:
40        gdb.Command.__init__(
41            self, "torch-tensor-repr", gdb.COMMAND_USER, gdb.COMPLETE_EXPRESSION
42        )
43
44    def invoke(self, args: str, from_tty: bool) -> None:
45        args = gdb.string_to_argv(args)
46        if len(args) != 1:
47            print("Usage: torch-tensor-repr EXP")
48            return
49        name = args[0]
50        with DisableBreakpoints():
51            res = gdb.parse_and_eval(f"torch::gdb::tensor_repr({name})")
52            print(f"Python-level repr of {name}:")
53            print(res.string())
54            # torch::gdb::tensor_repr returns a malloc()ed buffer, let's free it
55            gdb.parse_and_eval(f"(void)free({int(res)})")
56
57
58TensorRepr()
59