xref: /aosp_15_r20/external/pytorch/torch/fx/traceback.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import traceback
3from contextlib import contextmanager
4from typing import List, Any, Dict
5from ._compatibility import compatibility
6
7__all__ = ['preserve_node_meta', 'has_preserved_node_meta',
8           'set_stack_trace', 'set_grad_fn_seq_nr', 'reset_grad_fn_seq_nr',
9           'format_stack', 'set_current_meta', 'get_current_meta']
10
11current_meta: Dict[str, Any] = {}
12should_preserve_node_meta = False
13
14
15@compatibility(is_backward_compatible=False)
16@contextmanager
17def preserve_node_meta():
18    global should_preserve_node_meta
19    global current_meta
20
21    saved_should_preserve_node_meta = should_preserve_node_meta
22    # Shallow copy is OK since fields of current_meta are not mutated
23    saved_current_meta = current_meta.copy()
24    try:
25        should_preserve_node_meta = True
26        yield
27    finally:
28        should_preserve_node_meta = saved_should_preserve_node_meta
29        current_meta = saved_current_meta
30
31
32@compatibility(is_backward_compatible=False)
33def set_stack_trace(stack : List[str]):
34    global current_meta
35
36    if should_preserve_node_meta and stack:
37        current_meta["stack_trace"] = "".join(stack)
38
39
40@compatibility(is_backward_compatible=False)
41def set_grad_fn_seq_nr(seq_nr):
42    global current_meta
43
44    if should_preserve_node_meta:
45        # The seq_nr is captured by eager mode in the grad_fn during forward
46        current_meta["grad_fn_seq_nr"] = current_meta.get("grad_fn_seq_nr", []) + [seq_nr]
47        current_meta["in_grad_fn"] = current_meta.get("in_grad_fn", 0) + 1
48
49
50@compatibility(is_backward_compatible=False)
51def reset_grad_fn_seq_nr():
52    # NB: reset state properly, this would be helpful towards supporting
53    #     reentrant autograd if we actually wanted to do that.
54    global current_meta
55    if should_preserve_node_meta:
56        current_level = current_meta.get("in_grad_fn", 0)
57        assert current_level > 0
58        if current_level == 1:
59            del current_meta["in_grad_fn"]
60            del current_meta["grad_fn_seq_nr"]
61        else:
62            current_meta["in_grad_fn"] = current_level - 1
63            current_meta["grad_fn_seq_nr"] = current_meta["grad_fn_seq_nr"][:-1]
64
65
66@compatibility(is_backward_compatible=False)
67def format_stack() -> List[str]:
68    if should_preserve_node_meta:
69        return [current_meta.get("stack_trace", "")]
70    else:
71        # fallback to traceback.format_stack()
72        return traceback.format_list(traceback.extract_stack()[:-1])
73
74
75@compatibility(is_backward_compatible=False)
76def has_preserved_node_meta() -> bool:
77    return should_preserve_node_meta
78
79
80@compatibility(is_backward_compatible=False)
81@contextmanager
82def set_current_meta(node):
83    global current_meta
84    if should_preserve_node_meta and node.meta:
85        saved_meta = current_meta
86        try:
87            current_meta = node.meta.copy()
88
89            # Append (node.name, node.target) onto "from_node" for provenance tracking
90            if "from_node" not in current_meta:
91                current_meta["from_node"] = [(node.name, node.target)]
92            elif current_meta["from_node"][-1][0] != node.name:
93                current_meta["from_node"] = current_meta["from_node"] + [(node.name, node.target)]
94
95            yield
96        finally:
97            current_meta = saved_meta
98    else:
99        yield
100
101
102@compatibility(is_backward_compatible=False)
103def get_current_meta() -> Dict[str, Any]:
104    return current_meta
105