xref: /aosp_15_r20/external/pytorch/torch/_functorch/_aot_autograd/logging_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""
3Contains utils for logging in AOTAutograd, including managing the names of the graphs under
4compilation, capturing user-friendly tracebacks, and debug messages.
5"""
6
7import collections
8from contextlib import contextmanager
9from typing import List, Tuple
10
11import torch
12import torch.fx.traceback as fx_traceback
13
14
15# This is a list since looking forward, we can have this arbitrarily nested.
16graph_being_compiled: List[str] = []
17# TODO: It would be nice to reset the numbering every time aot_id goes
18# up, but this is annoying to do right now (because we don't know if
19# an aot_id will come back from the dead), so right now this also happens
20# to be a globally unique number too (at the cost of wobbling if you change
21# how the graphs compile)
22nth_graph: int = 0
23model_name: str = "model"
24
25
26def set_model_name(name):
27    global model_name
28    model_name = name
29
30
31def get_aot_compilation_context() -> Tuple[List[str], str, int]:
32    return list(graph_being_compiled), model_name, nth_graph
33
34
35def get_aot_graph_name() -> str:
36    """
37    Returns the name of the graph being compiled.
38    """
39    global model_name, graph_being_compiled, nth_graph
40    return f"{model_name}__{'_'.join(graph_being_compiled)}_{nth_graph}"
41
42
43get_graph_being_compiled = get_aot_graph_name
44
45
46@contextmanager
47def track_graph_compiling(aot_config, graph_name):
48    global graph_being_compiled
49    # TODO: Don't shove the aot_id in here; set it in the context
50    graph_being_compiled = [f"{aot_config.aot_id}_{graph_name}"]
51    old_name = None
52    if tracing_context := torch._guards.TracingContext.try_get():
53        old_name = tracing_context.aot_graph_name
54        tracing_context.aot_graph_name = graph_being_compiled
55        has_tracing_context = True
56    else:
57        has_tracing_context = False
58    try:
59        yield
60    finally:
61        global nth_graph
62        nth_graph += 1
63        graph_being_compiled = []
64        if has_tracing_context:
65            if tracing_context := torch._guards.TracingContext.try_get():
66                tracing_context.aot_graph_name = old_name
67
68
69# Set up hooks so that during backward the fx's stack_trace is properly set
70callback_set = False
71
72
73def setup_stacktrace_preservation_hooks(roots: List):
74    def iter_graph(roots):
75        if not roots:
76            return
77        seen = set()
78        q = collections.deque()  # type: ignore[var-annotated]
79        for node in roots:
80            if node is not None and node not in seen:
81                seen.add(node)
82                q.append(node)
83
84        while q:
85            node = q.popleft()
86            for fn, _idx in node.next_functions:
87                if fn in seen or fn is None:
88                    continue
89                seen.add(fn)
90                q.append(fn)
91
92            yield node
93
94    def get_callback(saved_stack_):
95        def callback():
96            global callback_set
97            fx_traceback.set_stack_trace(saved_stack_)
98            callback_set = False
99
100        return callback
101
102    def get_prehook(stack_, seq_nr):
103        def prehook(grad_output):
104            global callback_set
105
106            if not callback_set:
107                torch.autograd.variable.Variable._execution_engine.queue_callback(  # type: ignore[attr-defined]
108                    get_callback(fx_traceback.format_stack())
109                )
110                callback_set = True
111
112            fx_traceback.set_stack_trace(stack_)
113            fx_traceback.set_grad_fn_seq_nr(seq_nr)
114
115        return prehook
116
117    def get_posthook(special_stack_, seq_nr):
118        def posthook(grad_input, grad_output):
119            fx_traceback.set_stack_trace(special_stack_)
120            fx_traceback.reset_grad_fn_seq_nr()
121
122        return posthook
123
124    for node in iter_graph(roots):
125        forward_node_stack = node.metadata.get("traceback_", [])
126        node.register_prehook(get_prehook(forward_node_stack, node._sequence_nr()))
127
128        special_stack = forward_node_stack.copy()
129        special_stack.append(
130            "Gradient addition node due to multiple use of tensor around:"
131        )
132        node.register_hook(get_posthook(special_stack, node._sequence_nr()))
133
134
135def describe_input(i, aot_config):
136    if i < aot_config.num_params_buffers:
137        return f"parameter/buffer {i}"
138    else:
139        return f"input {i - aot_config.num_params_buffers}"
140
141
142def format_guard_bug_msg(aot_config, expected):
143    return (
144        f"At compilation time, graph {aot_config.aot_id} was compiled under the "
145        f"assumption that {expected}, but at runtime this was not the case.  "
146        "This indicates a guard bug in AOTAutograd or Dynamo, please file a bug to PyTorch."
147    )
148