1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4import contextlib 5import dataclasses 6import enum 7import functools 8import logging 9import threading 10import traceback 11import unittest.mock 12import weakref 13from abc import abstractmethod 14from contextlib import contextmanager 15from typing import ( 16 Any, 17 Callable, 18 Dict, 19 Generic, 20 List, 21 NamedTuple, 22 Optional, 23 Set, 24 Tuple, 25 TYPE_CHECKING, 26 TypeVar, 27) 28 29from torch._C._dynamo.eval_frame import set_context_frame # noqa: F401 30from torch.utils import _pytree as pytree 31from torch.utils._traceback import CapturedTraceback 32from torch.utils.weak import WeakTensorKeyDictionary 33 34 35log = logging.getLogger(__name__) 36 37 38if TYPE_CHECKING: 39 import sympy 40 41 # Import the following modules during type checking to enable code intelligence features, 42 # such as auto-completion in tools like pylance, even when these modules are not explicitly 43 # imported in user code. 44 import torch 45 46 47""" 48torch._guards is the definitional source of truth for general purpose guard structures. 49 50An important thing to keep in mind here is the preservation of layering. There should be no dynamo notions, 51and no guard installation notions here. 52""" 53 54 55class CompileId(NamedTuple): 56 frame_id: int 57 # This id is per-frame, and counts how many times we've compiled this 58 # frame. This could have been a global id but having this be per-frame 59 # gives you a better intuitive sense for how many recompiles have occurred 60 # so far. 61 frame_compile_id: int 62 # TODO: consider also tracking the recompilation count 63 64 def __str__(self): 65 return f"{self.frame_id}/{self.frame_compile_id}" 66 67 68class TraceId(NamedTuple): 69 compile_id: CompileId 70 # This starts off as 0, and every time we restart analysis it goes 71 # up by one 72 attempt: int 73 74 def __str__(self): 75 if self.attempt == 0: 76 return str(self.compile_id) 77 else: 78 return f"{self.compile_id}_{self.attempt}" 79 80 81class GuardSource(enum.Enum): 82 LOCAL = 0 83 GLOBAL = 1 84 LOCAL_SPECIALIZED_NN_MODULE = 2 85 GLOBAL_SPECIALIZED_NN_MODULE = 3 86 CONSTANT = 4 87 RANDOM_VALUE = 5 88 SHAPE_ENV = 6 89 LOCAL_FSDP_MODULE = 7 90 GLOBAL_FSDP_MODULE = 8 91 BACKWARD_STATE = 9 92 EPHEMERAL = 10 93 SYNTHETIC_LOCAL = 11 94 LOCAL_UNSPECIALIZED_NN_MODULE = 12 95 GLOBAL_UNSPECIALIZED_NN_MODULE = 13 96 LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE = 14 97 GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE = 15 98 99 def is_fsdp_module(self) -> bool: 100 return self in (GuardSource.GLOBAL_FSDP_MODULE, GuardSource.LOCAL_FSDP_MODULE) 101 102 def is_specialized_nn_module(self) -> bool: 103 return ( 104 self 105 in ( 106 GuardSource.GLOBAL_SPECIALIZED_NN_MODULE, 107 GuardSource.LOCAL_SPECIALIZED_NN_MODULE, 108 ) 109 # TODO (anijain2305) - Investigate why is_fsdp_module required. 110 or self.is_fsdp_module() 111 ) 112 113 def is_unspecialized_nn_module(self) -> bool: 114 return self in ( 115 GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE, 116 GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE, 117 GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE, 118 GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE, 119 ) 120 121 def is_unspecialized_builtin_nn_module(self) -> bool: 122 return self in ( 123 GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE, 124 GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE, 125 ) 126 127 def is_local(self): 128 return self in ( 129 GuardSource.LOCAL, 130 GuardSource.LOCAL_SPECIALIZED_NN_MODULE, 131 GuardSource.LOCAL_FSDP_MODULE, 132 GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE, 133 GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE, 134 ) 135 136 137""" 138Base class for a "GuardBuilder" role. 139 140The GuardBuilderBase role is to represent a scope within which to build a guard. The name is a little 141confusing, as its not a builder, but for the sake of avoiding a lot of renames and keeping the original reference 142to torchdynamo's GuardBuilder. 143 144Note: create_fn is invoked with a GuardBuilderBase and a Guard. A GuardBuilder is chosen based 145on GuardSource's select function. 146 147There is value in keeping this GuardBuilderBase empty to keep layering clean. 148""" 149 150 151class GuardBuilderBase: 152 pass 153 154 155class ShapeGuard(NamedTuple): 156 expr: sympy.Expr 157 stack: CapturedTraceback 158 159 160@dataclasses.dataclass 161class Guard: 162 # originating_source is the source that called the make_guard method to 163 # construct this guard object. The property name specifies what exactly it 164 # is the guard is guarding on. The meaning of the name is dependent on the 165 # create_fn; you must look at the use-site inside create_fn to know what 166 # name means. 167 # 168 # That being said, although you might think this is just a "name", name is 169 # usually an arbitrary Python expression that will be evaluated with all 170 # globals (and locals, if you create a LOCAL guard) to extract the Python 171 # object that we want to perform guard tests on. This evaluation 172 # typically happens in GuardBuilder.eval. In these cases, name is 173 # typically produced by originating_source.name() (not to be confused with 174 # GuardSource - the property source). 175 # 176 # Occasionally, name is not a valid Python expression; sometimes 177 # it is meaningless. Example create_fns that are like this include 178 # GRAD_MODE and SHAPE_ENV. 179 originating_source: Source 180 create_fn: Callable[[GuardBuilderBase, Guard], None] 181 182 # Export only. These values are written to at time of guard check_fn creation. 183 guard_types: Optional[List[str]] = None 184 code_list: Optional[List[str]] = None 185 obj_weakref: Optional[object] = None 186 guarded_class_weakref: Optional[type] = None 187 188 stack: Optional[CapturedTraceback] = None 189 user_stack: Optional[traceback.StackSummary] = None 190 _hash: Optional[int] = None 191 192 def __hash__(self): 193 if self._hash is None: 194 self._hash = hash((self.name, self.source, id(self.create_fn))) 195 return self._hash 196 197 def sort_key(self): 198 # Put the duplicate input guards at the end. The duplicate guards have 199 # two sources while guard.name only considers one source. 200 from torch._dynamo.guards import GuardBuilder 201 202 is_duplicate_input = ( 203 isinstance(self.create_fn, functools.partial) 204 and self.create_fn.func is GuardBuilder.DUPLICATE_INPUT 205 ) 206 return ( 207 is_duplicate_input, 208 self.source.value if self.source else -1, 209 len(self.name), 210 self.name, 211 self.inner_create_fn().__code__.co_firstlineno, 212 ) 213 214 def __lt__(self, other): 215 return self.sort_key() < other.sort_key() 216 217 def inner_create_fn(self): 218 if isinstance(self.create_fn, functools.partial): 219 return self.create_fn.func 220 else: 221 return self.create_fn 222 223 @property 224 def name(self) -> str: 225 return self.originating_source.name() 226 227 @property 228 def source(self) -> GuardSource: 229 return self.originating_source.guard_source() 230 231 @staticmethod 232 def weakref_to_str(obj_weakref): 233 """ 234 This is a workaround of a Python weakref bug. 235 236 `obj_weakref` is instance returned by `weakref.ref`, 237 `str(obj_weakref)` is buggy if the original obj overrides __getattr__, e.g: 238 239 class MyConfig(dict): 240 def __getattr__(self, x): 241 return self[x] 242 243 obj = MyConfig(offset=5) 244 obj_weakref = weakref.ref(obj) 245 str(obj_weakref) # raise error: KeyError: '__name__' 246 """ 247 if isinstance(obj_weakref, weakref.ReferenceType): 248 obj = obj_weakref() 249 if obj is not None: 250 return f"<weakref at {hex(id(obj_weakref))}; to '{obj.__class__.__name__}' at {hex(id(obj))}>" 251 else: 252 return f"<weakref at {hex(id(obj_weakref))}; dead>" 253 else: 254 return str(obj_weakref) 255 256 def __repr__(self): 257 s = f""" 258 {self.source.name.lower() if self.source else ""} {repr(self.name)} {self.inner_create_fn().__name__} 259 {{ 260 'guard_types': {self.guard_types}, 261 'code': {self.code_list}, 262 'obj_weakref': {self.weakref_to_str(self.obj_weakref)} 263 'guarded_class': {self.guarded_class_weakref} 264 }} 265 """ 266 return s 267 268 def __str__(self): 269 output = f"Name: {repr(self.name)}\n" 270 source = self.source.name.lower() if self.source else "" 271 output += f" Source: {source}\n" 272 output += f" Create Function: {self.inner_create_fn().__name__}\n" 273 output += f" Guard Types: {self.guard_types}\n" 274 output += f" Code List: {self.code_list}\n" 275 output += f" Object Weakref: {self.weakref_to_str(self.obj_weakref)}\n" 276 output += f" Guarded Class Weakref: {self.guarded_class_weakref}\n" 277 return output 278 279 def create(self, builder: GuardBuilderBase): 280 try: 281 return self.create_fn(builder, self) 282 except Exception: 283 log.exception("Error while creating guard:\n%s", str(self).rstrip()) 284 if self.stack: 285 log.error("Created at:\n%s", "".join(self.stack.format()[-4:]).rstrip()) 286 raise 287 288 def is_specialized_nn_module(self): 289 return self.source.is_specialized_nn_module() 290 291 def is_fsdp_module(self): 292 return self.source.is_fsdp_module() 293 294 def is_local(self): 295 return self.source.is_local() 296 297 def set_export_info(self, guard_type, guarded_class, code_list, obj_weakref): 298 if not self.guard_types: 299 self.guard_types = [] 300 301 self.guard_types.append(guard_type) 302 303 assert self.guarded_class_weakref in ( 304 guarded_class, 305 None, 306 ), "Guarded class id must be identical, or None" 307 self.guarded_class_weakref = guarded_class 308 309 if not self.code_list: 310 self.code_list = code_list 311 else: 312 self.code_list.extend(code_list) 313 314 # Some objects are ephemeral, e.g., list[slice(1, 2)]. If we have 315 # multiple guards on the same object, the weakref can die between the 316 # invocation of set_export_info calls. So a dead weakref is also 317 # acceptable. 318 assert ( 319 self.obj_weakref in (obj_weakref, None) 320 or callable(self.obj_weakref) 321 and self.obj_weakref() is None 322 ), "Guarded object must be identical, None or ephemeral (dead weakref)" 323 self.obj_weakref = obj_weakref 324 325 326T = TypeVar("T") 327 328""" 329Parent structure for guard env expressions. 330A GuardEnvExpr can have any subtype. 331Note: All subtypes must be handled exhaustively in 332torch._dynamo.guards._parse_guard_env_guards to avoid a RuntimeError. 333""" 334 335 336@dataclasses.dataclass 337class GuardEnvExpr: 338 pass 339 340 341""" 342A class representing a pair of duplicate inputs. 343input_pos_a and input_pos_b are input positions we have deduped. 344""" 345 346 347@dataclasses.dataclass 348class DuplicateInputs(GuardEnvExpr): 349 input_source_a: Source 350 input_source_b: Source 351 352 def __post_init__(self): 353 assert self.input_source_a != self.input_source_b 354 355 356""" 357Checkpointable is an interface for driving state snapshotting, left purposely vague for now. 358 359copy_graphstate() -> T, a somewhat legacy name, is expected to emit a snapshot of any type that 360can also be taken in at restore_graphstate(T) calls. 361 362When to snapshot, is, at the moment, an implementation detail of upstream callers. Checkpointable 363does not provide any garuantees around consistency, idempotency, or safety of calling its APIs, yet. 364 365In the future, it will have a closer coupling to a generic Checkpoint management system. 366""" 367 368 369class Checkpointable(Generic[T]): 370 @abstractmethod 371 def copy_graphstate(self) -> T: ... 372 373 @abstractmethod 374 def restore_graphstate(self, state: T): ... 375 376 377class GuardsCheckpointState: 378 """ 379 The GuardCheckpointState - it is the T of Checkpointable[T] for GuardsContext 380 """ 381 382 dynamo_guards: Set[Guard] = set() 383 384 def __init__(self, dynamo_guards): 385 self.dynamo_guards = dynamo_guards 386 387 def diff(self, other): 388 """ 389 Produces a delta against another GuardsCheckpointState. 390 391 Returns None if no delta is found, otherwise, return a set() of mismatched 392 Guard type objects. 393 """ 394 r = self.dynamo_guards.difference(other.dynamo_guards) 395 if len(r) == 0: 396 return None 397 return r 398 399 def __eq__(self, other): 400 return self.diff(other) is None 401 402 403class ModuleContextCheckpointState: 404 nn_modules: Dict[str, torch.nn.Module] = {} 405 406 def __init__(self, nn_modules): 407 self.nn_modules = nn_modules 408 409 def diff(self, other): 410 """ 411 Produces a delta against another ModuleContextCheckpointState. 412 413 Returns None if no delta is found, otherwise, return a set() of mismatched 414 module key names. 415 """ 416 r = set(self.nn_modules.keys()).difference(set(other.nn_modules.keys())) 417 if len(r) == 0: 418 return None 419 return r 420 421 def __eq__(self, other): 422 return self.diff(other) is None 423 424 425class ModuleContext(Checkpointable[ModuleContextCheckpointState]): 426 def __init__(self) -> None: 427 self.nn_modules: Dict[str, Any] = {} 428 429 def copy_graphstate(self): 430 return ModuleContextCheckpointState(dict(self.nn_modules)) 431 432 def restore_graphstate(self, state): 433 assert isinstance(state, ModuleContextCheckpointState) 434 self.nn_modules = state.nn_modules 435 436 437class GlobalContextCheckpointState: 438 global_state: Dict[str, Tuple[Callable, ...]] = {} 439 440 def __init__(self, global_states): 441 self.global_state = global_states 442 443 def diff(self, other): 444 """ 445 Produces a delta against another GlobalContextCheckpointState. 446 447 Returns None if no delta is found, otherwise, return a set() of mismatched 448 global key names. 449 """ 450 r = set(self.global_state.keys()).difference(set(other.global_state.keys())) 451 if len(r) == 0: 452 return None 453 return r 454 455 def __eq__(self, other): 456 return self.diff(other) is None 457 458 459class GlobalContext(Checkpointable[GlobalContextCheckpointState]): 460 """ 461 This keeps track of the global torch state during tracing of a function. 462 For example, torch.is_grad_enabled. 463 """ 464 465 _supported_global_states = { 466 "grad_enabled", 467 "torch_function_enabled", 468 "autocast_enabled", 469 "autocast_cpu_enabled", 470 "autocast_gpu_dtype", 471 "autocast_cpu_dtype", 472 "autocast_cache_enabled", 473 } 474 475 def __init__(self) -> None: 476 self.global_state: Dict[str, Tuple[Callable, ...]] = {} 477 478 def copy_graphstate(self): 479 return GlobalContextCheckpointState(dict(self.global_state)) 480 481 def restore_graphstate(self, state): 482 assert isinstance(state, GlobalContextCheckpointState) 483 self.global_state = state.global_state 484 assert ( 485 len(self.global_state) == len(self._supported_global_states) 486 and set(self.global_state.keys()) == self._supported_global_states 487 ), "Global state mismatch" 488 for func, args in self.global_state.values(): 489 func(args) 490 491 492""" 493A GuardsContext is a checkpointable representation of all the guards in the current tracing 494context. It's lifecycle is bound 1:1 to the tracing context, and it should never be instantiated 495directly outside of it. For passing around internal state representations of this object, 496prefer to extract them with copy_graphstate to produce a GuardsCheckpointState. 497""" 498 499 500# Like a Set[Guard] but will record the user stack on all guards at the 501# time they were installed at their destination 502class GuardsSet: 503 def __init__(self, inner=None): 504 if inner is None: 505 inner = set() 506 self.inner = inner 507 508 def __iter__(self): 509 return iter(self.inner) 510 511 def __len__(self): 512 return len(self.inner) 513 514 # Subtraction along with bool is typically used to determine the delta of 515 # added guards between checkpoints for higher order ops 516 def __sub__(self, other): 517 return GuardsSet(self.inner - other.inner) 518 519 def __bool__(self): 520 return bool(self.inner) 521 522 def add(self, guard: Guard, *, collect_debug_stack=True, skip=0): 523 if guard in self.inner: 524 return 525 if collect_debug_stack: 526 if guard.stack is None: 527 guard.stack = CapturedTraceback.extract(skip=1 + skip) 528 if guard.user_stack is None: 529 guard.user_stack = TracingContext.extract_stack() 530 self.inner.add(guard) 531 532 def update(self, *others: Set[Guard]): 533 for o in others: 534 for g in o: 535 self.add(g, skip=1) 536 537 def remove_guards_with_source(self, source): 538 """Delete all guards with a given source""" 539 self.inner = {g for g in self.inner if g.originating_source != source} 540 541 542class GuardsContext(Checkpointable[GuardsCheckpointState]): 543 def __init__(self) -> None: 544 self.dynamo_guards: GuardsSet = GuardsSet() 545 self.aotautograd_guards: List[GuardEnvExpr] = [] 546 547 def copy_graphstate(self): 548 return GuardsCheckpointState(set(self.dynamo_guards.inner)) 549 550 def restore_graphstate(self, state): 551 # NB: "steals" the passed in state 552 assert isinstance(state, GuardsCheckpointState) 553 self.dynamo_guards = GuardsSet(state.dynamo_guards) 554 555 556_TLS = threading.local() 557 558""" 559TracingContext is the source of truth for all currently accumulated information 560needed to trace. Its lifecycle is kept 1:1 when using TorchDynamo, but other systems 561are open to managing their own TracingContext with that in mind. 562 563The purpose of TracingContext is not to be a dumping ground, or god object, but rather to avoid 564having to plumb complex subsystems across multiple verticals. 565 566Ex: A common example is guard accumulation between dynamo, shape_env, aot_autograd, and inductor. 567Accessing the current tracing context via 568TracingContext.get() allows users to accumulate their own guards for processing, without needing to know how 569to plumb objects back up to where frame interpretation happened. 570 571Note that you can end up with multiple TracingContext for a single compilation 572of a frame, as we reset the TracingContext whenever we restart analysis. 573CompileContext is a more overarching context that encompasses multiple restarts. 574""" 575 576 577class CompileContext: 578 @staticmethod 579 def get() -> CompileContext: 580 assert _TLS.compile_context is not None 581 return _TLS.compile_context 582 583 @staticmethod 584 def try_get() -> Optional[CompileContext]: 585 return getattr(_TLS, "compile_context", None) 586 587 def __init__(self, compile_id): 588 assert compile_id is None or isinstance(compile_id, CompileId) 589 self.compile_id: Optional[CompileId] = compile_id 590 self.attempt = 0 591 592 @staticmethod 593 def current_compile_id(): 594 self = CompileContext.try_get() 595 if self is None: 596 return None 597 return self.compile_id 598 599 @staticmethod 600 def current_trace_id(): 601 self = CompileContext.try_get() 602 if self is None: 603 return None 604 if self.compile_id is None: 605 return None 606 return TraceId(self.compile_id, self.attempt) 607 608 609class TracingContext: 610 """ 611 Provides the currently installed TracingContext, or None. 612 613 Note that it is a staticmethod, and invocations outside of `with tracing()` (see below), are valid but 614 will return None. 615 """ 616 617 @staticmethod 618 def try_get() -> Optional[TracingContext]: 619 return getattr(_TLS, "tracing_context", None) 620 621 @staticmethod 622 def get() -> TracingContext: 623 if ctx := TracingContext.try_get(): 624 return ctx 625 raise RuntimeError( 626 "TracingContext.get() must be called within an ongoing trace." 627 ) 628 629 def __init__(self, fake_mode): 630 self.guards_context = GuardsContext() 631 self.module_context = ModuleContext() 632 self.global_context = GlobalContext() 633 self.fake_mode = fake_mode 634 self.frame_summary_stack = [] 635 # This is morally part of frame_summary_stack, but it is kept separate 636 # for clarity. As we process a frame, this variable gets updated 637 # to keep track of what line we are in the function. We make a 638 # function call, this gets cleared and the frame location is pushed 639 # to frame_summary_stack (prepping this variable for the inner frame's 640 # progress) 641 self.loc_in_frame = None 642 # this is only set after aot_autograd 643 self.fw_metadata = None 644 # this is only set after aot_autograd 645 self.aot_graph_name = None 646 self.params_flat = None 647 # this is for extended return calling convention from backend 648 # compiler to aot_autograd 649 # Per output, what the compiler specified stride of the output is, 650 # or None if no stride is known. This is always the HINT, it 651 # is never a SymInt (it would be better if it was a SymInt, but 652 # I can't conveniently get this from Inductor atm. Also, be 653 # careful not to accidentally induce guards on the SymInt if 654 # you ever do change this in aot_autograd.py; you should check 655 # on permutations preferentially.) 656 self.output_strides: Optional[List[Optional[Tuple[int, ...]]]] = None 657 # When this is True, whenever we encounter an int in Dynamo tracing, 658 # we will (1) force unspec it and (2) force it as a size-like unbacked 659 # integer. This is currently used when processing certain lists of 660 # ints that are known to be size-like and may have 0/1 entries that we 661 # must not specialize on. 662 self.force_unspec_int_unbacked_size_like = False 663 # See note [Tensor Fakification and Symbol Caching] 664 self.tensor_to_context = WeakTensorKeyDictionary() 665 666 # If this true, Aot Autograd will return output Fake Tensors with appropiate 667 # meta on the first invocation 668 # see note: [Returning Fake Tensors on First AOT Autograd Call] 669 self.fakify_first_call = False 670 671 def clear(self): 672 # Look at the note in output_graph.py in function `save_global_state` 673 # for the context on clearing global context. 674 self.global_context.global_state = {} 675 676 @staticmethod 677 @contextmanager 678 def patch(**kwargs): 679 prior = {} 680 ctx = TracingContext.get() 681 682 for key in kwargs.keys(): 683 # KeyError on invalid entry 684 prior[key] = getattr(ctx, key) 685 for key, val in kwargs.items(): 686 setattr(ctx, key, val) 687 try: 688 yield 689 finally: 690 for key, val in prior.items(): 691 setattr(ctx, key, val) 692 693 @staticmethod 694 def extract_stack(): 695 self = TracingContext.try_get() 696 if self is None: 697 return traceback.StackSummary() 698 stack = self.frame_summary_stack 699 if self.loc_in_frame is not None: 700 stack = stack + [self.loc_in_frame] 701 return traceback.StackSummary.from_list(stack) 702 703 # Call this when you want to call into some code that isn't necessarily 704 # associated with the current frame state 705 @staticmethod 706 @contextlib.contextmanager 707 def clear_frame(): 708 tc = TracingContext.get() 709 with unittest.mock.patch.object( 710 tc, "frame_summary_stack", [] 711 ), unittest.mock.patch.object(tc, "loc_in_frame", None): 712 try: 713 yield 714 except Exception as e: 715 # Prevent real_stack from getting attached 716 # 717 # The invariant is that if an Exception as real_stack, we've 718 # appropriately attached a user stack and we no longer need to 719 # attach anything. Because we cannot conveniently interpose 720 # when an exception is thrown, we instead interpose everywhere 721 # we set what the user stack is set (using the context 722 # manager). However, our compiler stack does "tail calls" 723 # (when it calls into user compiler), at which point the 724 # parent exception frames would incorrectly attach an 725 # incorrect frame. 726 # 727 # However, if, somehow, someone raised an exception with this 728 # scope that had a stack (for example, because they are 729 # restoring the user stack state appropriately as they process 730 # node by node), we should respect it. Thus, we cannot 731 # unconditionally set None. 732 if not hasattr(e, "real_stack"): 733 e.real_stack = None # type: ignore[attr-defined] 734 raise 735 736 @staticmethod 737 @contextlib.contextmanager 738 def current_frame(frame_summary): 739 # frame_summary can be None to solely take advantage of real_stack 740 # attachment to thrown exceptions 741 tc = TracingContext.get() 742 if frame_summary is not None: 743 tc.frame_summary_stack.append(frame_summary) 744 old = tc.loc_in_frame 745 tc.loc_in_frame = None 746 try: 747 yield 748 except Exception as e: 749 if not hasattr(e, "real_stack"): 750 e.real_stack = tc.extract_stack() # type: ignore[attr-defined] 751 raise 752 finally: 753 if frame_summary is not None: 754 tc.frame_summary_stack.pop() 755 tc.loc_in_frame = old 756 757 @staticmethod 758 @contextlib.contextmanager 759 def report_output_strides(): 760 tc = TracingContext.try_get() 761 if tc is None: 762 yield None 763 return 764 old_output_strides = tc.output_strides 765 tc.output_strides = [] 766 try: 767 yield tc.output_strides 768 finally: 769 tc.output_strides = old_output_strides 770 771 @staticmethod 772 def set_current_loc(filename, lineno, frame_name): 773 TracingContext.get().loc_in_frame = traceback.FrameSummary( 774 filename, lineno, frame_name, lookup_line=False 775 ) 776 777 778@contextmanager 779def compile_context(context: Optional[CompileContext]): 780 old_context = getattr(_TLS, "compile_context", None) 781 _TLS.compile_context = context 782 try: 783 yield context 784 finally: 785 if context is not None: 786 if context.compile_id is not None: 787 set_context_frame( 788 ( 789 context.compile_id.frame_id, 790 context.compile_id.frame_compile_id, 791 context.attempt, 792 ) 793 ) 794 _TLS.compile_context = old_context 795 796 797@contextmanager 798def tracing(context: Optional[TracingContext]): 799 """ 800 This function installs the passed in tracing context as a dynamic scoped 801 global variable. 802 803 Calls to TracingContext.get() while not under a `with tracing()` context 804 will return None. 805 """ 806 old_context = getattr(_TLS, "tracing_context", None) 807 _TLS.tracing_context = context 808 try: 809 yield context 810 except Exception as e: 811 if not hasattr(e, "real_stack") and context is not None: 812 e.real_stack = context.extract_stack() # type: ignore[attr-defined] 813 raise 814 finally: 815 if ( 816 context is not None 817 and context.fake_mode is not None 818 and context.fake_mode.shape_env is not None 819 ): 820 context.fake_mode.shape_env.cleanup() 821 _TLS.tracing_context = old_context 822 823 824# Subclasses can be found in torch/_dynamo/source.py 825# TODO(voz): Consider a toplevel torch/_source.py 826@dataclasses.dataclass(frozen=True) 827class Source: 828 def is_dict_key(self): 829 return False 830 831 def is_ephemeral(self): 832 return False 833 834 def reconstruct(self, codegen): 835 raise NotImplementedError 836 837 def guard_source(self) -> GuardSource: 838 raise NotImplementedError 839 840 def name(self) -> str: 841 raise NotImplementedError 842 843 def make_guard(self, fn) -> Guard: 844 if self.guard_source() is GuardSource.CONSTANT: 845 raise NotImplementedError 846 return Guard(self, fn) 847 848 def is_specialized_nn_module(self) -> bool: 849 return self.guard_source().is_specialized_nn_module() 850 851 def subguards_allowed(self): 852 """True if you can guard on attributes of this""" 853 return self.guard_source() != GuardSource.SYNTHETIC_LOCAL 854 855 856# Subclasses can be found in torch/_dynamo/source.py 857@dataclasses.dataclass(frozen=True) 858class ChainedSource(Source): 859 base: Source 860 861 def is_dict_key(self): 862 # Recurse until you either hit a ConstDictKey or a Source 863 return self.base.is_dict_key() 864 865 def is_ephemeral(self): 866 return self.base.is_ephemeral() 867 868 869def detect_fake_mode(inputs: Any = None): 870 """ 871 Attempts to "detect" what the current fake mode is. If there is one ambiently 872 available from TracingContext, we preferentially use that. Otherwise, we 873 heuristically detect the fake mode via the following sources, in order of 874 priority: 875 876 - Currently active fake mode on stack 877 - Fake mode associated with passed in tensors (inputs does not 878 have to be flattened) 879 """ 880 from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode 881 882 fake_modes = [] 883 884 if context := TracingContext.try_get(): 885 fake_mode = context.fake_mode 886 if fake_mode is not None: 887 fake_modes.append((fake_mode, "tracing context", 0)) 888 889 from torch.utils._python_dispatch import _get_current_dispatch_mode_stack 890 891 for i, m in enumerate(reversed(_get_current_dispatch_mode_stack())): 892 if isinstance(m, FakeTensorMode): 893 fake_modes.append((m, "active fake mode", i)) 894 895 flat_inputs = pytree.tree_leaves(inputs) 896 for i, flat_input in enumerate(flat_inputs): 897 if isinstance(flat_input, FakeTensor): 898 fake_modes.append((flat_input.fake_mode, "fake tensor input", i)) 899 900 if fake_modes: 901 fake_mode, desc1, i1 = fake_modes[0] 902 for m, desc2, i2 in fake_modes[1:]: 903 assert fake_mode is m, ( 904 f"fake mode ({fake_mode}) from {desc1} {i1} doesn't match mode ({m}) from {desc2} {i2}\n\n" 905 f"fake mode from {desc1} {i1} allocated at:\n{fake_mode.stack}\n" 906 f"fake mode from {desc2} {i2} allocated at:\n{m.stack}" 907 ) 908 return fake_mode 909 else: 910 return None 911 912 913def active_fake_mode(): 914 """ 915 Inspects the dispatch mode stack for an active fake mode and returns it. 916 Returns None if no fake mode is active. 917 """ 918 from torch._subclasses.fake_tensor import FakeTensorMode 919 from torch.utils._python_dispatch import _get_current_dispatch_mode_stack 920 921 for _, m in enumerate(reversed(_get_current_dispatch_mode_stack())): 922 if isinstance(m, FakeTensorMode): 923 return m 924 925 return None 926