xref: /aosp_15_r20/external/pytorch/torch/_guards.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4import contextlib
5import dataclasses
6import enum
7import functools
8import logging
9import threading
10import traceback
11import unittest.mock
12import weakref
13from abc import abstractmethod
14from contextlib import contextmanager
15from typing import (
16    Any,
17    Callable,
18    Dict,
19    Generic,
20    List,
21    NamedTuple,
22    Optional,
23    Set,
24    Tuple,
25    TYPE_CHECKING,
26    TypeVar,
27)
28
29from torch._C._dynamo.eval_frame import set_context_frame  # noqa: F401
30from torch.utils import _pytree as pytree
31from torch.utils._traceback import CapturedTraceback
32from torch.utils.weak import WeakTensorKeyDictionary
33
34
35log = logging.getLogger(__name__)
36
37
38if TYPE_CHECKING:
39    import sympy
40
41    # Import the following modules during type checking to enable code intelligence features,
42    # such as auto-completion in tools like pylance, even when these modules are not explicitly
43    # imported in user code.
44    import torch
45
46
47"""
48torch._guards is the definitional source of truth for general purpose guard structures.
49
50An important thing to keep in mind here is the preservation of layering. There should be no dynamo notions,
51and no guard installation notions here.
52"""
53
54
55class CompileId(NamedTuple):
56    frame_id: int
57    # This id is per-frame, and counts how many times we've compiled this
58    # frame.  This could have been a global id but having this be per-frame
59    # gives you a better intuitive sense for how many recompiles have occurred
60    # so far.
61    frame_compile_id: int
62    # TODO: consider also tracking the recompilation count
63
64    def __str__(self):
65        return f"{self.frame_id}/{self.frame_compile_id}"
66
67
68class TraceId(NamedTuple):
69    compile_id: CompileId
70    # This starts off as 0, and every time we restart analysis it goes
71    # up by one
72    attempt: int
73
74    def __str__(self):
75        if self.attempt == 0:
76            return str(self.compile_id)
77        else:
78            return f"{self.compile_id}_{self.attempt}"
79
80
81class GuardSource(enum.Enum):
82    LOCAL = 0
83    GLOBAL = 1
84    LOCAL_SPECIALIZED_NN_MODULE = 2
85    GLOBAL_SPECIALIZED_NN_MODULE = 3
86    CONSTANT = 4
87    RANDOM_VALUE = 5
88    SHAPE_ENV = 6
89    LOCAL_FSDP_MODULE = 7
90    GLOBAL_FSDP_MODULE = 8
91    BACKWARD_STATE = 9
92    EPHEMERAL = 10
93    SYNTHETIC_LOCAL = 11
94    LOCAL_UNSPECIALIZED_NN_MODULE = 12
95    GLOBAL_UNSPECIALIZED_NN_MODULE = 13
96    LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE = 14
97    GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE = 15
98
99    def is_fsdp_module(self) -> bool:
100        return self in (GuardSource.GLOBAL_FSDP_MODULE, GuardSource.LOCAL_FSDP_MODULE)
101
102    def is_specialized_nn_module(self) -> bool:
103        return (
104            self
105            in (
106                GuardSource.GLOBAL_SPECIALIZED_NN_MODULE,
107                GuardSource.LOCAL_SPECIALIZED_NN_MODULE,
108            )
109            # TODO (anijain2305) - Investigate why is_fsdp_module required.
110            or self.is_fsdp_module()
111        )
112
113    def is_unspecialized_nn_module(self) -> bool:
114        return self in (
115            GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE,
116            GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE,
117            GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
118            GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
119        )
120
121    def is_unspecialized_builtin_nn_module(self) -> bool:
122        return self in (
123            GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
124            GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
125        )
126
127    def is_local(self):
128        return self in (
129            GuardSource.LOCAL,
130            GuardSource.LOCAL_SPECIALIZED_NN_MODULE,
131            GuardSource.LOCAL_FSDP_MODULE,
132            GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE,
133            GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
134        )
135
136
137"""
138Base class for a "GuardBuilder" role.
139
140The GuardBuilderBase role is to represent a scope within which to build a guard. The name is a little
141confusing, as its not a builder, but for the sake of avoiding a lot of renames and keeping the original reference
142to torchdynamo's GuardBuilder.
143
144Note: create_fn is invoked with a GuardBuilderBase and a Guard. A GuardBuilder is chosen based
145on GuardSource's select function.
146
147There is value in keeping this GuardBuilderBase empty to keep layering clean.
148"""
149
150
151class GuardBuilderBase:
152    pass
153
154
155class ShapeGuard(NamedTuple):
156    expr: sympy.Expr
157    stack: CapturedTraceback
158
159
160@dataclasses.dataclass
161class Guard:
162    # originating_source is the source that called the make_guard method to
163    # construct this guard object. The property name specifies what exactly it
164    # is the guard is guarding on.  The meaning of the name is dependent on the
165    # create_fn; you must look at the use-site inside create_fn to know what
166    # name means.
167    #
168    # That being said, although you might think this is just a "name", name is
169    # usually an arbitrary Python expression that will be evaluated with all
170    # globals (and locals, if you create a LOCAL guard) to extract the Python
171    # object that we want to perform guard tests on.  This evaluation
172    # typically happens in GuardBuilder.eval.  In these cases, name is
173    # typically produced by originating_source.name() (not to be confused with
174    # GuardSource - the property source).
175    #
176    # Occasionally, name is not a valid Python expression; sometimes
177    # it is meaningless.  Example create_fns that are like this include
178    # GRAD_MODE and SHAPE_ENV.
179    originating_source: Source
180    create_fn: Callable[[GuardBuilderBase, Guard], None]
181
182    # Export only. These values are written to at time of guard check_fn creation.
183    guard_types: Optional[List[str]] = None
184    code_list: Optional[List[str]] = None
185    obj_weakref: Optional[object] = None
186    guarded_class_weakref: Optional[type] = None
187
188    stack: Optional[CapturedTraceback] = None
189    user_stack: Optional[traceback.StackSummary] = None
190    _hash: Optional[int] = None
191
192    def __hash__(self):
193        if self._hash is None:
194            self._hash = hash((self.name, self.source, id(self.create_fn)))
195        return self._hash
196
197    def sort_key(self):
198        # Put the duplicate input guards at the end. The duplicate guards have
199        # two sources while guard.name only considers one source.
200        from torch._dynamo.guards import GuardBuilder
201
202        is_duplicate_input = (
203            isinstance(self.create_fn, functools.partial)
204            and self.create_fn.func is GuardBuilder.DUPLICATE_INPUT
205        )
206        return (
207            is_duplicate_input,
208            self.source.value if self.source else -1,
209            len(self.name),
210            self.name,
211            self.inner_create_fn().__code__.co_firstlineno,
212        )
213
214    def __lt__(self, other):
215        return self.sort_key() < other.sort_key()
216
217    def inner_create_fn(self):
218        if isinstance(self.create_fn, functools.partial):
219            return self.create_fn.func
220        else:
221            return self.create_fn
222
223    @property
224    def name(self) -> str:
225        return self.originating_source.name()
226
227    @property
228    def source(self) -> GuardSource:
229        return self.originating_source.guard_source()
230
231    @staticmethod
232    def weakref_to_str(obj_weakref):
233        """
234        This is a workaround of a Python weakref bug.
235
236        `obj_weakref` is instance returned by `weakref.ref`,
237        `str(obj_weakref)` is buggy if the original obj overrides __getattr__, e.g:
238
239            class MyConfig(dict):
240                def __getattr__(self, x):
241                    return self[x]
242
243            obj = MyConfig(offset=5)
244            obj_weakref = weakref.ref(obj)
245            str(obj_weakref)  # raise error: KeyError: '__name__'
246        """
247        if isinstance(obj_weakref, weakref.ReferenceType):
248            obj = obj_weakref()
249            if obj is not None:
250                return f"<weakref at {hex(id(obj_weakref))}; to '{obj.__class__.__name__}' at {hex(id(obj))}>"
251            else:
252                return f"<weakref at {hex(id(obj_weakref))}; dead>"
253        else:
254            return str(obj_weakref)
255
256    def __repr__(self):
257        s = f"""
258        {self.source.name.lower() if self.source else ""} {repr(self.name)} {self.inner_create_fn().__name__}
259        {{
260            'guard_types': {self.guard_types},
261            'code': {self.code_list},
262            'obj_weakref': {self.weakref_to_str(self.obj_weakref)}
263            'guarded_class': {self.guarded_class_weakref}
264        }}
265        """
266        return s
267
268    def __str__(self):
269        output = f"Name: {repr(self.name)}\n"
270        source = self.source.name.lower() if self.source else ""
271        output += f"    Source: {source}\n"
272        output += f"    Create Function: {self.inner_create_fn().__name__}\n"
273        output += f"    Guard Types: {self.guard_types}\n"
274        output += f"    Code List: {self.code_list}\n"
275        output += f"    Object Weakref: {self.weakref_to_str(self.obj_weakref)}\n"
276        output += f"    Guarded Class Weakref: {self.guarded_class_weakref}\n"
277        return output
278
279    def create(self, builder: GuardBuilderBase):
280        try:
281            return self.create_fn(builder, self)
282        except Exception:
283            log.exception("Error while creating guard:\n%s", str(self).rstrip())
284            if self.stack:
285                log.error("Created at:\n%s", "".join(self.stack.format()[-4:]).rstrip())
286            raise
287
288    def is_specialized_nn_module(self):
289        return self.source.is_specialized_nn_module()
290
291    def is_fsdp_module(self):
292        return self.source.is_fsdp_module()
293
294    def is_local(self):
295        return self.source.is_local()
296
297    def set_export_info(self, guard_type, guarded_class, code_list, obj_weakref):
298        if not self.guard_types:
299            self.guard_types = []
300
301        self.guard_types.append(guard_type)
302
303        assert self.guarded_class_weakref in (
304            guarded_class,
305            None,
306        ), "Guarded class id must be identical, or None"
307        self.guarded_class_weakref = guarded_class
308
309        if not self.code_list:
310            self.code_list = code_list
311        else:
312            self.code_list.extend(code_list)
313
314        # Some objects are ephemeral, e.g., list[slice(1, 2)]. If we have
315        # multiple guards on the same object, the weakref can die between the
316        # invocation of set_export_info calls. So a dead weakref is also
317        # acceptable.
318        assert (
319            self.obj_weakref in (obj_weakref, None)
320            or callable(self.obj_weakref)
321            and self.obj_weakref() is None
322        ), "Guarded object must be identical, None or ephemeral (dead weakref)"
323        self.obj_weakref = obj_weakref
324
325
326T = TypeVar("T")
327
328"""
329Parent structure for guard env expressions.
330A GuardEnvExpr can have any subtype.
331Note: All subtypes must be handled exhaustively in
332torch._dynamo.guards._parse_guard_env_guards to avoid a RuntimeError.
333"""
334
335
336@dataclasses.dataclass
337class GuardEnvExpr:
338    pass
339
340
341"""
342A class representing a pair of duplicate inputs.
343input_pos_a and input_pos_b are input positions we have deduped.
344"""
345
346
347@dataclasses.dataclass
348class DuplicateInputs(GuardEnvExpr):
349    input_source_a: Source
350    input_source_b: Source
351
352    def __post_init__(self):
353        assert self.input_source_a != self.input_source_b
354
355
356"""
357Checkpointable is an interface for driving state snapshotting, left purposely vague for now.
358
359copy_graphstate() -> T, a somewhat legacy name, is expected to emit a snapshot of any type that
360can also be taken in at restore_graphstate(T) calls.
361
362When to snapshot, is, at the moment, an implementation detail of upstream callers. Checkpointable
363does not provide any garuantees around consistency, idempotency, or safety of calling its APIs, yet.
364
365In the future, it will have a closer coupling to a generic Checkpoint management system.
366"""
367
368
369class Checkpointable(Generic[T]):
370    @abstractmethod
371    def copy_graphstate(self) -> T: ...
372
373    @abstractmethod
374    def restore_graphstate(self, state: T): ...
375
376
377class GuardsCheckpointState:
378    """
379    The GuardCheckpointState - it is the T of Checkpointable[T] for GuardsContext
380    """
381
382    dynamo_guards: Set[Guard] = set()
383
384    def __init__(self, dynamo_guards):
385        self.dynamo_guards = dynamo_guards
386
387    def diff(self, other):
388        """
389        Produces a delta against another GuardsCheckpointState.
390
391        Returns None if no delta is found, otherwise, return a set() of mismatched
392        Guard type objects.
393        """
394        r = self.dynamo_guards.difference(other.dynamo_guards)
395        if len(r) == 0:
396            return None
397        return r
398
399    def __eq__(self, other):
400        return self.diff(other) is None
401
402
403class ModuleContextCheckpointState:
404    nn_modules: Dict[str, torch.nn.Module] = {}
405
406    def __init__(self, nn_modules):
407        self.nn_modules = nn_modules
408
409    def diff(self, other):
410        """
411        Produces a delta against another ModuleContextCheckpointState.
412
413        Returns None if no delta is found, otherwise, return a set() of mismatched
414        module key names.
415        """
416        r = set(self.nn_modules.keys()).difference(set(other.nn_modules.keys()))
417        if len(r) == 0:
418            return None
419        return r
420
421    def __eq__(self, other):
422        return self.diff(other) is None
423
424
425class ModuleContext(Checkpointable[ModuleContextCheckpointState]):
426    def __init__(self) -> None:
427        self.nn_modules: Dict[str, Any] = {}
428
429    def copy_graphstate(self):
430        return ModuleContextCheckpointState(dict(self.nn_modules))
431
432    def restore_graphstate(self, state):
433        assert isinstance(state, ModuleContextCheckpointState)
434        self.nn_modules = state.nn_modules
435
436
437class GlobalContextCheckpointState:
438    global_state: Dict[str, Tuple[Callable, ...]] = {}
439
440    def __init__(self, global_states):
441        self.global_state = global_states
442
443    def diff(self, other):
444        """
445        Produces a delta against another GlobalContextCheckpointState.
446
447        Returns None if no delta is found, otherwise, return a set() of mismatched
448        global key names.
449        """
450        r = set(self.global_state.keys()).difference(set(other.global_state.keys()))
451        if len(r) == 0:
452            return None
453        return r
454
455    def __eq__(self, other):
456        return self.diff(other) is None
457
458
459class GlobalContext(Checkpointable[GlobalContextCheckpointState]):
460    """
461    This keeps track of the global torch state during tracing of a function.
462    For example, torch.is_grad_enabled.
463    """
464
465    _supported_global_states = {
466        "grad_enabled",
467        "torch_function_enabled",
468        "autocast_enabled",
469        "autocast_cpu_enabled",
470        "autocast_gpu_dtype",
471        "autocast_cpu_dtype",
472        "autocast_cache_enabled",
473    }
474
475    def __init__(self) -> None:
476        self.global_state: Dict[str, Tuple[Callable, ...]] = {}
477
478    def copy_graphstate(self):
479        return GlobalContextCheckpointState(dict(self.global_state))
480
481    def restore_graphstate(self, state):
482        assert isinstance(state, GlobalContextCheckpointState)
483        self.global_state = state.global_state
484        assert (
485            len(self.global_state) == len(self._supported_global_states)
486            and set(self.global_state.keys()) == self._supported_global_states
487        ), "Global state mismatch"
488        for func, args in self.global_state.values():
489            func(args)
490
491
492"""
493A GuardsContext is a checkpointable representation of all the guards in the current tracing
494context. It's lifecycle is bound 1:1 to the tracing context, and it should never be instantiated
495directly outside of it. For passing around internal state representations of this object,
496prefer to extract them with copy_graphstate to produce a GuardsCheckpointState.
497"""
498
499
500# Like a Set[Guard] but will record the user stack on all guards at the
501# time they were installed at their destination
502class GuardsSet:
503    def __init__(self, inner=None):
504        if inner is None:
505            inner = set()
506        self.inner = inner
507
508    def __iter__(self):
509        return iter(self.inner)
510
511    def __len__(self):
512        return len(self.inner)
513
514    # Subtraction along with bool is typically used to determine the delta of
515    # added guards between checkpoints for higher order ops
516    def __sub__(self, other):
517        return GuardsSet(self.inner - other.inner)
518
519    def __bool__(self):
520        return bool(self.inner)
521
522    def add(self, guard: Guard, *, collect_debug_stack=True, skip=0):
523        if guard in self.inner:
524            return
525        if collect_debug_stack:
526            if guard.stack is None:
527                guard.stack = CapturedTraceback.extract(skip=1 + skip)
528            if guard.user_stack is None:
529                guard.user_stack = TracingContext.extract_stack()
530        self.inner.add(guard)
531
532    def update(self, *others: Set[Guard]):
533        for o in others:
534            for g in o:
535                self.add(g, skip=1)
536
537    def remove_guards_with_source(self, source):
538        """Delete all guards with a given source"""
539        self.inner = {g for g in self.inner if g.originating_source != source}
540
541
542class GuardsContext(Checkpointable[GuardsCheckpointState]):
543    def __init__(self) -> None:
544        self.dynamo_guards: GuardsSet = GuardsSet()
545        self.aotautograd_guards: List[GuardEnvExpr] = []
546
547    def copy_graphstate(self):
548        return GuardsCheckpointState(set(self.dynamo_guards.inner))
549
550    def restore_graphstate(self, state):
551        # NB: "steals" the passed in state
552        assert isinstance(state, GuardsCheckpointState)
553        self.dynamo_guards = GuardsSet(state.dynamo_guards)
554
555
556_TLS = threading.local()
557
558"""
559TracingContext is the source of truth for all currently accumulated information
560needed to trace. Its lifecycle is kept 1:1 when using TorchDynamo, but other systems
561are open to managing their own TracingContext with that in mind.
562
563The purpose of TracingContext is not to be a dumping ground, or god object, but rather to avoid
564having to plumb complex subsystems across multiple verticals.
565
566Ex: A common example is guard accumulation between dynamo, shape_env, aot_autograd, and inductor.
567Accessing the current tracing context via
568TracingContext.get() allows users to accumulate their own guards for processing, without needing to know how
569to plumb objects back up to where frame interpretation happened.
570
571Note that you can end up with multiple TracingContext for a single compilation
572of a frame, as we reset the TracingContext whenever we restart analysis.
573CompileContext is a more overarching context that encompasses multiple restarts.
574"""
575
576
577class CompileContext:
578    @staticmethod
579    def get() -> CompileContext:
580        assert _TLS.compile_context is not None
581        return _TLS.compile_context
582
583    @staticmethod
584    def try_get() -> Optional[CompileContext]:
585        return getattr(_TLS, "compile_context", None)
586
587    def __init__(self, compile_id):
588        assert compile_id is None or isinstance(compile_id, CompileId)
589        self.compile_id: Optional[CompileId] = compile_id
590        self.attempt = 0
591
592    @staticmethod
593    def current_compile_id():
594        self = CompileContext.try_get()
595        if self is None:
596            return None
597        return self.compile_id
598
599    @staticmethod
600    def current_trace_id():
601        self = CompileContext.try_get()
602        if self is None:
603            return None
604        if self.compile_id is None:
605            return None
606        return TraceId(self.compile_id, self.attempt)
607
608
609class TracingContext:
610    """
611    Provides the currently installed TracingContext, or None.
612
613    Note that it is a staticmethod, and invocations outside of `with tracing()` (see below), are valid but
614    will return None.
615    """
616
617    @staticmethod
618    def try_get() -> Optional[TracingContext]:
619        return getattr(_TLS, "tracing_context", None)
620
621    @staticmethod
622    def get() -> TracingContext:
623        if ctx := TracingContext.try_get():
624            return ctx
625        raise RuntimeError(
626            "TracingContext.get() must be called within an ongoing trace."
627        )
628
629    def __init__(self, fake_mode):
630        self.guards_context = GuardsContext()
631        self.module_context = ModuleContext()
632        self.global_context = GlobalContext()
633        self.fake_mode = fake_mode
634        self.frame_summary_stack = []
635        # This is morally part of frame_summary_stack, but it is kept separate
636        # for clarity.  As we process a frame, this variable gets updated
637        # to keep track of what line we are in the function.  We make a
638        # function call, this gets cleared and the frame location is pushed
639        # to frame_summary_stack (prepping this variable for the inner frame's
640        # progress)
641        self.loc_in_frame = None
642        # this is only set after aot_autograd
643        self.fw_metadata = None
644        # this is only set after aot_autograd
645        self.aot_graph_name = None
646        self.params_flat = None
647        # this is for extended return calling convention from backend
648        # compiler to aot_autograd
649        # Per output, what the compiler specified stride of the output is,
650        # or None if no stride is known.  This is always the HINT, it
651        # is never a SymInt (it would be better if it was a SymInt, but
652        # I can't conveniently get this from Inductor atm.  Also, be
653        # careful not to accidentally induce guards on the SymInt if
654        # you ever do change this in aot_autograd.py; you should check
655        # on permutations preferentially.)
656        self.output_strides: Optional[List[Optional[Tuple[int, ...]]]] = None
657        # When this is True, whenever we encounter an int in Dynamo tracing,
658        # we will (1) force unspec it and (2) force it as a size-like unbacked
659        # integer.  This is currently used when processing certain lists of
660        # ints that are known to be size-like and may have 0/1 entries that we
661        # must not specialize on.
662        self.force_unspec_int_unbacked_size_like = False
663        # See note [Tensor Fakification and Symbol Caching]
664        self.tensor_to_context = WeakTensorKeyDictionary()
665
666        # If this true, Aot Autograd will return output Fake Tensors with appropiate
667        # meta on the first invocation
668        # see note: [Returning Fake Tensors on First AOT Autograd Call]
669        self.fakify_first_call = False
670
671    def clear(self):
672        # Look at the note in output_graph.py in function `save_global_state`
673        # for the context on clearing global context.
674        self.global_context.global_state = {}
675
676    @staticmethod
677    @contextmanager
678    def patch(**kwargs):
679        prior = {}
680        ctx = TracingContext.get()
681
682        for key in kwargs.keys():
683            # KeyError on invalid entry
684            prior[key] = getattr(ctx, key)
685        for key, val in kwargs.items():
686            setattr(ctx, key, val)
687        try:
688            yield
689        finally:
690            for key, val in prior.items():
691                setattr(ctx, key, val)
692
693    @staticmethod
694    def extract_stack():
695        self = TracingContext.try_get()
696        if self is None:
697            return traceback.StackSummary()
698        stack = self.frame_summary_stack
699        if self.loc_in_frame is not None:
700            stack = stack + [self.loc_in_frame]
701        return traceback.StackSummary.from_list(stack)
702
703    # Call this when you want to call into some code that isn't necessarily
704    # associated with the current frame state
705    @staticmethod
706    @contextlib.contextmanager
707    def clear_frame():
708        tc = TracingContext.get()
709        with unittest.mock.patch.object(
710            tc, "frame_summary_stack", []
711        ), unittest.mock.patch.object(tc, "loc_in_frame", None):
712            try:
713                yield
714            except Exception as e:
715                # Prevent real_stack from getting attached
716                #
717                # The invariant is that if an Exception as real_stack, we've
718                # appropriately attached a user stack and we no longer need to
719                # attach anything. Because we cannot conveniently interpose
720                # when an exception is thrown, we instead interpose everywhere
721                # we set what the user stack is set (using the context
722                # manager). However, our compiler stack does "tail calls"
723                # (when it calls into user compiler), at which point the
724                # parent exception frames would incorrectly attach an
725                # incorrect frame.
726                #
727                # However, if, somehow, someone raised an exception with this
728                # scope that had a stack (for example, because they are
729                # restoring the user stack state appropriately as they process
730                # node by node), we should respect it. Thus, we cannot
731                # unconditionally set None.
732                if not hasattr(e, "real_stack"):
733                    e.real_stack = None  # type: ignore[attr-defined]
734                raise
735
736    @staticmethod
737    @contextlib.contextmanager
738    def current_frame(frame_summary):
739        # frame_summary can be None to solely take advantage of real_stack
740        # attachment to thrown exceptions
741        tc = TracingContext.get()
742        if frame_summary is not None:
743            tc.frame_summary_stack.append(frame_summary)
744        old = tc.loc_in_frame
745        tc.loc_in_frame = None
746        try:
747            yield
748        except Exception as e:
749            if not hasattr(e, "real_stack"):
750                e.real_stack = tc.extract_stack()  # type: ignore[attr-defined]
751            raise
752        finally:
753            if frame_summary is not None:
754                tc.frame_summary_stack.pop()
755            tc.loc_in_frame = old
756
757    @staticmethod
758    @contextlib.contextmanager
759    def report_output_strides():
760        tc = TracingContext.try_get()
761        if tc is None:
762            yield None
763            return
764        old_output_strides = tc.output_strides
765        tc.output_strides = []
766        try:
767            yield tc.output_strides
768        finally:
769            tc.output_strides = old_output_strides
770
771    @staticmethod
772    def set_current_loc(filename, lineno, frame_name):
773        TracingContext.get().loc_in_frame = traceback.FrameSummary(
774            filename, lineno, frame_name, lookup_line=False
775        )
776
777
778@contextmanager
779def compile_context(context: Optional[CompileContext]):
780    old_context = getattr(_TLS, "compile_context", None)
781    _TLS.compile_context = context
782    try:
783        yield context
784    finally:
785        if context is not None:
786            if context.compile_id is not None:
787                set_context_frame(
788                    (
789                        context.compile_id.frame_id,
790                        context.compile_id.frame_compile_id,
791                        context.attempt,
792                    )
793                )
794        _TLS.compile_context = old_context
795
796
797@contextmanager
798def tracing(context: Optional[TracingContext]):
799    """
800    This function installs the passed in tracing context as a dynamic scoped
801    global variable.
802
803    Calls to TracingContext.get() while not under a `with tracing()` context
804    will return None.
805    """
806    old_context = getattr(_TLS, "tracing_context", None)
807    _TLS.tracing_context = context
808    try:
809        yield context
810    except Exception as e:
811        if not hasattr(e, "real_stack") and context is not None:
812            e.real_stack = context.extract_stack()  # type: ignore[attr-defined]
813        raise
814    finally:
815        if (
816            context is not None
817            and context.fake_mode is not None
818            and context.fake_mode.shape_env is not None
819        ):
820            context.fake_mode.shape_env.cleanup()
821        _TLS.tracing_context = old_context
822
823
824# Subclasses can be found in torch/_dynamo/source.py
825# TODO(voz): Consider a toplevel torch/_source.py
826@dataclasses.dataclass(frozen=True)
827class Source:
828    def is_dict_key(self):
829        return False
830
831    def is_ephemeral(self):
832        return False
833
834    def reconstruct(self, codegen):
835        raise NotImplementedError
836
837    def guard_source(self) -> GuardSource:
838        raise NotImplementedError
839
840    def name(self) -> str:
841        raise NotImplementedError
842
843    def make_guard(self, fn) -> Guard:
844        if self.guard_source() is GuardSource.CONSTANT:
845            raise NotImplementedError
846        return Guard(self, fn)
847
848    def is_specialized_nn_module(self) -> bool:
849        return self.guard_source().is_specialized_nn_module()
850
851    def subguards_allowed(self):
852        """True if you can guard on attributes of this"""
853        return self.guard_source() != GuardSource.SYNTHETIC_LOCAL
854
855
856# Subclasses can be found in torch/_dynamo/source.py
857@dataclasses.dataclass(frozen=True)
858class ChainedSource(Source):
859    base: Source
860
861    def is_dict_key(self):
862        # Recurse until you either hit a ConstDictKey or a Source
863        return self.base.is_dict_key()
864
865    def is_ephemeral(self):
866        return self.base.is_ephemeral()
867
868
869def detect_fake_mode(inputs: Any = None):
870    """
871    Attempts to "detect" what the current fake mode is.  If there is one ambiently
872    available from TracingContext, we preferentially use that.  Otherwise, we
873    heuristically detect the fake mode via the following sources, in order of
874    priority:
875
876        - Currently active fake mode on stack
877        - Fake mode associated with passed in tensors (inputs does not
878          have to be flattened)
879    """
880    from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
881
882    fake_modes = []
883
884    if context := TracingContext.try_get():
885        fake_mode = context.fake_mode
886        if fake_mode is not None:
887            fake_modes.append((fake_mode, "tracing context", 0))
888
889    from torch.utils._python_dispatch import _get_current_dispatch_mode_stack
890
891    for i, m in enumerate(reversed(_get_current_dispatch_mode_stack())):
892        if isinstance(m, FakeTensorMode):
893            fake_modes.append((m, "active fake mode", i))
894
895    flat_inputs = pytree.tree_leaves(inputs)
896    for i, flat_input in enumerate(flat_inputs):
897        if isinstance(flat_input, FakeTensor):
898            fake_modes.append((flat_input.fake_mode, "fake tensor input", i))
899
900    if fake_modes:
901        fake_mode, desc1, i1 = fake_modes[0]
902        for m, desc2, i2 in fake_modes[1:]:
903            assert fake_mode is m, (
904                f"fake mode ({fake_mode}) from {desc1} {i1} doesn't match mode ({m}) from {desc2} {i2}\n\n"
905                f"fake mode from {desc1} {i1} allocated at:\n{fake_mode.stack}\n"
906                f"fake mode from {desc2} {i2} allocated at:\n{m.stack}"
907            )
908        return fake_mode
909    else:
910        return None
911
912
913def active_fake_mode():
914    """
915    Inspects the dispatch mode stack for an active fake mode and returns it.
916    Returns None if no fake mode is active.
917    """
918    from torch._subclasses.fake_tensor import FakeTensorMode
919    from torch.utils._python_dispatch import _get_current_dispatch_mode_stack
920
921    for _, m in enumerate(reversed(_get_current_dispatch_mode_stack())):
922        if isinstance(m, FakeTensorMode):
923            return m
924
925    return None
926