xref: /aosp_15_r20/external/pytorch/torch/_guards.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerimport contextlib
5*da0073e9SAndroid Build Coastguard Workerimport dataclasses
6*da0073e9SAndroid Build Coastguard Workerimport enum
7*da0073e9SAndroid Build Coastguard Workerimport functools
8*da0073e9SAndroid Build Coastguard Workerimport logging
9*da0073e9SAndroid Build Coastguard Workerimport threading
10*da0073e9SAndroid Build Coastguard Workerimport traceback
11*da0073e9SAndroid Build Coastguard Workerimport unittest.mock
12*da0073e9SAndroid Build Coastguard Workerimport weakref
13*da0073e9SAndroid Build Coastguard Workerfrom abc import abstractmethod
14*da0073e9SAndroid Build Coastguard Workerfrom contextlib import contextmanager
15*da0073e9SAndroid Build Coastguard Workerfrom typing import (
16*da0073e9SAndroid Build Coastguard Worker    Any,
17*da0073e9SAndroid Build Coastguard Worker    Callable,
18*da0073e9SAndroid Build Coastguard Worker    Dict,
19*da0073e9SAndroid Build Coastguard Worker    Generic,
20*da0073e9SAndroid Build Coastguard Worker    List,
21*da0073e9SAndroid Build Coastguard Worker    NamedTuple,
22*da0073e9SAndroid Build Coastguard Worker    Optional,
23*da0073e9SAndroid Build Coastguard Worker    Set,
24*da0073e9SAndroid Build Coastguard Worker    Tuple,
25*da0073e9SAndroid Build Coastguard Worker    TYPE_CHECKING,
26*da0073e9SAndroid Build Coastguard Worker    TypeVar,
27*da0073e9SAndroid Build Coastguard Worker)
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Workerfrom torch._C._dynamo.eval_frame import set_context_frame  # noqa: F401
30*da0073e9SAndroid Build Coastguard Workerfrom torch.utils import _pytree as pytree
31*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._traceback import CapturedTraceback
32*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.weak import WeakTensorKeyDictionary
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Workerlog = logging.getLogger(__name__)
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Workerif TYPE_CHECKING:
39*da0073e9SAndroid Build Coastguard Worker    import sympy
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker    # Import the following modules during type checking to enable code intelligence features,
42*da0073e9SAndroid Build Coastguard Worker    # such as auto-completion in tools like pylance, even when these modules are not explicitly
43*da0073e9SAndroid Build Coastguard Worker    # imported in user code.
44*da0073e9SAndroid Build Coastguard Worker    import torch
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker"""
48*da0073e9SAndroid Build Coastguard Workertorch._guards is the definitional source of truth for general purpose guard structures.
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard WorkerAn important thing to keep in mind here is the preservation of layering. There should be no dynamo notions,
51*da0073e9SAndroid Build Coastguard Workerand no guard installation notions here.
52*da0073e9SAndroid Build Coastguard Worker"""
53*da0073e9SAndroid Build Coastguard Worker
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Workerclass CompileId(NamedTuple):
56*da0073e9SAndroid Build Coastguard Worker    frame_id: int
57*da0073e9SAndroid Build Coastguard Worker    # This id is per-frame, and counts how many times we've compiled this
58*da0073e9SAndroid Build Coastguard Worker    # frame.  This could have been a global id but having this be per-frame
59*da0073e9SAndroid Build Coastguard Worker    # gives you a better intuitive sense for how many recompiles have occurred
60*da0073e9SAndroid Build Coastguard Worker    # so far.
61*da0073e9SAndroid Build Coastguard Worker    frame_compile_id: int
62*da0073e9SAndroid Build Coastguard Worker    # TODO: consider also tracking the recompilation count
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker    def __str__(self):
65*da0073e9SAndroid Build Coastguard Worker        return f"{self.frame_id}/{self.frame_compile_id}"
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Workerclass TraceId(NamedTuple):
69*da0073e9SAndroid Build Coastguard Worker    compile_id: CompileId
70*da0073e9SAndroid Build Coastguard Worker    # This starts off as 0, and every time we restart analysis it goes
71*da0073e9SAndroid Build Coastguard Worker    # up by one
72*da0073e9SAndroid Build Coastguard Worker    attempt: int
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Worker    def __str__(self):
75*da0073e9SAndroid Build Coastguard Worker        if self.attempt == 0:
76*da0073e9SAndroid Build Coastguard Worker            return str(self.compile_id)
77*da0073e9SAndroid Build Coastguard Worker        else:
78*da0073e9SAndroid Build Coastguard Worker            return f"{self.compile_id}_{self.attempt}"
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Workerclass GuardSource(enum.Enum):
82*da0073e9SAndroid Build Coastguard Worker    LOCAL = 0
83*da0073e9SAndroid Build Coastguard Worker    GLOBAL = 1
84*da0073e9SAndroid Build Coastguard Worker    LOCAL_SPECIALIZED_NN_MODULE = 2
85*da0073e9SAndroid Build Coastguard Worker    GLOBAL_SPECIALIZED_NN_MODULE = 3
86*da0073e9SAndroid Build Coastguard Worker    CONSTANT = 4
87*da0073e9SAndroid Build Coastguard Worker    RANDOM_VALUE = 5
88*da0073e9SAndroid Build Coastguard Worker    SHAPE_ENV = 6
89*da0073e9SAndroid Build Coastguard Worker    LOCAL_FSDP_MODULE = 7
90*da0073e9SAndroid Build Coastguard Worker    GLOBAL_FSDP_MODULE = 8
91*da0073e9SAndroid Build Coastguard Worker    BACKWARD_STATE = 9
92*da0073e9SAndroid Build Coastguard Worker    EPHEMERAL = 10
93*da0073e9SAndroid Build Coastguard Worker    SYNTHETIC_LOCAL = 11
94*da0073e9SAndroid Build Coastguard Worker    LOCAL_UNSPECIALIZED_NN_MODULE = 12
95*da0073e9SAndroid Build Coastguard Worker    GLOBAL_UNSPECIALIZED_NN_MODULE = 13
96*da0073e9SAndroid Build Coastguard Worker    LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE = 14
97*da0073e9SAndroid Build Coastguard Worker    GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE = 15
98*da0073e9SAndroid Build Coastguard Worker
99*da0073e9SAndroid Build Coastguard Worker    def is_fsdp_module(self) -> bool:
100*da0073e9SAndroid Build Coastguard Worker        return self in (GuardSource.GLOBAL_FSDP_MODULE, GuardSource.LOCAL_FSDP_MODULE)
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Worker    def is_specialized_nn_module(self) -> bool:
103*da0073e9SAndroid Build Coastguard Worker        return (
104*da0073e9SAndroid Build Coastguard Worker            self
105*da0073e9SAndroid Build Coastguard Worker            in (
106*da0073e9SAndroid Build Coastguard Worker                GuardSource.GLOBAL_SPECIALIZED_NN_MODULE,
107*da0073e9SAndroid Build Coastguard Worker                GuardSource.LOCAL_SPECIALIZED_NN_MODULE,
108*da0073e9SAndroid Build Coastguard Worker            )
109*da0073e9SAndroid Build Coastguard Worker            # TODO (anijain2305) - Investigate why is_fsdp_module required.
110*da0073e9SAndroid Build Coastguard Worker            or self.is_fsdp_module()
111*da0073e9SAndroid Build Coastguard Worker        )
112*da0073e9SAndroid Build Coastguard Worker
113*da0073e9SAndroid Build Coastguard Worker    def is_unspecialized_nn_module(self) -> bool:
114*da0073e9SAndroid Build Coastguard Worker        return self in (
115*da0073e9SAndroid Build Coastguard Worker            GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE,
116*da0073e9SAndroid Build Coastguard Worker            GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE,
117*da0073e9SAndroid Build Coastguard Worker            GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
118*da0073e9SAndroid Build Coastguard Worker            GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
119*da0073e9SAndroid Build Coastguard Worker        )
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker    def is_unspecialized_builtin_nn_module(self) -> bool:
122*da0073e9SAndroid Build Coastguard Worker        return self in (
123*da0073e9SAndroid Build Coastguard Worker            GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
124*da0073e9SAndroid Build Coastguard Worker            GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
125*da0073e9SAndroid Build Coastguard Worker        )
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Worker    def is_local(self):
128*da0073e9SAndroid Build Coastguard Worker        return self in (
129*da0073e9SAndroid Build Coastguard Worker            GuardSource.LOCAL,
130*da0073e9SAndroid Build Coastguard Worker            GuardSource.LOCAL_SPECIALIZED_NN_MODULE,
131*da0073e9SAndroid Build Coastguard Worker            GuardSource.LOCAL_FSDP_MODULE,
132*da0073e9SAndroid Build Coastguard Worker            GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE,
133*da0073e9SAndroid Build Coastguard Worker            GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
134*da0073e9SAndroid Build Coastguard Worker        )
135*da0073e9SAndroid Build Coastguard Worker
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard Worker"""
138*da0073e9SAndroid Build Coastguard WorkerBase class for a "GuardBuilder" role.
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard WorkerThe GuardBuilderBase role is to represent a scope within which to build a guard. The name is a little
141*da0073e9SAndroid Build Coastguard Workerconfusing, as its not a builder, but for the sake of avoiding a lot of renames and keeping the original reference
142*da0073e9SAndroid Build Coastguard Workerto torchdynamo's GuardBuilder.
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard WorkerNote: create_fn is invoked with a GuardBuilderBase and a Guard. A GuardBuilder is chosen based
145*da0073e9SAndroid Build Coastguard Workeron GuardSource's select function.
146*da0073e9SAndroid Build Coastguard Worker
147*da0073e9SAndroid Build Coastguard WorkerThere is value in keeping this GuardBuilderBase empty to keep layering clean.
148*da0073e9SAndroid Build Coastguard Worker"""
149*da0073e9SAndroid Build Coastguard Worker
150*da0073e9SAndroid Build Coastguard Worker
151*da0073e9SAndroid Build Coastguard Workerclass GuardBuilderBase:
152*da0073e9SAndroid Build Coastguard Worker    pass
153*da0073e9SAndroid Build Coastguard Worker
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Workerclass ShapeGuard(NamedTuple):
156*da0073e9SAndroid Build Coastguard Worker    expr: sympy.Expr
157*da0073e9SAndroid Build Coastguard Worker    stack: CapturedTraceback
158*da0073e9SAndroid Build Coastguard Worker
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard Worker@dataclasses.dataclass
161*da0073e9SAndroid Build Coastguard Workerclass Guard:
162*da0073e9SAndroid Build Coastguard Worker    # originating_source is the source that called the make_guard method to
163*da0073e9SAndroid Build Coastguard Worker    # construct this guard object. The property name specifies what exactly it
164*da0073e9SAndroid Build Coastguard Worker    # is the guard is guarding on.  The meaning of the name is dependent on the
165*da0073e9SAndroid Build Coastguard Worker    # create_fn; you must look at the use-site inside create_fn to know what
166*da0073e9SAndroid Build Coastguard Worker    # name means.
167*da0073e9SAndroid Build Coastguard Worker    #
168*da0073e9SAndroid Build Coastguard Worker    # That being said, although you might think this is just a "name", name is
169*da0073e9SAndroid Build Coastguard Worker    # usually an arbitrary Python expression that will be evaluated with all
170*da0073e9SAndroid Build Coastguard Worker    # globals (and locals, if you create a LOCAL guard) to extract the Python
171*da0073e9SAndroid Build Coastguard Worker    # object that we want to perform guard tests on.  This evaluation
172*da0073e9SAndroid Build Coastguard Worker    # typically happens in GuardBuilder.eval.  In these cases, name is
173*da0073e9SAndroid Build Coastguard Worker    # typically produced by originating_source.name() (not to be confused with
174*da0073e9SAndroid Build Coastguard Worker    # GuardSource - the property source).
175*da0073e9SAndroid Build Coastguard Worker    #
176*da0073e9SAndroid Build Coastguard Worker    # Occasionally, name is not a valid Python expression; sometimes
177*da0073e9SAndroid Build Coastguard Worker    # it is meaningless.  Example create_fns that are like this include
178*da0073e9SAndroid Build Coastguard Worker    # GRAD_MODE and SHAPE_ENV.
179*da0073e9SAndroid Build Coastguard Worker    originating_source: Source
180*da0073e9SAndroid Build Coastguard Worker    create_fn: Callable[[GuardBuilderBase, Guard], None]
181*da0073e9SAndroid Build Coastguard Worker
182*da0073e9SAndroid Build Coastguard Worker    # Export only. These values are written to at time of guard check_fn creation.
183*da0073e9SAndroid Build Coastguard Worker    guard_types: Optional[List[str]] = None
184*da0073e9SAndroid Build Coastguard Worker    code_list: Optional[List[str]] = None
185*da0073e9SAndroid Build Coastguard Worker    obj_weakref: Optional[object] = None
186*da0073e9SAndroid Build Coastguard Worker    guarded_class_weakref: Optional[type] = None
187*da0073e9SAndroid Build Coastguard Worker
188*da0073e9SAndroid Build Coastguard Worker    stack: Optional[CapturedTraceback] = None
189*da0073e9SAndroid Build Coastguard Worker    user_stack: Optional[traceback.StackSummary] = None
190*da0073e9SAndroid Build Coastguard Worker    _hash: Optional[int] = None
191*da0073e9SAndroid Build Coastguard Worker
192*da0073e9SAndroid Build Coastguard Worker    def __hash__(self):
193*da0073e9SAndroid Build Coastguard Worker        if self._hash is None:
194*da0073e9SAndroid Build Coastguard Worker            self._hash = hash((self.name, self.source, id(self.create_fn)))
195*da0073e9SAndroid Build Coastguard Worker        return self._hash
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Worker    def sort_key(self):
198*da0073e9SAndroid Build Coastguard Worker        # Put the duplicate input guards at the end. The duplicate guards have
199*da0073e9SAndroid Build Coastguard Worker        # two sources while guard.name only considers one source.
200*da0073e9SAndroid Build Coastguard Worker        from torch._dynamo.guards import GuardBuilder
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Worker        is_duplicate_input = (
203*da0073e9SAndroid Build Coastguard Worker            isinstance(self.create_fn, functools.partial)
204*da0073e9SAndroid Build Coastguard Worker            and self.create_fn.func is GuardBuilder.DUPLICATE_INPUT
205*da0073e9SAndroid Build Coastguard Worker        )
206*da0073e9SAndroid Build Coastguard Worker        return (
207*da0073e9SAndroid Build Coastguard Worker            is_duplicate_input,
208*da0073e9SAndroid Build Coastguard Worker            self.source.value if self.source else -1,
209*da0073e9SAndroid Build Coastguard Worker            len(self.name),
210*da0073e9SAndroid Build Coastguard Worker            self.name,
211*da0073e9SAndroid Build Coastguard Worker            self.inner_create_fn().__code__.co_firstlineno,
212*da0073e9SAndroid Build Coastguard Worker        )
213*da0073e9SAndroid Build Coastguard Worker
214*da0073e9SAndroid Build Coastguard Worker    def __lt__(self, other):
215*da0073e9SAndroid Build Coastguard Worker        return self.sort_key() < other.sort_key()
216*da0073e9SAndroid Build Coastguard Worker
217*da0073e9SAndroid Build Coastguard Worker    def inner_create_fn(self):
218*da0073e9SAndroid Build Coastguard Worker        if isinstance(self.create_fn, functools.partial):
219*da0073e9SAndroid Build Coastguard Worker            return self.create_fn.func
220*da0073e9SAndroid Build Coastguard Worker        else:
221*da0073e9SAndroid Build Coastguard Worker            return self.create_fn
222*da0073e9SAndroid Build Coastguard Worker
223*da0073e9SAndroid Build Coastguard Worker    @property
224*da0073e9SAndroid Build Coastguard Worker    def name(self) -> str:
225*da0073e9SAndroid Build Coastguard Worker        return self.originating_source.name()
226*da0073e9SAndroid Build Coastguard Worker
227*da0073e9SAndroid Build Coastguard Worker    @property
228*da0073e9SAndroid Build Coastguard Worker    def source(self) -> GuardSource:
229*da0073e9SAndroid Build Coastguard Worker        return self.originating_source.guard_source()
230*da0073e9SAndroid Build Coastguard Worker
231*da0073e9SAndroid Build Coastguard Worker    @staticmethod
232*da0073e9SAndroid Build Coastguard Worker    def weakref_to_str(obj_weakref):
233*da0073e9SAndroid Build Coastguard Worker        """
234*da0073e9SAndroid Build Coastguard Worker        This is a workaround of a Python weakref bug.
235*da0073e9SAndroid Build Coastguard Worker
236*da0073e9SAndroid Build Coastguard Worker        `obj_weakref` is instance returned by `weakref.ref`,
237*da0073e9SAndroid Build Coastguard Worker        `str(obj_weakref)` is buggy if the original obj overrides __getattr__, e.g:
238*da0073e9SAndroid Build Coastguard Worker
239*da0073e9SAndroid Build Coastguard Worker            class MyConfig(dict):
240*da0073e9SAndroid Build Coastguard Worker                def __getattr__(self, x):
241*da0073e9SAndroid Build Coastguard Worker                    return self[x]
242*da0073e9SAndroid Build Coastguard Worker
243*da0073e9SAndroid Build Coastguard Worker            obj = MyConfig(offset=5)
244*da0073e9SAndroid Build Coastguard Worker            obj_weakref = weakref.ref(obj)
245*da0073e9SAndroid Build Coastguard Worker            str(obj_weakref)  # raise error: KeyError: '__name__'
246*da0073e9SAndroid Build Coastguard Worker        """
247*da0073e9SAndroid Build Coastguard Worker        if isinstance(obj_weakref, weakref.ReferenceType):
248*da0073e9SAndroid Build Coastguard Worker            obj = obj_weakref()
249*da0073e9SAndroid Build Coastguard Worker            if obj is not None:
250*da0073e9SAndroid Build Coastguard Worker                return f"<weakref at {hex(id(obj_weakref))}; to '{obj.__class__.__name__}' at {hex(id(obj))}>"
251*da0073e9SAndroid Build Coastguard Worker            else:
252*da0073e9SAndroid Build Coastguard Worker                return f"<weakref at {hex(id(obj_weakref))}; dead>"
253*da0073e9SAndroid Build Coastguard Worker        else:
254*da0073e9SAndroid Build Coastguard Worker            return str(obj_weakref)
255*da0073e9SAndroid Build Coastguard Worker
256*da0073e9SAndroid Build Coastguard Worker    def __repr__(self):
257*da0073e9SAndroid Build Coastguard Worker        s = f"""
258*da0073e9SAndroid Build Coastguard Worker        {self.source.name.lower() if self.source else ""} {repr(self.name)} {self.inner_create_fn().__name__}
259*da0073e9SAndroid Build Coastguard Worker        {{
260*da0073e9SAndroid Build Coastguard Worker            'guard_types': {self.guard_types},
261*da0073e9SAndroid Build Coastguard Worker            'code': {self.code_list},
262*da0073e9SAndroid Build Coastguard Worker            'obj_weakref': {self.weakref_to_str(self.obj_weakref)}
263*da0073e9SAndroid Build Coastguard Worker            'guarded_class': {self.guarded_class_weakref}
264*da0073e9SAndroid Build Coastguard Worker        }}
265*da0073e9SAndroid Build Coastguard Worker        """
266*da0073e9SAndroid Build Coastguard Worker        return s
267*da0073e9SAndroid Build Coastguard Worker
268*da0073e9SAndroid Build Coastguard Worker    def __str__(self):
269*da0073e9SAndroid Build Coastguard Worker        output = f"Name: {repr(self.name)}\n"
270*da0073e9SAndroid Build Coastguard Worker        source = self.source.name.lower() if self.source else ""
271*da0073e9SAndroid Build Coastguard Worker        output += f"    Source: {source}\n"
272*da0073e9SAndroid Build Coastguard Worker        output += f"    Create Function: {self.inner_create_fn().__name__}\n"
273*da0073e9SAndroid Build Coastguard Worker        output += f"    Guard Types: {self.guard_types}\n"
274*da0073e9SAndroid Build Coastguard Worker        output += f"    Code List: {self.code_list}\n"
275*da0073e9SAndroid Build Coastguard Worker        output += f"    Object Weakref: {self.weakref_to_str(self.obj_weakref)}\n"
276*da0073e9SAndroid Build Coastguard Worker        output += f"    Guarded Class Weakref: {self.guarded_class_weakref}\n"
277*da0073e9SAndroid Build Coastguard Worker        return output
278*da0073e9SAndroid Build Coastguard Worker
279*da0073e9SAndroid Build Coastguard Worker    def create(self, builder: GuardBuilderBase):
280*da0073e9SAndroid Build Coastguard Worker        try:
281*da0073e9SAndroid Build Coastguard Worker            return self.create_fn(builder, self)
282*da0073e9SAndroid Build Coastguard Worker        except Exception:
283*da0073e9SAndroid Build Coastguard Worker            log.exception("Error while creating guard:\n%s", str(self).rstrip())
284*da0073e9SAndroid Build Coastguard Worker            if self.stack:
285*da0073e9SAndroid Build Coastguard Worker                log.error("Created at:\n%s", "".join(self.stack.format()[-4:]).rstrip())
286*da0073e9SAndroid Build Coastguard Worker            raise
287*da0073e9SAndroid Build Coastguard Worker
288*da0073e9SAndroid Build Coastguard Worker    def is_specialized_nn_module(self):
289*da0073e9SAndroid Build Coastguard Worker        return self.source.is_specialized_nn_module()
290*da0073e9SAndroid Build Coastguard Worker
291*da0073e9SAndroid Build Coastguard Worker    def is_fsdp_module(self):
292*da0073e9SAndroid Build Coastguard Worker        return self.source.is_fsdp_module()
293*da0073e9SAndroid Build Coastguard Worker
294*da0073e9SAndroid Build Coastguard Worker    def is_local(self):
295*da0073e9SAndroid Build Coastguard Worker        return self.source.is_local()
296*da0073e9SAndroid Build Coastguard Worker
297*da0073e9SAndroid Build Coastguard Worker    def set_export_info(self, guard_type, guarded_class, code_list, obj_weakref):
298*da0073e9SAndroid Build Coastguard Worker        if not self.guard_types:
299*da0073e9SAndroid Build Coastguard Worker            self.guard_types = []
300*da0073e9SAndroid Build Coastguard Worker
301*da0073e9SAndroid Build Coastguard Worker        self.guard_types.append(guard_type)
302*da0073e9SAndroid Build Coastguard Worker
303*da0073e9SAndroid Build Coastguard Worker        assert self.guarded_class_weakref in (
304*da0073e9SAndroid Build Coastguard Worker            guarded_class,
305*da0073e9SAndroid Build Coastguard Worker            None,
306*da0073e9SAndroid Build Coastguard Worker        ), "Guarded class id must be identical, or None"
307*da0073e9SAndroid Build Coastguard Worker        self.guarded_class_weakref = guarded_class
308*da0073e9SAndroid Build Coastguard Worker
309*da0073e9SAndroid Build Coastguard Worker        if not self.code_list:
310*da0073e9SAndroid Build Coastguard Worker            self.code_list = code_list
311*da0073e9SAndroid Build Coastguard Worker        else:
312*da0073e9SAndroid Build Coastguard Worker            self.code_list.extend(code_list)
313*da0073e9SAndroid Build Coastguard Worker
314*da0073e9SAndroid Build Coastguard Worker        # Some objects are ephemeral, e.g., list[slice(1, 2)]. If we have
315*da0073e9SAndroid Build Coastguard Worker        # multiple guards on the same object, the weakref can die between the
316*da0073e9SAndroid Build Coastguard Worker        # invocation of set_export_info calls. So a dead weakref is also
317*da0073e9SAndroid Build Coastguard Worker        # acceptable.
318*da0073e9SAndroid Build Coastguard Worker        assert (
319*da0073e9SAndroid Build Coastguard Worker            self.obj_weakref in (obj_weakref, None)
320*da0073e9SAndroid Build Coastguard Worker            or callable(self.obj_weakref)
321*da0073e9SAndroid Build Coastguard Worker            and self.obj_weakref() is None
322*da0073e9SAndroid Build Coastguard Worker        ), "Guarded object must be identical, None or ephemeral (dead weakref)"
323*da0073e9SAndroid Build Coastguard Worker        self.obj_weakref = obj_weakref
324*da0073e9SAndroid Build Coastguard Worker
325*da0073e9SAndroid Build Coastguard Worker
326*da0073e9SAndroid Build Coastguard WorkerT = TypeVar("T")
327*da0073e9SAndroid Build Coastguard Worker
328*da0073e9SAndroid Build Coastguard Worker"""
329*da0073e9SAndroid Build Coastguard WorkerParent structure for guard env expressions.
330*da0073e9SAndroid Build Coastguard WorkerA GuardEnvExpr can have any subtype.
331*da0073e9SAndroid Build Coastguard WorkerNote: All subtypes must be handled exhaustively in
332*da0073e9SAndroid Build Coastguard Workertorch._dynamo.guards._parse_guard_env_guards to avoid a RuntimeError.
333*da0073e9SAndroid Build Coastguard Worker"""
334*da0073e9SAndroid Build Coastguard Worker
335*da0073e9SAndroid Build Coastguard Worker
336*da0073e9SAndroid Build Coastguard Worker@dataclasses.dataclass
337*da0073e9SAndroid Build Coastguard Workerclass GuardEnvExpr:
338*da0073e9SAndroid Build Coastguard Worker    pass
339*da0073e9SAndroid Build Coastguard Worker
340*da0073e9SAndroid Build Coastguard Worker
341*da0073e9SAndroid Build Coastguard Worker"""
342*da0073e9SAndroid Build Coastguard WorkerA class representing a pair of duplicate inputs.
343*da0073e9SAndroid Build Coastguard Workerinput_pos_a and input_pos_b are input positions we have deduped.
344*da0073e9SAndroid Build Coastguard Worker"""
345*da0073e9SAndroid Build Coastguard Worker
346*da0073e9SAndroid Build Coastguard Worker
347*da0073e9SAndroid Build Coastguard Worker@dataclasses.dataclass
348*da0073e9SAndroid Build Coastguard Workerclass DuplicateInputs(GuardEnvExpr):
349*da0073e9SAndroid Build Coastguard Worker    input_source_a: Source
350*da0073e9SAndroid Build Coastguard Worker    input_source_b: Source
351*da0073e9SAndroid Build Coastguard Worker
352*da0073e9SAndroid Build Coastguard Worker    def __post_init__(self):
353*da0073e9SAndroid Build Coastguard Worker        assert self.input_source_a != self.input_source_b
354*da0073e9SAndroid Build Coastguard Worker
355*da0073e9SAndroid Build Coastguard Worker
356*da0073e9SAndroid Build Coastguard Worker"""
357*da0073e9SAndroid Build Coastguard WorkerCheckpointable is an interface for driving state snapshotting, left purposely vague for now.
358*da0073e9SAndroid Build Coastguard Worker
359*da0073e9SAndroid Build Coastguard Workercopy_graphstate() -> T, a somewhat legacy name, is expected to emit a snapshot of any type that
360*da0073e9SAndroid Build Coastguard Workercan also be taken in at restore_graphstate(T) calls.
361*da0073e9SAndroid Build Coastguard Worker
362*da0073e9SAndroid Build Coastguard WorkerWhen to snapshot, is, at the moment, an implementation detail of upstream callers. Checkpointable
363*da0073e9SAndroid Build Coastguard Workerdoes not provide any garuantees around consistency, idempotency, or safety of calling its APIs, yet.
364*da0073e9SAndroid Build Coastguard Worker
365*da0073e9SAndroid Build Coastguard WorkerIn the future, it will have a closer coupling to a generic Checkpoint management system.
366*da0073e9SAndroid Build Coastguard Worker"""
367*da0073e9SAndroid Build Coastguard Worker
368*da0073e9SAndroid Build Coastguard Worker
369*da0073e9SAndroid Build Coastguard Workerclass Checkpointable(Generic[T]):
370*da0073e9SAndroid Build Coastguard Worker    @abstractmethod
371*da0073e9SAndroid Build Coastguard Worker    def copy_graphstate(self) -> T: ...
372*da0073e9SAndroid Build Coastguard Worker
373*da0073e9SAndroid Build Coastguard Worker    @abstractmethod
374*da0073e9SAndroid Build Coastguard Worker    def restore_graphstate(self, state: T): ...
375*da0073e9SAndroid Build Coastguard Worker
376*da0073e9SAndroid Build Coastguard Worker
377*da0073e9SAndroid Build Coastguard Workerclass GuardsCheckpointState:
378*da0073e9SAndroid Build Coastguard Worker    """
379*da0073e9SAndroid Build Coastguard Worker    The GuardCheckpointState - it is the T of Checkpointable[T] for GuardsContext
380*da0073e9SAndroid Build Coastguard Worker    """
381*da0073e9SAndroid Build Coastguard Worker
382*da0073e9SAndroid Build Coastguard Worker    dynamo_guards: Set[Guard] = set()
383*da0073e9SAndroid Build Coastguard Worker
384*da0073e9SAndroid Build Coastguard Worker    def __init__(self, dynamo_guards):
385*da0073e9SAndroid Build Coastguard Worker        self.dynamo_guards = dynamo_guards
386*da0073e9SAndroid Build Coastguard Worker
387*da0073e9SAndroid Build Coastguard Worker    def diff(self, other):
388*da0073e9SAndroid Build Coastguard Worker        """
389*da0073e9SAndroid Build Coastguard Worker        Produces a delta against another GuardsCheckpointState.
390*da0073e9SAndroid Build Coastguard Worker
391*da0073e9SAndroid Build Coastguard Worker        Returns None if no delta is found, otherwise, return a set() of mismatched
392*da0073e9SAndroid Build Coastguard Worker        Guard type objects.
393*da0073e9SAndroid Build Coastguard Worker        """
394*da0073e9SAndroid Build Coastguard Worker        r = self.dynamo_guards.difference(other.dynamo_guards)
395*da0073e9SAndroid Build Coastguard Worker        if len(r) == 0:
396*da0073e9SAndroid Build Coastguard Worker            return None
397*da0073e9SAndroid Build Coastguard Worker        return r
398*da0073e9SAndroid Build Coastguard Worker
399*da0073e9SAndroid Build Coastguard Worker    def __eq__(self, other):
400*da0073e9SAndroid Build Coastguard Worker        return self.diff(other) is None
401*da0073e9SAndroid Build Coastguard Worker
402*da0073e9SAndroid Build Coastguard Worker
403*da0073e9SAndroid Build Coastguard Workerclass ModuleContextCheckpointState:
404*da0073e9SAndroid Build Coastguard Worker    nn_modules: Dict[str, torch.nn.Module] = {}
405*da0073e9SAndroid Build Coastguard Worker
406*da0073e9SAndroid Build Coastguard Worker    def __init__(self, nn_modules):
407*da0073e9SAndroid Build Coastguard Worker        self.nn_modules = nn_modules
408*da0073e9SAndroid Build Coastguard Worker
409*da0073e9SAndroid Build Coastguard Worker    def diff(self, other):
410*da0073e9SAndroid Build Coastguard Worker        """
411*da0073e9SAndroid Build Coastguard Worker        Produces a delta against another ModuleContextCheckpointState.
412*da0073e9SAndroid Build Coastguard Worker
413*da0073e9SAndroid Build Coastguard Worker        Returns None if no delta is found, otherwise, return a set() of mismatched
414*da0073e9SAndroid Build Coastguard Worker        module key names.
415*da0073e9SAndroid Build Coastguard Worker        """
416*da0073e9SAndroid Build Coastguard Worker        r = set(self.nn_modules.keys()).difference(set(other.nn_modules.keys()))
417*da0073e9SAndroid Build Coastguard Worker        if len(r) == 0:
418*da0073e9SAndroid Build Coastguard Worker            return None
419*da0073e9SAndroid Build Coastguard Worker        return r
420*da0073e9SAndroid Build Coastguard Worker
421*da0073e9SAndroid Build Coastguard Worker    def __eq__(self, other):
422*da0073e9SAndroid Build Coastguard Worker        return self.diff(other) is None
423*da0073e9SAndroid Build Coastguard Worker
424*da0073e9SAndroid Build Coastguard Worker
425*da0073e9SAndroid Build Coastguard Workerclass ModuleContext(Checkpointable[ModuleContextCheckpointState]):
426*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
427*da0073e9SAndroid Build Coastguard Worker        self.nn_modules: Dict[str, Any] = {}
428*da0073e9SAndroid Build Coastguard Worker
429*da0073e9SAndroid Build Coastguard Worker    def copy_graphstate(self):
430*da0073e9SAndroid Build Coastguard Worker        return ModuleContextCheckpointState(dict(self.nn_modules))
431*da0073e9SAndroid Build Coastguard Worker
432*da0073e9SAndroid Build Coastguard Worker    def restore_graphstate(self, state):
433*da0073e9SAndroid Build Coastguard Worker        assert isinstance(state, ModuleContextCheckpointState)
434*da0073e9SAndroid Build Coastguard Worker        self.nn_modules = state.nn_modules
435*da0073e9SAndroid Build Coastguard Worker
436*da0073e9SAndroid Build Coastguard Worker
437*da0073e9SAndroid Build Coastguard Workerclass GlobalContextCheckpointState:
438*da0073e9SAndroid Build Coastguard Worker    global_state: Dict[str, Tuple[Callable, ...]] = {}
439*da0073e9SAndroid Build Coastguard Worker
440*da0073e9SAndroid Build Coastguard Worker    def __init__(self, global_states):
441*da0073e9SAndroid Build Coastguard Worker        self.global_state = global_states
442*da0073e9SAndroid Build Coastguard Worker
443*da0073e9SAndroid Build Coastguard Worker    def diff(self, other):
444*da0073e9SAndroid Build Coastguard Worker        """
445*da0073e9SAndroid Build Coastguard Worker        Produces a delta against another GlobalContextCheckpointState.
446*da0073e9SAndroid Build Coastguard Worker
447*da0073e9SAndroid Build Coastguard Worker        Returns None if no delta is found, otherwise, return a set() of mismatched
448*da0073e9SAndroid Build Coastguard Worker        global key names.
449*da0073e9SAndroid Build Coastguard Worker        """
450*da0073e9SAndroid Build Coastguard Worker        r = set(self.global_state.keys()).difference(set(other.global_state.keys()))
451*da0073e9SAndroid Build Coastguard Worker        if len(r) == 0:
452*da0073e9SAndroid Build Coastguard Worker            return None
453*da0073e9SAndroid Build Coastguard Worker        return r
454*da0073e9SAndroid Build Coastguard Worker
455*da0073e9SAndroid Build Coastguard Worker    def __eq__(self, other):
456*da0073e9SAndroid Build Coastguard Worker        return self.diff(other) is None
457*da0073e9SAndroid Build Coastguard Worker
458*da0073e9SAndroid Build Coastguard Worker
459*da0073e9SAndroid Build Coastguard Workerclass GlobalContext(Checkpointable[GlobalContextCheckpointState]):
460*da0073e9SAndroid Build Coastguard Worker    """
461*da0073e9SAndroid Build Coastguard Worker    This keeps track of the global torch state during tracing of a function.
462*da0073e9SAndroid Build Coastguard Worker    For example, torch.is_grad_enabled.
463*da0073e9SAndroid Build Coastguard Worker    """
464*da0073e9SAndroid Build Coastguard Worker
465*da0073e9SAndroid Build Coastguard Worker    _supported_global_states = {
466*da0073e9SAndroid Build Coastguard Worker        "grad_enabled",
467*da0073e9SAndroid Build Coastguard Worker        "torch_function_enabled",
468*da0073e9SAndroid Build Coastguard Worker        "autocast_enabled",
469*da0073e9SAndroid Build Coastguard Worker        "autocast_cpu_enabled",
470*da0073e9SAndroid Build Coastguard Worker        "autocast_gpu_dtype",
471*da0073e9SAndroid Build Coastguard Worker        "autocast_cpu_dtype",
472*da0073e9SAndroid Build Coastguard Worker        "autocast_cache_enabled",
473*da0073e9SAndroid Build Coastguard Worker    }
474*da0073e9SAndroid Build Coastguard Worker
475*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
476*da0073e9SAndroid Build Coastguard Worker        self.global_state: Dict[str, Tuple[Callable, ...]] = {}
477*da0073e9SAndroid Build Coastguard Worker
478*da0073e9SAndroid Build Coastguard Worker    def copy_graphstate(self):
479*da0073e9SAndroid Build Coastguard Worker        return GlobalContextCheckpointState(dict(self.global_state))
480*da0073e9SAndroid Build Coastguard Worker
481*da0073e9SAndroid Build Coastguard Worker    def restore_graphstate(self, state):
482*da0073e9SAndroid Build Coastguard Worker        assert isinstance(state, GlobalContextCheckpointState)
483*da0073e9SAndroid Build Coastguard Worker        self.global_state = state.global_state
484*da0073e9SAndroid Build Coastguard Worker        assert (
485*da0073e9SAndroid Build Coastguard Worker            len(self.global_state) == len(self._supported_global_states)
486*da0073e9SAndroid Build Coastguard Worker            and set(self.global_state.keys()) == self._supported_global_states
487*da0073e9SAndroid Build Coastguard Worker        ), "Global state mismatch"
488*da0073e9SAndroid Build Coastguard Worker        for func, args in self.global_state.values():
489*da0073e9SAndroid Build Coastguard Worker            func(args)
490*da0073e9SAndroid Build Coastguard Worker
491*da0073e9SAndroid Build Coastguard Worker
492*da0073e9SAndroid Build Coastguard Worker"""
493*da0073e9SAndroid Build Coastguard WorkerA GuardsContext is a checkpointable representation of all the guards in the current tracing
494*da0073e9SAndroid Build Coastguard Workercontext. It's lifecycle is bound 1:1 to the tracing context, and it should never be instantiated
495*da0073e9SAndroid Build Coastguard Workerdirectly outside of it. For passing around internal state representations of this object,
496*da0073e9SAndroid Build Coastguard Workerprefer to extract them with copy_graphstate to produce a GuardsCheckpointState.
497*da0073e9SAndroid Build Coastguard Worker"""
498*da0073e9SAndroid Build Coastguard Worker
499*da0073e9SAndroid Build Coastguard Worker
500*da0073e9SAndroid Build Coastguard Worker# Like a Set[Guard] but will record the user stack on all guards at the
501*da0073e9SAndroid Build Coastguard Worker# time they were installed at their destination
502*da0073e9SAndroid Build Coastguard Workerclass GuardsSet:
503*da0073e9SAndroid Build Coastguard Worker    def __init__(self, inner=None):
504*da0073e9SAndroid Build Coastguard Worker        if inner is None:
505*da0073e9SAndroid Build Coastguard Worker            inner = set()
506*da0073e9SAndroid Build Coastguard Worker        self.inner = inner
507*da0073e9SAndroid Build Coastguard Worker
508*da0073e9SAndroid Build Coastguard Worker    def __iter__(self):
509*da0073e9SAndroid Build Coastguard Worker        return iter(self.inner)
510*da0073e9SAndroid Build Coastguard Worker
511*da0073e9SAndroid Build Coastguard Worker    def __len__(self):
512*da0073e9SAndroid Build Coastguard Worker        return len(self.inner)
513*da0073e9SAndroid Build Coastguard Worker
514*da0073e9SAndroid Build Coastguard Worker    # Subtraction along with bool is typically used to determine the delta of
515*da0073e9SAndroid Build Coastguard Worker    # added guards between checkpoints for higher order ops
516*da0073e9SAndroid Build Coastguard Worker    def __sub__(self, other):
517*da0073e9SAndroid Build Coastguard Worker        return GuardsSet(self.inner - other.inner)
518*da0073e9SAndroid Build Coastguard Worker
519*da0073e9SAndroid Build Coastguard Worker    def __bool__(self):
520*da0073e9SAndroid Build Coastguard Worker        return bool(self.inner)
521*da0073e9SAndroid Build Coastguard Worker
522*da0073e9SAndroid Build Coastguard Worker    def add(self, guard: Guard, *, collect_debug_stack=True, skip=0):
523*da0073e9SAndroid Build Coastguard Worker        if guard in self.inner:
524*da0073e9SAndroid Build Coastguard Worker            return
525*da0073e9SAndroid Build Coastguard Worker        if collect_debug_stack:
526*da0073e9SAndroid Build Coastguard Worker            if guard.stack is None:
527*da0073e9SAndroid Build Coastguard Worker                guard.stack = CapturedTraceback.extract(skip=1 + skip)
528*da0073e9SAndroid Build Coastguard Worker            if guard.user_stack is None:
529*da0073e9SAndroid Build Coastguard Worker                guard.user_stack = TracingContext.extract_stack()
530*da0073e9SAndroid Build Coastguard Worker        self.inner.add(guard)
531*da0073e9SAndroid Build Coastguard Worker
532*da0073e9SAndroid Build Coastguard Worker    def update(self, *others: Set[Guard]):
533*da0073e9SAndroid Build Coastguard Worker        for o in others:
534*da0073e9SAndroid Build Coastguard Worker            for g in o:
535*da0073e9SAndroid Build Coastguard Worker                self.add(g, skip=1)
536*da0073e9SAndroid Build Coastguard Worker
537*da0073e9SAndroid Build Coastguard Worker    def remove_guards_with_source(self, source):
538*da0073e9SAndroid Build Coastguard Worker        """Delete all guards with a given source"""
539*da0073e9SAndroid Build Coastguard Worker        self.inner = {g for g in self.inner if g.originating_source != source}
540*da0073e9SAndroid Build Coastguard Worker
541*da0073e9SAndroid Build Coastguard Worker
542*da0073e9SAndroid Build Coastguard Workerclass GuardsContext(Checkpointable[GuardsCheckpointState]):
543*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
544*da0073e9SAndroid Build Coastguard Worker        self.dynamo_guards: GuardsSet = GuardsSet()
545*da0073e9SAndroid Build Coastguard Worker        self.aotautograd_guards: List[GuardEnvExpr] = []
546*da0073e9SAndroid Build Coastguard Worker
547*da0073e9SAndroid Build Coastguard Worker    def copy_graphstate(self):
548*da0073e9SAndroid Build Coastguard Worker        return GuardsCheckpointState(set(self.dynamo_guards.inner))
549*da0073e9SAndroid Build Coastguard Worker
550*da0073e9SAndroid Build Coastguard Worker    def restore_graphstate(self, state):
551*da0073e9SAndroid Build Coastguard Worker        # NB: "steals" the passed in state
552*da0073e9SAndroid Build Coastguard Worker        assert isinstance(state, GuardsCheckpointState)
553*da0073e9SAndroid Build Coastguard Worker        self.dynamo_guards = GuardsSet(state.dynamo_guards)
554*da0073e9SAndroid Build Coastguard Worker
555*da0073e9SAndroid Build Coastguard Worker
556*da0073e9SAndroid Build Coastguard Worker_TLS = threading.local()
557*da0073e9SAndroid Build Coastguard Worker
558*da0073e9SAndroid Build Coastguard Worker"""
559*da0073e9SAndroid Build Coastguard WorkerTracingContext is the source of truth for all currently accumulated information
560*da0073e9SAndroid Build Coastguard Workerneeded to trace. Its lifecycle is kept 1:1 when using TorchDynamo, but other systems
561*da0073e9SAndroid Build Coastguard Workerare open to managing their own TracingContext with that in mind.
562*da0073e9SAndroid Build Coastguard Worker
563*da0073e9SAndroid Build Coastguard WorkerThe purpose of TracingContext is not to be a dumping ground, or god object, but rather to avoid
564*da0073e9SAndroid Build Coastguard Workerhaving to plumb complex subsystems across multiple verticals.
565*da0073e9SAndroid Build Coastguard Worker
566*da0073e9SAndroid Build Coastguard WorkerEx: A common example is guard accumulation between dynamo, shape_env, aot_autograd, and inductor.
567*da0073e9SAndroid Build Coastguard WorkerAccessing the current tracing context via
568*da0073e9SAndroid Build Coastguard WorkerTracingContext.get() allows users to accumulate their own guards for processing, without needing to know how
569*da0073e9SAndroid Build Coastguard Workerto plumb objects back up to where frame interpretation happened.
570*da0073e9SAndroid Build Coastguard Worker
571*da0073e9SAndroid Build Coastguard WorkerNote that you can end up with multiple TracingContext for a single compilation
572*da0073e9SAndroid Build Coastguard Workerof a frame, as we reset the TracingContext whenever we restart analysis.
573*da0073e9SAndroid Build Coastguard WorkerCompileContext is a more overarching context that encompasses multiple restarts.
574*da0073e9SAndroid Build Coastguard Worker"""
575*da0073e9SAndroid Build Coastguard Worker
576*da0073e9SAndroid Build Coastguard Worker
577*da0073e9SAndroid Build Coastguard Workerclass CompileContext:
578*da0073e9SAndroid Build Coastguard Worker    @staticmethod
579*da0073e9SAndroid Build Coastguard Worker    def get() -> CompileContext:
580*da0073e9SAndroid Build Coastguard Worker        assert _TLS.compile_context is not None
581*da0073e9SAndroid Build Coastguard Worker        return _TLS.compile_context
582*da0073e9SAndroid Build Coastguard Worker
583*da0073e9SAndroid Build Coastguard Worker    @staticmethod
584*da0073e9SAndroid Build Coastguard Worker    def try_get() -> Optional[CompileContext]:
585*da0073e9SAndroid Build Coastguard Worker        return getattr(_TLS, "compile_context", None)
586*da0073e9SAndroid Build Coastguard Worker
587*da0073e9SAndroid Build Coastguard Worker    def __init__(self, compile_id):
588*da0073e9SAndroid Build Coastguard Worker        assert compile_id is None or isinstance(compile_id, CompileId)
589*da0073e9SAndroid Build Coastguard Worker        self.compile_id: Optional[CompileId] = compile_id
590*da0073e9SAndroid Build Coastguard Worker        self.attempt = 0
591*da0073e9SAndroid Build Coastguard Worker
592*da0073e9SAndroid Build Coastguard Worker    @staticmethod
593*da0073e9SAndroid Build Coastguard Worker    def current_compile_id():
594*da0073e9SAndroid Build Coastguard Worker        self = CompileContext.try_get()
595*da0073e9SAndroid Build Coastguard Worker        if self is None:
596*da0073e9SAndroid Build Coastguard Worker            return None
597*da0073e9SAndroid Build Coastguard Worker        return self.compile_id
598*da0073e9SAndroid Build Coastguard Worker
599*da0073e9SAndroid Build Coastguard Worker    @staticmethod
600*da0073e9SAndroid Build Coastguard Worker    def current_trace_id():
601*da0073e9SAndroid Build Coastguard Worker        self = CompileContext.try_get()
602*da0073e9SAndroid Build Coastguard Worker        if self is None:
603*da0073e9SAndroid Build Coastguard Worker            return None
604*da0073e9SAndroid Build Coastguard Worker        if self.compile_id is None:
605*da0073e9SAndroid Build Coastguard Worker            return None
606*da0073e9SAndroid Build Coastguard Worker        return TraceId(self.compile_id, self.attempt)
607*da0073e9SAndroid Build Coastguard Worker
608*da0073e9SAndroid Build Coastguard Worker
609*da0073e9SAndroid Build Coastguard Workerclass TracingContext:
610*da0073e9SAndroid Build Coastguard Worker    """
611*da0073e9SAndroid Build Coastguard Worker    Provides the currently installed TracingContext, or None.
612*da0073e9SAndroid Build Coastguard Worker
613*da0073e9SAndroid Build Coastguard Worker    Note that it is a staticmethod, and invocations outside of `with tracing()` (see below), are valid but
614*da0073e9SAndroid Build Coastguard Worker    will return None.
615*da0073e9SAndroid Build Coastguard Worker    """
616*da0073e9SAndroid Build Coastguard Worker
617*da0073e9SAndroid Build Coastguard Worker    @staticmethod
618*da0073e9SAndroid Build Coastguard Worker    def try_get() -> Optional[TracingContext]:
619*da0073e9SAndroid Build Coastguard Worker        return getattr(_TLS, "tracing_context", None)
620*da0073e9SAndroid Build Coastguard Worker
621*da0073e9SAndroid Build Coastguard Worker    @staticmethod
622*da0073e9SAndroid Build Coastguard Worker    def get() -> TracingContext:
623*da0073e9SAndroid Build Coastguard Worker        if ctx := TracingContext.try_get():
624*da0073e9SAndroid Build Coastguard Worker            return ctx
625*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
626*da0073e9SAndroid Build Coastguard Worker            "TracingContext.get() must be called within an ongoing trace."
627*da0073e9SAndroid Build Coastguard Worker        )
628*da0073e9SAndroid Build Coastguard Worker
629*da0073e9SAndroid Build Coastguard Worker    def __init__(self, fake_mode):
630*da0073e9SAndroid Build Coastguard Worker        self.guards_context = GuardsContext()
631*da0073e9SAndroid Build Coastguard Worker        self.module_context = ModuleContext()
632*da0073e9SAndroid Build Coastguard Worker        self.global_context = GlobalContext()
633*da0073e9SAndroid Build Coastguard Worker        self.fake_mode = fake_mode
634*da0073e9SAndroid Build Coastguard Worker        self.frame_summary_stack = []
635*da0073e9SAndroid Build Coastguard Worker        # This is morally part of frame_summary_stack, but it is kept separate
636*da0073e9SAndroid Build Coastguard Worker        # for clarity.  As we process a frame, this variable gets updated
637*da0073e9SAndroid Build Coastguard Worker        # to keep track of what line we are in the function.  We make a
638*da0073e9SAndroid Build Coastguard Worker        # function call, this gets cleared and the frame location is pushed
639*da0073e9SAndroid Build Coastguard Worker        # to frame_summary_stack (prepping this variable for the inner frame's
640*da0073e9SAndroid Build Coastguard Worker        # progress)
641*da0073e9SAndroid Build Coastguard Worker        self.loc_in_frame = None
642*da0073e9SAndroid Build Coastguard Worker        # this is only set after aot_autograd
643*da0073e9SAndroid Build Coastguard Worker        self.fw_metadata = None
644*da0073e9SAndroid Build Coastguard Worker        # this is only set after aot_autograd
645*da0073e9SAndroid Build Coastguard Worker        self.aot_graph_name = None
646*da0073e9SAndroid Build Coastguard Worker        self.params_flat = None
647*da0073e9SAndroid Build Coastguard Worker        # this is for extended return calling convention from backend
648*da0073e9SAndroid Build Coastguard Worker        # compiler to aot_autograd
649*da0073e9SAndroid Build Coastguard Worker        # Per output, what the compiler specified stride of the output is,
650*da0073e9SAndroid Build Coastguard Worker        # or None if no stride is known.  This is always the HINT, it
651*da0073e9SAndroid Build Coastguard Worker        # is never a SymInt (it would be better if it was a SymInt, but
652*da0073e9SAndroid Build Coastguard Worker        # I can't conveniently get this from Inductor atm.  Also, be
653*da0073e9SAndroid Build Coastguard Worker        # careful not to accidentally induce guards on the SymInt if
654*da0073e9SAndroid Build Coastguard Worker        # you ever do change this in aot_autograd.py; you should check
655*da0073e9SAndroid Build Coastguard Worker        # on permutations preferentially.)
656*da0073e9SAndroid Build Coastguard Worker        self.output_strides: Optional[List[Optional[Tuple[int, ...]]]] = None
657*da0073e9SAndroid Build Coastguard Worker        # When this is True, whenever we encounter an int in Dynamo tracing,
658*da0073e9SAndroid Build Coastguard Worker        # we will (1) force unspec it and (2) force it as a size-like unbacked
659*da0073e9SAndroid Build Coastguard Worker        # integer.  This is currently used when processing certain lists of
660*da0073e9SAndroid Build Coastguard Worker        # ints that are known to be size-like and may have 0/1 entries that we
661*da0073e9SAndroid Build Coastguard Worker        # must not specialize on.
662*da0073e9SAndroid Build Coastguard Worker        self.force_unspec_int_unbacked_size_like = False
663*da0073e9SAndroid Build Coastguard Worker        # See note [Tensor Fakification and Symbol Caching]
664*da0073e9SAndroid Build Coastguard Worker        self.tensor_to_context = WeakTensorKeyDictionary()
665*da0073e9SAndroid Build Coastguard Worker
666*da0073e9SAndroid Build Coastguard Worker        # If this true, Aot Autograd will return output Fake Tensors with appropiate
667*da0073e9SAndroid Build Coastguard Worker        # meta on the first invocation
668*da0073e9SAndroid Build Coastguard Worker        # see note: [Returning Fake Tensors on First AOT Autograd Call]
669*da0073e9SAndroid Build Coastguard Worker        self.fakify_first_call = False
670*da0073e9SAndroid Build Coastguard Worker
671*da0073e9SAndroid Build Coastguard Worker    def clear(self):
672*da0073e9SAndroid Build Coastguard Worker        # Look at the note in output_graph.py in function `save_global_state`
673*da0073e9SAndroid Build Coastguard Worker        # for the context on clearing global context.
674*da0073e9SAndroid Build Coastguard Worker        self.global_context.global_state = {}
675*da0073e9SAndroid Build Coastguard Worker
676*da0073e9SAndroid Build Coastguard Worker    @staticmethod
677*da0073e9SAndroid Build Coastguard Worker    @contextmanager
678*da0073e9SAndroid Build Coastguard Worker    def patch(**kwargs):
679*da0073e9SAndroid Build Coastguard Worker        prior = {}
680*da0073e9SAndroid Build Coastguard Worker        ctx = TracingContext.get()
681*da0073e9SAndroid Build Coastguard Worker
682*da0073e9SAndroid Build Coastguard Worker        for key in kwargs.keys():
683*da0073e9SAndroid Build Coastguard Worker            # KeyError on invalid entry
684*da0073e9SAndroid Build Coastguard Worker            prior[key] = getattr(ctx, key)
685*da0073e9SAndroid Build Coastguard Worker        for key, val in kwargs.items():
686*da0073e9SAndroid Build Coastguard Worker            setattr(ctx, key, val)
687*da0073e9SAndroid Build Coastguard Worker        try:
688*da0073e9SAndroid Build Coastguard Worker            yield
689*da0073e9SAndroid Build Coastguard Worker        finally:
690*da0073e9SAndroid Build Coastguard Worker            for key, val in prior.items():
691*da0073e9SAndroid Build Coastguard Worker                setattr(ctx, key, val)
692*da0073e9SAndroid Build Coastguard Worker
693*da0073e9SAndroid Build Coastguard Worker    @staticmethod
694*da0073e9SAndroid Build Coastguard Worker    def extract_stack():
695*da0073e9SAndroid Build Coastguard Worker        self = TracingContext.try_get()
696*da0073e9SAndroid Build Coastguard Worker        if self is None:
697*da0073e9SAndroid Build Coastguard Worker            return traceback.StackSummary()
698*da0073e9SAndroid Build Coastguard Worker        stack = self.frame_summary_stack
699*da0073e9SAndroid Build Coastguard Worker        if self.loc_in_frame is not None:
700*da0073e9SAndroid Build Coastguard Worker            stack = stack + [self.loc_in_frame]
701*da0073e9SAndroid Build Coastguard Worker        return traceback.StackSummary.from_list(stack)
702*da0073e9SAndroid Build Coastguard Worker
703*da0073e9SAndroid Build Coastguard Worker    # Call this when you want to call into some code that isn't necessarily
704*da0073e9SAndroid Build Coastguard Worker    # associated with the current frame state
705*da0073e9SAndroid Build Coastguard Worker    @staticmethod
706*da0073e9SAndroid Build Coastguard Worker    @contextlib.contextmanager
707*da0073e9SAndroid Build Coastguard Worker    def clear_frame():
708*da0073e9SAndroid Build Coastguard Worker        tc = TracingContext.get()
709*da0073e9SAndroid Build Coastguard Worker        with unittest.mock.patch.object(
710*da0073e9SAndroid Build Coastguard Worker            tc, "frame_summary_stack", []
711*da0073e9SAndroid Build Coastguard Worker        ), unittest.mock.patch.object(tc, "loc_in_frame", None):
712*da0073e9SAndroid Build Coastguard Worker            try:
713*da0073e9SAndroid Build Coastguard Worker                yield
714*da0073e9SAndroid Build Coastguard Worker            except Exception as e:
715*da0073e9SAndroid Build Coastguard Worker                # Prevent real_stack from getting attached
716*da0073e9SAndroid Build Coastguard Worker                #
717*da0073e9SAndroid Build Coastguard Worker                # The invariant is that if an Exception as real_stack, we've
718*da0073e9SAndroid Build Coastguard Worker                # appropriately attached a user stack and we no longer need to
719*da0073e9SAndroid Build Coastguard Worker                # attach anything. Because we cannot conveniently interpose
720*da0073e9SAndroid Build Coastguard Worker                # when an exception is thrown, we instead interpose everywhere
721*da0073e9SAndroid Build Coastguard Worker                # we set what the user stack is set (using the context
722*da0073e9SAndroid Build Coastguard Worker                # manager). However, our compiler stack does "tail calls"
723*da0073e9SAndroid Build Coastguard Worker                # (when it calls into user compiler), at which point the
724*da0073e9SAndroid Build Coastguard Worker                # parent exception frames would incorrectly attach an
725*da0073e9SAndroid Build Coastguard Worker                # incorrect frame.
726*da0073e9SAndroid Build Coastguard Worker                #
727*da0073e9SAndroid Build Coastguard Worker                # However, if, somehow, someone raised an exception with this
728*da0073e9SAndroid Build Coastguard Worker                # scope that had a stack (for example, because they are
729*da0073e9SAndroid Build Coastguard Worker                # restoring the user stack state appropriately as they process
730*da0073e9SAndroid Build Coastguard Worker                # node by node), we should respect it. Thus, we cannot
731*da0073e9SAndroid Build Coastguard Worker                # unconditionally set None.
732*da0073e9SAndroid Build Coastguard Worker                if not hasattr(e, "real_stack"):
733*da0073e9SAndroid Build Coastguard Worker                    e.real_stack = None  # type: ignore[attr-defined]
734*da0073e9SAndroid Build Coastguard Worker                raise
735*da0073e9SAndroid Build Coastguard Worker
736*da0073e9SAndroid Build Coastguard Worker    @staticmethod
737*da0073e9SAndroid Build Coastguard Worker    @contextlib.contextmanager
738*da0073e9SAndroid Build Coastguard Worker    def current_frame(frame_summary):
739*da0073e9SAndroid Build Coastguard Worker        # frame_summary can be None to solely take advantage of real_stack
740*da0073e9SAndroid Build Coastguard Worker        # attachment to thrown exceptions
741*da0073e9SAndroid Build Coastguard Worker        tc = TracingContext.get()
742*da0073e9SAndroid Build Coastguard Worker        if frame_summary is not None:
743*da0073e9SAndroid Build Coastguard Worker            tc.frame_summary_stack.append(frame_summary)
744*da0073e9SAndroid Build Coastguard Worker        old = tc.loc_in_frame
745*da0073e9SAndroid Build Coastguard Worker        tc.loc_in_frame = None
746*da0073e9SAndroid Build Coastguard Worker        try:
747*da0073e9SAndroid Build Coastguard Worker            yield
748*da0073e9SAndroid Build Coastguard Worker        except Exception as e:
749*da0073e9SAndroid Build Coastguard Worker            if not hasattr(e, "real_stack"):
750*da0073e9SAndroid Build Coastguard Worker                e.real_stack = tc.extract_stack()  # type: ignore[attr-defined]
751*da0073e9SAndroid Build Coastguard Worker            raise
752*da0073e9SAndroid Build Coastguard Worker        finally:
753*da0073e9SAndroid Build Coastguard Worker            if frame_summary is not None:
754*da0073e9SAndroid Build Coastguard Worker                tc.frame_summary_stack.pop()
755*da0073e9SAndroid Build Coastguard Worker            tc.loc_in_frame = old
756*da0073e9SAndroid Build Coastguard Worker
757*da0073e9SAndroid Build Coastguard Worker    @staticmethod
758*da0073e9SAndroid Build Coastguard Worker    @contextlib.contextmanager
759*da0073e9SAndroid Build Coastguard Worker    def report_output_strides():
760*da0073e9SAndroid Build Coastguard Worker        tc = TracingContext.try_get()
761*da0073e9SAndroid Build Coastguard Worker        if tc is None:
762*da0073e9SAndroid Build Coastguard Worker            yield None
763*da0073e9SAndroid Build Coastguard Worker            return
764*da0073e9SAndroid Build Coastguard Worker        old_output_strides = tc.output_strides
765*da0073e9SAndroid Build Coastguard Worker        tc.output_strides = []
766*da0073e9SAndroid Build Coastguard Worker        try:
767*da0073e9SAndroid Build Coastguard Worker            yield tc.output_strides
768*da0073e9SAndroid Build Coastguard Worker        finally:
769*da0073e9SAndroid Build Coastguard Worker            tc.output_strides = old_output_strides
770*da0073e9SAndroid Build Coastguard Worker
771*da0073e9SAndroid Build Coastguard Worker    @staticmethod
772*da0073e9SAndroid Build Coastguard Worker    def set_current_loc(filename, lineno, frame_name):
773*da0073e9SAndroid Build Coastguard Worker        TracingContext.get().loc_in_frame = traceback.FrameSummary(
774*da0073e9SAndroid Build Coastguard Worker            filename, lineno, frame_name, lookup_line=False
775*da0073e9SAndroid Build Coastguard Worker        )
776*da0073e9SAndroid Build Coastguard Worker
777*da0073e9SAndroid Build Coastguard Worker
778*da0073e9SAndroid Build Coastguard Worker@contextmanager
779*da0073e9SAndroid Build Coastguard Workerdef compile_context(context: Optional[CompileContext]):
780*da0073e9SAndroid Build Coastguard Worker    old_context = getattr(_TLS, "compile_context", None)
781*da0073e9SAndroid Build Coastguard Worker    _TLS.compile_context = context
782*da0073e9SAndroid Build Coastguard Worker    try:
783*da0073e9SAndroid Build Coastguard Worker        yield context
784*da0073e9SAndroid Build Coastguard Worker    finally:
785*da0073e9SAndroid Build Coastguard Worker        if context is not None:
786*da0073e9SAndroid Build Coastguard Worker            if context.compile_id is not None:
787*da0073e9SAndroid Build Coastguard Worker                set_context_frame(
788*da0073e9SAndroid Build Coastguard Worker                    (
789*da0073e9SAndroid Build Coastguard Worker                        context.compile_id.frame_id,
790*da0073e9SAndroid Build Coastguard Worker                        context.compile_id.frame_compile_id,
791*da0073e9SAndroid Build Coastguard Worker                        context.attempt,
792*da0073e9SAndroid Build Coastguard Worker                    )
793*da0073e9SAndroid Build Coastguard Worker                )
794*da0073e9SAndroid Build Coastguard Worker        _TLS.compile_context = old_context
795*da0073e9SAndroid Build Coastguard Worker
796*da0073e9SAndroid Build Coastguard Worker
797*da0073e9SAndroid Build Coastguard Worker@contextmanager
798*da0073e9SAndroid Build Coastguard Workerdef tracing(context: Optional[TracingContext]):
799*da0073e9SAndroid Build Coastguard Worker    """
800*da0073e9SAndroid Build Coastguard Worker    This function installs the passed in tracing context as a dynamic scoped
801*da0073e9SAndroid Build Coastguard Worker    global variable.
802*da0073e9SAndroid Build Coastguard Worker
803*da0073e9SAndroid Build Coastguard Worker    Calls to TracingContext.get() while not under a `with tracing()` context
804*da0073e9SAndroid Build Coastguard Worker    will return None.
805*da0073e9SAndroid Build Coastguard Worker    """
806*da0073e9SAndroid Build Coastguard Worker    old_context = getattr(_TLS, "tracing_context", None)
807*da0073e9SAndroid Build Coastguard Worker    _TLS.tracing_context = context
808*da0073e9SAndroid Build Coastguard Worker    try:
809*da0073e9SAndroid Build Coastguard Worker        yield context
810*da0073e9SAndroid Build Coastguard Worker    except Exception as e:
811*da0073e9SAndroid Build Coastguard Worker        if not hasattr(e, "real_stack") and context is not None:
812*da0073e9SAndroid Build Coastguard Worker            e.real_stack = context.extract_stack()  # type: ignore[attr-defined]
813*da0073e9SAndroid Build Coastguard Worker        raise
814*da0073e9SAndroid Build Coastguard Worker    finally:
815*da0073e9SAndroid Build Coastguard Worker        if (
816*da0073e9SAndroid Build Coastguard Worker            context is not None
817*da0073e9SAndroid Build Coastguard Worker            and context.fake_mode is not None
818*da0073e9SAndroid Build Coastguard Worker            and context.fake_mode.shape_env is not None
819*da0073e9SAndroid Build Coastguard Worker        ):
820*da0073e9SAndroid Build Coastguard Worker            context.fake_mode.shape_env.cleanup()
821*da0073e9SAndroid Build Coastguard Worker        _TLS.tracing_context = old_context
822*da0073e9SAndroid Build Coastguard Worker
823*da0073e9SAndroid Build Coastguard Worker
824*da0073e9SAndroid Build Coastguard Worker# Subclasses can be found in torch/_dynamo/source.py
825*da0073e9SAndroid Build Coastguard Worker# TODO(voz): Consider a toplevel torch/_source.py
826*da0073e9SAndroid Build Coastguard Worker@dataclasses.dataclass(frozen=True)
827*da0073e9SAndroid Build Coastguard Workerclass Source:
828*da0073e9SAndroid Build Coastguard Worker    def is_dict_key(self):
829*da0073e9SAndroid Build Coastguard Worker        return False
830*da0073e9SAndroid Build Coastguard Worker
831*da0073e9SAndroid Build Coastguard Worker    def is_ephemeral(self):
832*da0073e9SAndroid Build Coastguard Worker        return False
833*da0073e9SAndroid Build Coastguard Worker
834*da0073e9SAndroid Build Coastguard Worker    def reconstruct(self, codegen):
835*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
836*da0073e9SAndroid Build Coastguard Worker
837*da0073e9SAndroid Build Coastguard Worker    def guard_source(self) -> GuardSource:
838*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
839*da0073e9SAndroid Build Coastguard Worker
840*da0073e9SAndroid Build Coastguard Worker    def name(self) -> str:
841*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
842*da0073e9SAndroid Build Coastguard Worker
843*da0073e9SAndroid Build Coastguard Worker    def make_guard(self, fn) -> Guard:
844*da0073e9SAndroid Build Coastguard Worker        if self.guard_source() is GuardSource.CONSTANT:
845*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError
846*da0073e9SAndroid Build Coastguard Worker        return Guard(self, fn)
847*da0073e9SAndroid Build Coastguard Worker
848*da0073e9SAndroid Build Coastguard Worker    def is_specialized_nn_module(self) -> bool:
849*da0073e9SAndroid Build Coastguard Worker        return self.guard_source().is_specialized_nn_module()
850*da0073e9SAndroid Build Coastguard Worker
851*da0073e9SAndroid Build Coastguard Worker    def subguards_allowed(self):
852*da0073e9SAndroid Build Coastguard Worker        """True if you can guard on attributes of this"""
853*da0073e9SAndroid Build Coastguard Worker        return self.guard_source() != GuardSource.SYNTHETIC_LOCAL
854*da0073e9SAndroid Build Coastguard Worker
855*da0073e9SAndroid Build Coastguard Worker
856*da0073e9SAndroid Build Coastguard Worker# Subclasses can be found in torch/_dynamo/source.py
857*da0073e9SAndroid Build Coastguard Worker@dataclasses.dataclass(frozen=True)
858*da0073e9SAndroid Build Coastguard Workerclass ChainedSource(Source):
859*da0073e9SAndroid Build Coastguard Worker    base: Source
860*da0073e9SAndroid Build Coastguard Worker
861*da0073e9SAndroid Build Coastguard Worker    def is_dict_key(self):
862*da0073e9SAndroid Build Coastguard Worker        # Recurse until you either hit a ConstDictKey or a Source
863*da0073e9SAndroid Build Coastguard Worker        return self.base.is_dict_key()
864*da0073e9SAndroid Build Coastguard Worker
865*da0073e9SAndroid Build Coastguard Worker    def is_ephemeral(self):
866*da0073e9SAndroid Build Coastguard Worker        return self.base.is_ephemeral()
867*da0073e9SAndroid Build Coastguard Worker
868*da0073e9SAndroid Build Coastguard Worker
869*da0073e9SAndroid Build Coastguard Workerdef detect_fake_mode(inputs: Any = None):
870*da0073e9SAndroid Build Coastguard Worker    """
871*da0073e9SAndroid Build Coastguard Worker    Attempts to "detect" what the current fake mode is.  If there is one ambiently
872*da0073e9SAndroid Build Coastguard Worker    available from TracingContext, we preferentially use that.  Otherwise, we
873*da0073e9SAndroid Build Coastguard Worker    heuristically detect the fake mode via the following sources, in order of
874*da0073e9SAndroid Build Coastguard Worker    priority:
875*da0073e9SAndroid Build Coastguard Worker
876*da0073e9SAndroid Build Coastguard Worker        - Currently active fake mode on stack
877*da0073e9SAndroid Build Coastguard Worker        - Fake mode associated with passed in tensors (inputs does not
878*da0073e9SAndroid Build Coastguard Worker          have to be flattened)
879*da0073e9SAndroid Build Coastguard Worker    """
880*da0073e9SAndroid Build Coastguard Worker    from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
881*da0073e9SAndroid Build Coastguard Worker
882*da0073e9SAndroid Build Coastguard Worker    fake_modes = []
883*da0073e9SAndroid Build Coastguard Worker
884*da0073e9SAndroid Build Coastguard Worker    if context := TracingContext.try_get():
885*da0073e9SAndroid Build Coastguard Worker        fake_mode = context.fake_mode
886*da0073e9SAndroid Build Coastguard Worker        if fake_mode is not None:
887*da0073e9SAndroid Build Coastguard Worker            fake_modes.append((fake_mode, "tracing context", 0))
888*da0073e9SAndroid Build Coastguard Worker
889*da0073e9SAndroid Build Coastguard Worker    from torch.utils._python_dispatch import _get_current_dispatch_mode_stack
890*da0073e9SAndroid Build Coastguard Worker
891*da0073e9SAndroid Build Coastguard Worker    for i, m in enumerate(reversed(_get_current_dispatch_mode_stack())):
892*da0073e9SAndroid Build Coastguard Worker        if isinstance(m, FakeTensorMode):
893*da0073e9SAndroid Build Coastguard Worker            fake_modes.append((m, "active fake mode", i))
894*da0073e9SAndroid Build Coastguard Worker
895*da0073e9SAndroid Build Coastguard Worker    flat_inputs = pytree.tree_leaves(inputs)
896*da0073e9SAndroid Build Coastguard Worker    for i, flat_input in enumerate(flat_inputs):
897*da0073e9SAndroid Build Coastguard Worker        if isinstance(flat_input, FakeTensor):
898*da0073e9SAndroid Build Coastguard Worker            fake_modes.append((flat_input.fake_mode, "fake tensor input", i))
899*da0073e9SAndroid Build Coastguard Worker
900*da0073e9SAndroid Build Coastguard Worker    if fake_modes:
901*da0073e9SAndroid Build Coastguard Worker        fake_mode, desc1, i1 = fake_modes[0]
902*da0073e9SAndroid Build Coastguard Worker        for m, desc2, i2 in fake_modes[1:]:
903*da0073e9SAndroid Build Coastguard Worker            assert fake_mode is m, (
904*da0073e9SAndroid Build Coastguard Worker                f"fake mode ({fake_mode}) from {desc1} {i1} doesn't match mode ({m}) from {desc2} {i2}\n\n"
905*da0073e9SAndroid Build Coastguard Worker                f"fake mode from {desc1} {i1} allocated at:\n{fake_mode.stack}\n"
906*da0073e9SAndroid Build Coastguard Worker                f"fake mode from {desc2} {i2} allocated at:\n{m.stack}"
907*da0073e9SAndroid Build Coastguard Worker            )
908*da0073e9SAndroid Build Coastguard Worker        return fake_mode
909*da0073e9SAndroid Build Coastguard Worker    else:
910*da0073e9SAndroid Build Coastguard Worker        return None
911*da0073e9SAndroid Build Coastguard Worker
912*da0073e9SAndroid Build Coastguard Worker
913*da0073e9SAndroid Build Coastguard Workerdef active_fake_mode():
914*da0073e9SAndroid Build Coastguard Worker    """
915*da0073e9SAndroid Build Coastguard Worker    Inspects the dispatch mode stack for an active fake mode and returns it.
916*da0073e9SAndroid Build Coastguard Worker    Returns None if no fake mode is active.
917*da0073e9SAndroid Build Coastguard Worker    """
918*da0073e9SAndroid Build Coastguard Worker    from torch._subclasses.fake_tensor import FakeTensorMode
919*da0073e9SAndroid Build Coastguard Worker    from torch.utils._python_dispatch import _get_current_dispatch_mode_stack
920*da0073e9SAndroid Build Coastguard Worker
921*da0073e9SAndroid Build Coastguard Worker    for _, m in enumerate(reversed(_get_current_dispatch_mode_stack())):
922*da0073e9SAndroid Build Coastguard Worker        if isinstance(m, FakeTensorMode):
923*da0073e9SAndroid Build Coastguard Worker            return m
924*da0073e9SAndroid Build Coastguard Worker
925*da0073e9SAndroid Build Coastguard Worker    return None
926