1# mypy: allow-untyped-defs 2# This file establishes the public comptime interface to Dynamo. 3# This allows Dynamo users to execute arbitrary Python code while 4# Dynamo is symbolically evaluating their original programs. 5# 6# The goal of the public API is to give users rope, without actually 7# leaking private implementation details of Dynamo. 8 9import builtins 10import dis 11import time 12import traceback 13from typing import Optional, Union 14 15import torch 16from torch.fx.experimental.symbolic_shapes import free_symbols 17 18from .exc import unimplemented 19from .variables import NewCellVariable 20from .variables.constant import ConstantVariable 21from .variables.misc import ClosureVariable 22from .variables.tensor import SymNodeVariable 23 24 25class ComptimeVar: 26 """ 27 A ComptimeVar represents a Python value, at some particular point 28 in time, in the Python code we are symbolically evaluating with 29 torchdynamo. This must be distinguished from a runtime value, as 30 at compile-time there are some properties of the variable we 31 do not know (for example, if the ComptimeVar represents a Tensor, 32 we only know metadata about the tensor; we do NOT know what the 33 actual data in the Tensor is.) 34 """ 35 36 def __init__(self, v) -> None: 37 self.__variable = v 38 39 def as_proxy(self): 40 """ 41 Returns an fx.Proxy (or tuple/list of fx.Proxy) representing 42 this variable in the FX graph we are assembling to pass 43 to the user compiler. 44 45 This method only works for variables we actually track in 46 the FX graph, aka Tensors (and ints, if you are compiling 47 with dynamic shapes). In particular, if you have a list 48 or tuple of tensors, you will get a list/tuple of proxies 49 (not a single proxy representing the entire list/tuple). 50 """ 51 return self.__variable.as_proxy() 52 53 def is_proxy(self): 54 """ 55 Returns True if as_proxy() would succeed. 56 """ 57 return self.__variable.is_proxy() 58 59 def as_fake(self): 60 """ 61 Returns a "fake" value (either a FakeTensor or a SymInt) 62 representing the variable in question. This only works 63 for variables that denote Tensor or int. You can use 64 this to query metadata; e.g., v.as_fake().size(0) will 65 tell you the compile-time known size of the tensor. 66 67 WARNING: Do NOT mutate the returned tensor. 68 """ 69 return self.__variable.as_proxy().node.meta["example_value"] 70 71 def size(self, dim: Optional[int] = None) -> Union[int, torch.SymInt]: 72 """ 73 Returns the size of the tensor (if dim is None) or the size 74 at the dimension dim. The returned size may be a SymInt. 75 """ 76 return self.as_fake().size(dim) 77 78 def python_type(self): 79 """ 80 Returns what type(v) would have returned for the variable 81 at compile time. 82 """ 83 return self.__variable.python_type() 84 85 def as_python_constant(self): 86 """ 87 Returns the Python value this variable would have, but only if it is 88 completely known at compile-time (e.g., it is constant). 89 90 WARNING: Do NOT mutate the returned constant. The returned constant 91 may or may not correspond to the actual value this variable may take 92 on at runtime; for example, if the variable in question is a constant 93 list, we may return a copy of that list. 94 """ 95 return self.__variable.as_python_constant() 96 97 def is_python_constant(self): 98 """ 99 Returns True if as_python_constant would succeed. 100 """ 101 return self.__variable.is_python_constant() 102 103 def is_dynamic(self): 104 if isinstance(self.__variable, SymNodeVariable): 105 fs = free_symbols(self.__variable.sym_num) 106 return bool(fs) 107 return False 108 109 def force_static(self): 110 """ 111 Forces that a value is static, inducing a guard on its specific value 112 """ 113 if isinstance(self.__variable, SymNodeVariable): 114 self.__variable.evaluate_expr() 115 elif isinstance(self.__variable, ConstantVariable): 116 # TODO: Maybe complain if this isn't a int/bool/float variable 117 pass 118 else: 119 raise AssertionError( 120 f"cannot force {self.__variable} ({type(self.__variable)}) static" 121 ) 122 123 def _i_will_not_complain_if_bc_breaks_VariableTracker(self): 124 """ 125 Returns the internal data structure VariableTracker that Dynamo uses 126 to represent variables at compile time. There are no BC guarantees on 127 this API and WE RESERVE THE RIGHT TO BREAK YOUR CODE if you rely on 128 it. 129 """ 130 return self.__variable 131 132 def __repr__(self) -> str: 133 return self.__variable.debug_repr() 134 135 # TODO: API for adding a custom guard 136 137 138class ComptimeContext: 139 """ 140 This context class provides access to a public API for Dynamo's internals. 141 If there is something here you would find useful that is missing, please 142 file a feature request at https://github.com/pytorch/pytorch/ 143 """ 144 145 def __init__(self, tx) -> None: 146 self.__tx = tx 147 148 def get_local(self, name: str, *, stacklevel=0) -> ComptimeVar: 149 """ 150 Retrieve the compile-time known information about a local. 151 """ 152 tx = self.__get_tx(stacklevel) 153 154 # This is analogous to LOAD_DEREF 155 if hasattr(tx, "closure_cells") and name in tx.closure_cells: 156 cell = tx.closure_cells[name] 157 if isinstance(cell, ClosureVariable): 158 return ComptimeVar(tx.output.root_tx.symbolic_locals[cell.name]) 159 else: 160 return ComptimeVar(tx.output.side_effects.load_cell(cell)) 161 else: 162 r = tx.symbolic_locals[name] 163 if isinstance(r, NewCellVariable): 164 return ComptimeVar(tx.output.side_effects.load_cell(r)) 165 else: 166 return ComptimeVar(r) 167 168 def graph_break(self, msg="ComptimeContext.graph_break"): 169 """ 170 Manually trigger a graph break 171 """ 172 unimplemented(msg) 173 174 def graph(self): 175 """ 176 Retrieve the partially constructed FX graph that would be 177 passed to the user compiler after compilation. 178 """ 179 return self.__tx.output.graph 180 181 def assert_static(self, val): 182 """ 183 Asserts that the int is static (and not dynamic, per dynamic shapes) 184 """ 185 assert ( 186 not val.is_dynamic() 187 ), "expected static but got dynamic (run with TORCH_LOGS=dynamic for more info)" 188 189 def print_graph(self, *, verbose=True, file=None): 190 """ 191 Print the partially constructed FX graph that would be passed 192 to the user compiler after compilation. 193 """ 194 print( 195 self.__tx.output.graph.python_code("self", verbose=verbose).src, file=file 196 ) 197 198 def parent(self): 199 return ComptimeContext(self.__tx.parent) 200 201 def __get_tx(self, stacklevel): 202 tx = self.__tx 203 for _ in range(stacklevel): 204 tx = tx.parent 205 return tx 206 207 def print(self, val, *, file=None): 208 print(repr(val), file=file) 209 210 def print_disas(self, *, file=None, stacklevel=0): 211 """ 212 Print the current series of opcodes being executed (not including 213 parent frames), including where you are in the particular opcode 214 stream. 215 """ 216 tx = self.__get_tx(stacklevel) 217 print( 218 dis.Bytecode( 219 tx.f_code, 220 current_offset=tx.instructions[tx.instruction_pointer].offset, 221 ).dis(), 222 file=file, 223 ) 224 225 def print_value_stack(self, *, file=None, stacklevel=0): 226 """ 227 Print the current Python value stack. Note that this is NOT the same 228 as the traceback; use print_bt() to print that. Note that at 229 stacklevel=0, this will typically be empty, as comptime cannot 230 currently be used in an expression context where there would be 231 intermediates on the stack. If you would find this useful, please 232 file a bug at https://github.com/pytorch/pytorch/ 233 234 NB: Stack grows downwards in our print 235 """ 236 tx = self.__get_tx(stacklevel) 237 for s in tx.stack: 238 print(f"- {s.debug_repr()}", file=file) 239 240 def print_locals(self, *, file=None, stacklevel=0): 241 """ 242 Print all of the locals available in the current context. 243 By default this view is very limited; you can get more information 244 about any individual local using get_local(). 245 """ 246 tx = self.__get_tx(stacklevel) 247 for k, v in tx.symbolic_locals.items(): 248 print(f"{k} = {v.debug_repr()}", file=file) 249 250 def print_bt(self, *, file=None, stacklevel=0): 251 """ 252 Print the user code backtrace, starting at the beginning of the 253 frame Dynamo started evaluating. Note that this MAY NOT go all 254 the way to the torch.compile invocation, as we may have done 255 a graph break and are compiling an intermediate frame as the 256 starting point. If you think the other behavior would be better, 257 file a bug at https://github.com/pytorch/pytorch/ 258 """ 259 stack = [] 260 tx = self.__get_tx(stacklevel) 261 while tx is not None: 262 stack.append(tx.frame_summary()) 263 tx = getattr(tx, "parent", None) 264 print( 265 "".join(traceback.StackSummary.from_list(reversed(stack)).format()), 266 file=file, 267 ) 268 269 def print_guards(self, *, file=None): 270 """ 271 Print the currently installed guards for the Dynamo context. 272 This does NOT include guards associated with variables that 273 may or may not be installed in the future if those variables 274 are used. 275 """ 276 # TODO: improve print format, current guard format is extremely 277 # verbose 278 print( 279 "\n".join(f"{repr(guard)}" for guard in sorted(self.__tx.output.guards)), 280 file=file, 281 ) 282 283 def _i_will_not_complain_if_bc_breaks_InstructionTranslator(self): 284 """ 285 Returns the internal data structure InstructionTranslator that Dynamo 286 uses to track state of symbolic evaluation. There are no BC 287 guarantees on this API and WE RESERVE THE RIGHT TO BREAK YOUR CODE if 288 you rely on it. 289 """ 290 return self.__tx 291 292 def sleep(self, sec): 293 time.sleep(sec) 294 295 296class _Comptime: 297 @staticmethod 298 def __call__(fn, fallback_fn=lambda: None): 299 """fn gets called at compile time in TorchDynamo, calls fallback_fn otherwise""" 300 fallback_fn() 301 302 # Convenience wrappers that are more compact to use 303 304 @staticmethod 305 def graph_break(): 306 comptime(lambda ctx: ctx.graph_break()) 307 308 @staticmethod 309 def print(e): 310 comptime(lambda ctx: ctx.print(ctx.get_local("e")), lambda: print(e)) 311 312 @staticmethod 313 def print_graph(): 314 comptime(lambda ctx: ctx.print_graph()) 315 316 @staticmethod 317 def print_disas(*, stacklevel=0): 318 comptime( 319 lambda ctx: ctx.print_disas( 320 stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1 321 ) 322 ) 323 324 @staticmethod 325 def print_value_stack(*, stacklevel=0): 326 comptime( 327 lambda ctx: ctx.print_value_stack( 328 stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1 329 ) 330 ) 331 332 # This is a more useful variant of print_value_stack that can be used 333 # in an expression context; e.g., x + print_value_stack_and_return(y + z), 334 # you will see x on the stack prior to the addition operation 335 @staticmethod 336 def print_value_stack_and_return(e, *, stacklevel=0): 337 comptime( 338 lambda ctx: ctx.print_value_stack( 339 stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1 340 ) 341 ) 342 return e 343 344 @staticmethod 345 def print_locals(*, stacklevel=0): 346 comptime( 347 lambda ctx: ctx.print_locals( 348 stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1 349 ) 350 ) 351 352 @staticmethod 353 def print_bt(*, stacklevel=0): 354 comptime( 355 lambda ctx: ctx.print_bt( 356 stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1 357 ) 358 ) 359 360 @staticmethod 361 def print_guards(): 362 comptime(lambda ctx: ctx.print_guards()) 363 364 @staticmethod 365 def assert_static(val): 366 comptime(lambda ctx: ctx.assert_static(ctx.get_local("val"))) 367 368 @staticmethod 369 def force_static(val): 370 comptime(lambda ctx: ctx.get_local("val").force_static()) 371 372 @staticmethod 373 def breakpoint(): 374 """ 375 Like pdb breakpoint(), but drop into pdb whenever this line 376 of code is compiled by dynamo. Use it by putting 377 this in your model code:: 378 379 from torch._dynamo.comptime import comptime 380 comptime.breakpoint() 381 382 And then, inside pdb, you can access 'ctx' to query things 383 about the compilation context:: 384 385 (Pdb) !ctx.print_bt() 386 (Pdb) !ctx.print_locals() 387 (Pdb) p ctx.get_local("attention").as_fake() 388 """ 389 390 def inner(inner_ctx): 391 ctx = inner_ctx.parent() 392 builtins.breakpoint() 393 394 comptime(inner) 395 396 @staticmethod 397 def sleep(sec): 398 comptime(lambda ctx: ctx.sleep(ctx.get_local("sec").as_python_constant())) 399 400 401comptime = _Comptime() 402