1# mypy: allow-untyped-defs 2""" 3Contains utils for logging in AOTAutograd, including managing the names of the graphs under 4compilation, capturing user-friendly tracebacks, and debug messages. 5""" 6 7import collections 8from contextlib import contextmanager 9from typing import List, Tuple 10 11import torch 12import torch.fx.traceback as fx_traceback 13 14 15# This is a list since looking forward, we can have this arbitrarily nested. 16graph_being_compiled: List[str] = [] 17# TODO: It would be nice to reset the numbering every time aot_id goes 18# up, but this is annoying to do right now (because we don't know if 19# an aot_id will come back from the dead), so right now this also happens 20# to be a globally unique number too (at the cost of wobbling if you change 21# how the graphs compile) 22nth_graph: int = 0 23model_name: str = "model" 24 25 26def set_model_name(name): 27 global model_name 28 model_name = name 29 30 31def get_aot_compilation_context() -> Tuple[List[str], str, int]: 32 return list(graph_being_compiled), model_name, nth_graph 33 34 35def get_aot_graph_name() -> str: 36 """ 37 Returns the name of the graph being compiled. 38 """ 39 global model_name, graph_being_compiled, nth_graph 40 return f"{model_name}__{'_'.join(graph_being_compiled)}_{nth_graph}" 41 42 43get_graph_being_compiled = get_aot_graph_name 44 45 46@contextmanager 47def track_graph_compiling(aot_config, graph_name): 48 global graph_being_compiled 49 # TODO: Don't shove the aot_id in here; set it in the context 50 graph_being_compiled = [f"{aot_config.aot_id}_{graph_name}"] 51 old_name = None 52 if tracing_context := torch._guards.TracingContext.try_get(): 53 old_name = tracing_context.aot_graph_name 54 tracing_context.aot_graph_name = graph_being_compiled 55 has_tracing_context = True 56 else: 57 has_tracing_context = False 58 try: 59 yield 60 finally: 61 global nth_graph 62 nth_graph += 1 63 graph_being_compiled = [] 64 if has_tracing_context: 65 if tracing_context := torch._guards.TracingContext.try_get(): 66 tracing_context.aot_graph_name = old_name 67 68 69# Set up hooks so that during backward the fx's stack_trace is properly set 70callback_set = False 71 72 73def setup_stacktrace_preservation_hooks(roots: List): 74 def iter_graph(roots): 75 if not roots: 76 return 77 seen = set() 78 q = collections.deque() # type: ignore[var-annotated] 79 for node in roots: 80 if node is not None and node not in seen: 81 seen.add(node) 82 q.append(node) 83 84 while q: 85 node = q.popleft() 86 for fn, _idx in node.next_functions: 87 if fn in seen or fn is None: 88 continue 89 seen.add(fn) 90 q.append(fn) 91 92 yield node 93 94 def get_callback(saved_stack_): 95 def callback(): 96 global callback_set 97 fx_traceback.set_stack_trace(saved_stack_) 98 callback_set = False 99 100 return callback 101 102 def get_prehook(stack_, seq_nr): 103 def prehook(grad_output): 104 global callback_set 105 106 if not callback_set: 107 torch.autograd.variable.Variable._execution_engine.queue_callback( # type: ignore[attr-defined] 108 get_callback(fx_traceback.format_stack()) 109 ) 110 callback_set = True 111 112 fx_traceback.set_stack_trace(stack_) 113 fx_traceback.set_grad_fn_seq_nr(seq_nr) 114 115 return prehook 116 117 def get_posthook(special_stack_, seq_nr): 118 def posthook(grad_input, grad_output): 119 fx_traceback.set_stack_trace(special_stack_) 120 fx_traceback.reset_grad_fn_seq_nr() 121 122 return posthook 123 124 for node in iter_graph(roots): 125 forward_node_stack = node.metadata.get("traceback_", []) 126 node.register_prehook(get_prehook(forward_node_stack, node._sequence_nr())) 127 128 special_stack = forward_node_stack.copy() 129 special_stack.append( 130 "Gradient addition node due to multiple use of tensor around:" 131 ) 132 node.register_hook(get_posthook(special_stack, node._sequence_nr())) 133 134 135def describe_input(i, aot_config): 136 if i < aot_config.num_params_buffers: 137 return f"parameter/buffer {i}" 138 else: 139 return f"input {i - aot_config.num_params_buffers}" 140 141 142def format_guard_bug_msg(aot_config, expected): 143 return ( 144 f"At compilation time, graph {aot_config.aot_id} was compiled under the " 145 f"assumption that {expected}, but at runtime this was not the case. " 146 "This indicates a guard bug in AOTAutograd or Dynamo, please file a bug to PyTorch." 147 ) 148