1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerimport contextlib 5*da0073e9SAndroid Build Coastguard Workerimport dataclasses 6*da0073e9SAndroid Build Coastguard Workerimport enum 7*da0073e9SAndroid Build Coastguard Workerimport functools 8*da0073e9SAndroid Build Coastguard Workerimport logging 9*da0073e9SAndroid Build Coastguard Workerimport threading 10*da0073e9SAndroid Build Coastguard Workerimport traceback 11*da0073e9SAndroid Build Coastguard Workerimport unittest.mock 12*da0073e9SAndroid Build Coastguard Workerimport weakref 13*da0073e9SAndroid Build Coastguard Workerfrom abc import abstractmethod 14*da0073e9SAndroid Build Coastguard Workerfrom contextlib import contextmanager 15*da0073e9SAndroid Build Coastguard Workerfrom typing import ( 16*da0073e9SAndroid Build Coastguard Worker Any, 17*da0073e9SAndroid Build Coastguard Worker Callable, 18*da0073e9SAndroid Build Coastguard Worker Dict, 19*da0073e9SAndroid Build Coastguard Worker Generic, 20*da0073e9SAndroid Build Coastguard Worker List, 21*da0073e9SAndroid Build Coastguard Worker NamedTuple, 22*da0073e9SAndroid Build Coastguard Worker Optional, 23*da0073e9SAndroid Build Coastguard Worker Set, 24*da0073e9SAndroid Build Coastguard Worker Tuple, 25*da0073e9SAndroid Build Coastguard Worker TYPE_CHECKING, 26*da0073e9SAndroid Build Coastguard Worker TypeVar, 27*da0073e9SAndroid Build Coastguard Worker) 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Workerfrom torch._C._dynamo.eval_frame import set_context_frame # noqa: F401 30*da0073e9SAndroid Build Coastguard Workerfrom torch.utils import _pytree as pytree 31*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._traceback import CapturedTraceback 32*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.weak import WeakTensorKeyDictionary 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Workerlog = logging.getLogger(__name__) 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Workerif TYPE_CHECKING: 39*da0073e9SAndroid Build Coastguard Worker import sympy 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Worker # Import the following modules during type checking to enable code intelligence features, 42*da0073e9SAndroid Build Coastguard Worker # such as auto-completion in tools like pylance, even when these modules are not explicitly 43*da0073e9SAndroid Build Coastguard Worker # imported in user code. 44*da0073e9SAndroid Build Coastguard Worker import torch 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Worker""" 48*da0073e9SAndroid Build Coastguard Workertorch._guards is the definitional source of truth for general purpose guard structures. 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard WorkerAn important thing to keep in mind here is the preservation of layering. There should be no dynamo notions, 51*da0073e9SAndroid Build Coastguard Workerand no guard installation notions here. 52*da0073e9SAndroid Build Coastguard Worker""" 53*da0073e9SAndroid Build Coastguard Worker 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Workerclass CompileId(NamedTuple): 56*da0073e9SAndroid Build Coastguard Worker frame_id: int 57*da0073e9SAndroid Build Coastguard Worker # This id is per-frame, and counts how many times we've compiled this 58*da0073e9SAndroid Build Coastguard Worker # frame. This could have been a global id but having this be per-frame 59*da0073e9SAndroid Build Coastguard Worker # gives you a better intuitive sense for how many recompiles have occurred 60*da0073e9SAndroid Build Coastguard Worker # so far. 61*da0073e9SAndroid Build Coastguard Worker frame_compile_id: int 62*da0073e9SAndroid Build Coastguard Worker # TODO: consider also tracking the recompilation count 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker def __str__(self): 65*da0073e9SAndroid Build Coastguard Worker return f"{self.frame_id}/{self.frame_compile_id}" 66*da0073e9SAndroid Build Coastguard Worker 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Workerclass TraceId(NamedTuple): 69*da0073e9SAndroid Build Coastguard Worker compile_id: CompileId 70*da0073e9SAndroid Build Coastguard Worker # This starts off as 0, and every time we restart analysis it goes 71*da0073e9SAndroid Build Coastguard Worker # up by one 72*da0073e9SAndroid Build Coastguard Worker attempt: int 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Worker def __str__(self): 75*da0073e9SAndroid Build Coastguard Worker if self.attempt == 0: 76*da0073e9SAndroid Build Coastguard Worker return str(self.compile_id) 77*da0073e9SAndroid Build Coastguard Worker else: 78*da0073e9SAndroid Build Coastguard Worker return f"{self.compile_id}_{self.attempt}" 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Workerclass GuardSource(enum.Enum): 82*da0073e9SAndroid Build Coastguard Worker LOCAL = 0 83*da0073e9SAndroid Build Coastguard Worker GLOBAL = 1 84*da0073e9SAndroid Build Coastguard Worker LOCAL_SPECIALIZED_NN_MODULE = 2 85*da0073e9SAndroid Build Coastguard Worker GLOBAL_SPECIALIZED_NN_MODULE = 3 86*da0073e9SAndroid Build Coastguard Worker CONSTANT = 4 87*da0073e9SAndroid Build Coastguard Worker RANDOM_VALUE = 5 88*da0073e9SAndroid Build Coastguard Worker SHAPE_ENV = 6 89*da0073e9SAndroid Build Coastguard Worker LOCAL_FSDP_MODULE = 7 90*da0073e9SAndroid Build Coastguard Worker GLOBAL_FSDP_MODULE = 8 91*da0073e9SAndroid Build Coastguard Worker BACKWARD_STATE = 9 92*da0073e9SAndroid Build Coastguard Worker EPHEMERAL = 10 93*da0073e9SAndroid Build Coastguard Worker SYNTHETIC_LOCAL = 11 94*da0073e9SAndroid Build Coastguard Worker LOCAL_UNSPECIALIZED_NN_MODULE = 12 95*da0073e9SAndroid Build Coastguard Worker GLOBAL_UNSPECIALIZED_NN_MODULE = 13 96*da0073e9SAndroid Build Coastguard Worker LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE = 14 97*da0073e9SAndroid Build Coastguard Worker GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE = 15 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard Worker def is_fsdp_module(self) -> bool: 100*da0073e9SAndroid Build Coastguard Worker return self in (GuardSource.GLOBAL_FSDP_MODULE, GuardSource.LOCAL_FSDP_MODULE) 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Worker def is_specialized_nn_module(self) -> bool: 103*da0073e9SAndroid Build Coastguard Worker return ( 104*da0073e9SAndroid Build Coastguard Worker self 105*da0073e9SAndroid Build Coastguard Worker in ( 106*da0073e9SAndroid Build Coastguard Worker GuardSource.GLOBAL_SPECIALIZED_NN_MODULE, 107*da0073e9SAndroid Build Coastguard Worker GuardSource.LOCAL_SPECIALIZED_NN_MODULE, 108*da0073e9SAndroid Build Coastguard Worker ) 109*da0073e9SAndroid Build Coastguard Worker # TODO (anijain2305) - Investigate why is_fsdp_module required. 110*da0073e9SAndroid Build Coastguard Worker or self.is_fsdp_module() 111*da0073e9SAndroid Build Coastguard Worker ) 112*da0073e9SAndroid Build Coastguard Worker 113*da0073e9SAndroid Build Coastguard Worker def is_unspecialized_nn_module(self) -> bool: 114*da0073e9SAndroid Build Coastguard Worker return self in ( 115*da0073e9SAndroid Build Coastguard Worker GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE, 116*da0073e9SAndroid Build Coastguard Worker GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE, 117*da0073e9SAndroid Build Coastguard Worker GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE, 118*da0073e9SAndroid Build Coastguard Worker GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE, 119*da0073e9SAndroid Build Coastguard Worker ) 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker def is_unspecialized_builtin_nn_module(self) -> bool: 122*da0073e9SAndroid Build Coastguard Worker return self in ( 123*da0073e9SAndroid Build Coastguard Worker GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE, 124*da0073e9SAndroid Build Coastguard Worker GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE, 125*da0073e9SAndroid Build Coastguard Worker ) 126*da0073e9SAndroid Build Coastguard Worker 127*da0073e9SAndroid Build Coastguard Worker def is_local(self): 128*da0073e9SAndroid Build Coastguard Worker return self in ( 129*da0073e9SAndroid Build Coastguard Worker GuardSource.LOCAL, 130*da0073e9SAndroid Build Coastguard Worker GuardSource.LOCAL_SPECIALIZED_NN_MODULE, 131*da0073e9SAndroid Build Coastguard Worker GuardSource.LOCAL_FSDP_MODULE, 132*da0073e9SAndroid Build Coastguard Worker GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE, 133*da0073e9SAndroid Build Coastguard Worker GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE, 134*da0073e9SAndroid Build Coastguard Worker ) 135*da0073e9SAndroid Build Coastguard Worker 136*da0073e9SAndroid Build Coastguard Worker 137*da0073e9SAndroid Build Coastguard Worker""" 138*da0073e9SAndroid Build Coastguard WorkerBase class for a "GuardBuilder" role. 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard WorkerThe GuardBuilderBase role is to represent a scope within which to build a guard. The name is a little 141*da0073e9SAndroid Build Coastguard Workerconfusing, as its not a builder, but for the sake of avoiding a lot of renames and keeping the original reference 142*da0073e9SAndroid Build Coastguard Workerto torchdynamo's GuardBuilder. 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard WorkerNote: create_fn is invoked with a GuardBuilderBase and a Guard. A GuardBuilder is chosen based 145*da0073e9SAndroid Build Coastguard Workeron GuardSource's select function. 146*da0073e9SAndroid Build Coastguard Worker 147*da0073e9SAndroid Build Coastguard WorkerThere is value in keeping this GuardBuilderBase empty to keep layering clean. 148*da0073e9SAndroid Build Coastguard Worker""" 149*da0073e9SAndroid Build Coastguard Worker 150*da0073e9SAndroid Build Coastguard Worker 151*da0073e9SAndroid Build Coastguard Workerclass GuardBuilderBase: 152*da0073e9SAndroid Build Coastguard Worker pass 153*da0073e9SAndroid Build Coastguard Worker 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Workerclass ShapeGuard(NamedTuple): 156*da0073e9SAndroid Build Coastguard Worker expr: sympy.Expr 157*da0073e9SAndroid Build Coastguard Worker stack: CapturedTraceback 158*da0073e9SAndroid Build Coastguard Worker 159*da0073e9SAndroid Build Coastguard Worker 160*da0073e9SAndroid Build Coastguard Worker@dataclasses.dataclass 161*da0073e9SAndroid Build Coastguard Workerclass Guard: 162*da0073e9SAndroid Build Coastguard Worker # originating_source is the source that called the make_guard method to 163*da0073e9SAndroid Build Coastguard Worker # construct this guard object. The property name specifies what exactly it 164*da0073e9SAndroid Build Coastguard Worker # is the guard is guarding on. The meaning of the name is dependent on the 165*da0073e9SAndroid Build Coastguard Worker # create_fn; you must look at the use-site inside create_fn to know what 166*da0073e9SAndroid Build Coastguard Worker # name means. 167*da0073e9SAndroid Build Coastguard Worker # 168*da0073e9SAndroid Build Coastguard Worker # That being said, although you might think this is just a "name", name is 169*da0073e9SAndroid Build Coastguard Worker # usually an arbitrary Python expression that will be evaluated with all 170*da0073e9SAndroid Build Coastguard Worker # globals (and locals, if you create a LOCAL guard) to extract the Python 171*da0073e9SAndroid Build Coastguard Worker # object that we want to perform guard tests on. This evaluation 172*da0073e9SAndroid Build Coastguard Worker # typically happens in GuardBuilder.eval. In these cases, name is 173*da0073e9SAndroid Build Coastguard Worker # typically produced by originating_source.name() (not to be confused with 174*da0073e9SAndroid Build Coastguard Worker # GuardSource - the property source). 175*da0073e9SAndroid Build Coastguard Worker # 176*da0073e9SAndroid Build Coastguard Worker # Occasionally, name is not a valid Python expression; sometimes 177*da0073e9SAndroid Build Coastguard Worker # it is meaningless. Example create_fns that are like this include 178*da0073e9SAndroid Build Coastguard Worker # GRAD_MODE and SHAPE_ENV. 179*da0073e9SAndroid Build Coastguard Worker originating_source: Source 180*da0073e9SAndroid Build Coastguard Worker create_fn: Callable[[GuardBuilderBase, Guard], None] 181*da0073e9SAndroid Build Coastguard Worker 182*da0073e9SAndroid Build Coastguard Worker # Export only. These values are written to at time of guard check_fn creation. 183*da0073e9SAndroid Build Coastguard Worker guard_types: Optional[List[str]] = None 184*da0073e9SAndroid Build Coastguard Worker code_list: Optional[List[str]] = None 185*da0073e9SAndroid Build Coastguard Worker obj_weakref: Optional[object] = None 186*da0073e9SAndroid Build Coastguard Worker guarded_class_weakref: Optional[type] = None 187*da0073e9SAndroid Build Coastguard Worker 188*da0073e9SAndroid Build Coastguard Worker stack: Optional[CapturedTraceback] = None 189*da0073e9SAndroid Build Coastguard Worker user_stack: Optional[traceback.StackSummary] = None 190*da0073e9SAndroid Build Coastguard Worker _hash: Optional[int] = None 191*da0073e9SAndroid Build Coastguard Worker 192*da0073e9SAndroid Build Coastguard Worker def __hash__(self): 193*da0073e9SAndroid Build Coastguard Worker if self._hash is None: 194*da0073e9SAndroid Build Coastguard Worker self._hash = hash((self.name, self.source, id(self.create_fn))) 195*da0073e9SAndroid Build Coastguard Worker return self._hash 196*da0073e9SAndroid Build Coastguard Worker 197*da0073e9SAndroid Build Coastguard Worker def sort_key(self): 198*da0073e9SAndroid Build Coastguard Worker # Put the duplicate input guards at the end. The duplicate guards have 199*da0073e9SAndroid Build Coastguard Worker # two sources while guard.name only considers one source. 200*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.guards import GuardBuilder 201*da0073e9SAndroid Build Coastguard Worker 202*da0073e9SAndroid Build Coastguard Worker is_duplicate_input = ( 203*da0073e9SAndroid Build Coastguard Worker isinstance(self.create_fn, functools.partial) 204*da0073e9SAndroid Build Coastguard Worker and self.create_fn.func is GuardBuilder.DUPLICATE_INPUT 205*da0073e9SAndroid Build Coastguard Worker ) 206*da0073e9SAndroid Build Coastguard Worker return ( 207*da0073e9SAndroid Build Coastguard Worker is_duplicate_input, 208*da0073e9SAndroid Build Coastguard Worker self.source.value if self.source else -1, 209*da0073e9SAndroid Build Coastguard Worker len(self.name), 210*da0073e9SAndroid Build Coastguard Worker self.name, 211*da0073e9SAndroid Build Coastguard Worker self.inner_create_fn().__code__.co_firstlineno, 212*da0073e9SAndroid Build Coastguard Worker ) 213*da0073e9SAndroid Build Coastguard Worker 214*da0073e9SAndroid Build Coastguard Worker def __lt__(self, other): 215*da0073e9SAndroid Build Coastguard Worker return self.sort_key() < other.sort_key() 216*da0073e9SAndroid Build Coastguard Worker 217*da0073e9SAndroid Build Coastguard Worker def inner_create_fn(self): 218*da0073e9SAndroid Build Coastguard Worker if isinstance(self.create_fn, functools.partial): 219*da0073e9SAndroid Build Coastguard Worker return self.create_fn.func 220*da0073e9SAndroid Build Coastguard Worker else: 221*da0073e9SAndroid Build Coastguard Worker return self.create_fn 222*da0073e9SAndroid Build Coastguard Worker 223*da0073e9SAndroid Build Coastguard Worker @property 224*da0073e9SAndroid Build Coastguard Worker def name(self) -> str: 225*da0073e9SAndroid Build Coastguard Worker return self.originating_source.name() 226*da0073e9SAndroid Build Coastguard Worker 227*da0073e9SAndroid Build Coastguard Worker @property 228*da0073e9SAndroid Build Coastguard Worker def source(self) -> GuardSource: 229*da0073e9SAndroid Build Coastguard Worker return self.originating_source.guard_source() 230*da0073e9SAndroid Build Coastguard Worker 231*da0073e9SAndroid Build Coastguard Worker @staticmethod 232*da0073e9SAndroid Build Coastguard Worker def weakref_to_str(obj_weakref): 233*da0073e9SAndroid Build Coastguard Worker """ 234*da0073e9SAndroid Build Coastguard Worker This is a workaround of a Python weakref bug. 235*da0073e9SAndroid Build Coastguard Worker 236*da0073e9SAndroid Build Coastguard Worker `obj_weakref` is instance returned by `weakref.ref`, 237*da0073e9SAndroid Build Coastguard Worker `str(obj_weakref)` is buggy if the original obj overrides __getattr__, e.g: 238*da0073e9SAndroid Build Coastguard Worker 239*da0073e9SAndroid Build Coastguard Worker class MyConfig(dict): 240*da0073e9SAndroid Build Coastguard Worker def __getattr__(self, x): 241*da0073e9SAndroid Build Coastguard Worker return self[x] 242*da0073e9SAndroid Build Coastguard Worker 243*da0073e9SAndroid Build Coastguard Worker obj = MyConfig(offset=5) 244*da0073e9SAndroid Build Coastguard Worker obj_weakref = weakref.ref(obj) 245*da0073e9SAndroid Build Coastguard Worker str(obj_weakref) # raise error: KeyError: '__name__' 246*da0073e9SAndroid Build Coastguard Worker """ 247*da0073e9SAndroid Build Coastguard Worker if isinstance(obj_weakref, weakref.ReferenceType): 248*da0073e9SAndroid Build Coastguard Worker obj = obj_weakref() 249*da0073e9SAndroid Build Coastguard Worker if obj is not None: 250*da0073e9SAndroid Build Coastguard Worker return f"<weakref at {hex(id(obj_weakref))}; to '{obj.__class__.__name__}' at {hex(id(obj))}>" 251*da0073e9SAndroid Build Coastguard Worker else: 252*da0073e9SAndroid Build Coastguard Worker return f"<weakref at {hex(id(obj_weakref))}; dead>" 253*da0073e9SAndroid Build Coastguard Worker else: 254*da0073e9SAndroid Build Coastguard Worker return str(obj_weakref) 255*da0073e9SAndroid Build Coastguard Worker 256*da0073e9SAndroid Build Coastguard Worker def __repr__(self): 257*da0073e9SAndroid Build Coastguard Worker s = f""" 258*da0073e9SAndroid Build Coastguard Worker {self.source.name.lower() if self.source else ""} {repr(self.name)} {self.inner_create_fn().__name__} 259*da0073e9SAndroid Build Coastguard Worker {{ 260*da0073e9SAndroid Build Coastguard Worker 'guard_types': {self.guard_types}, 261*da0073e9SAndroid Build Coastguard Worker 'code': {self.code_list}, 262*da0073e9SAndroid Build Coastguard Worker 'obj_weakref': {self.weakref_to_str(self.obj_weakref)} 263*da0073e9SAndroid Build Coastguard Worker 'guarded_class': {self.guarded_class_weakref} 264*da0073e9SAndroid Build Coastguard Worker }} 265*da0073e9SAndroid Build Coastguard Worker """ 266*da0073e9SAndroid Build Coastguard Worker return s 267*da0073e9SAndroid Build Coastguard Worker 268*da0073e9SAndroid Build Coastguard Worker def __str__(self): 269*da0073e9SAndroid Build Coastguard Worker output = f"Name: {repr(self.name)}\n" 270*da0073e9SAndroid Build Coastguard Worker source = self.source.name.lower() if self.source else "" 271*da0073e9SAndroid Build Coastguard Worker output += f" Source: {source}\n" 272*da0073e9SAndroid Build Coastguard Worker output += f" Create Function: {self.inner_create_fn().__name__}\n" 273*da0073e9SAndroid Build Coastguard Worker output += f" Guard Types: {self.guard_types}\n" 274*da0073e9SAndroid Build Coastguard Worker output += f" Code List: {self.code_list}\n" 275*da0073e9SAndroid Build Coastguard Worker output += f" Object Weakref: {self.weakref_to_str(self.obj_weakref)}\n" 276*da0073e9SAndroid Build Coastguard Worker output += f" Guarded Class Weakref: {self.guarded_class_weakref}\n" 277*da0073e9SAndroid Build Coastguard Worker return output 278*da0073e9SAndroid Build Coastguard Worker 279*da0073e9SAndroid Build Coastguard Worker def create(self, builder: GuardBuilderBase): 280*da0073e9SAndroid Build Coastguard Worker try: 281*da0073e9SAndroid Build Coastguard Worker return self.create_fn(builder, self) 282*da0073e9SAndroid Build Coastguard Worker except Exception: 283*da0073e9SAndroid Build Coastguard Worker log.exception("Error while creating guard:\n%s", str(self).rstrip()) 284*da0073e9SAndroid Build Coastguard Worker if self.stack: 285*da0073e9SAndroid Build Coastguard Worker log.error("Created at:\n%s", "".join(self.stack.format()[-4:]).rstrip()) 286*da0073e9SAndroid Build Coastguard Worker raise 287*da0073e9SAndroid Build Coastguard Worker 288*da0073e9SAndroid Build Coastguard Worker def is_specialized_nn_module(self): 289*da0073e9SAndroid Build Coastguard Worker return self.source.is_specialized_nn_module() 290*da0073e9SAndroid Build Coastguard Worker 291*da0073e9SAndroid Build Coastguard Worker def is_fsdp_module(self): 292*da0073e9SAndroid Build Coastguard Worker return self.source.is_fsdp_module() 293*da0073e9SAndroid Build Coastguard Worker 294*da0073e9SAndroid Build Coastguard Worker def is_local(self): 295*da0073e9SAndroid Build Coastguard Worker return self.source.is_local() 296*da0073e9SAndroid Build Coastguard Worker 297*da0073e9SAndroid Build Coastguard Worker def set_export_info(self, guard_type, guarded_class, code_list, obj_weakref): 298*da0073e9SAndroid Build Coastguard Worker if not self.guard_types: 299*da0073e9SAndroid Build Coastguard Worker self.guard_types = [] 300*da0073e9SAndroid Build Coastguard Worker 301*da0073e9SAndroid Build Coastguard Worker self.guard_types.append(guard_type) 302*da0073e9SAndroid Build Coastguard Worker 303*da0073e9SAndroid Build Coastguard Worker assert self.guarded_class_weakref in ( 304*da0073e9SAndroid Build Coastguard Worker guarded_class, 305*da0073e9SAndroid Build Coastguard Worker None, 306*da0073e9SAndroid Build Coastguard Worker ), "Guarded class id must be identical, or None" 307*da0073e9SAndroid Build Coastguard Worker self.guarded_class_weakref = guarded_class 308*da0073e9SAndroid Build Coastguard Worker 309*da0073e9SAndroid Build Coastguard Worker if not self.code_list: 310*da0073e9SAndroid Build Coastguard Worker self.code_list = code_list 311*da0073e9SAndroid Build Coastguard Worker else: 312*da0073e9SAndroid Build Coastguard Worker self.code_list.extend(code_list) 313*da0073e9SAndroid Build Coastguard Worker 314*da0073e9SAndroid Build Coastguard Worker # Some objects are ephemeral, e.g., list[slice(1, 2)]. If we have 315*da0073e9SAndroid Build Coastguard Worker # multiple guards on the same object, the weakref can die between the 316*da0073e9SAndroid Build Coastguard Worker # invocation of set_export_info calls. So a dead weakref is also 317*da0073e9SAndroid Build Coastguard Worker # acceptable. 318*da0073e9SAndroid Build Coastguard Worker assert ( 319*da0073e9SAndroid Build Coastguard Worker self.obj_weakref in (obj_weakref, None) 320*da0073e9SAndroid Build Coastguard Worker or callable(self.obj_weakref) 321*da0073e9SAndroid Build Coastguard Worker and self.obj_weakref() is None 322*da0073e9SAndroid Build Coastguard Worker ), "Guarded object must be identical, None or ephemeral (dead weakref)" 323*da0073e9SAndroid Build Coastguard Worker self.obj_weakref = obj_weakref 324*da0073e9SAndroid Build Coastguard Worker 325*da0073e9SAndroid Build Coastguard Worker 326*da0073e9SAndroid Build Coastguard WorkerT = TypeVar("T") 327*da0073e9SAndroid Build Coastguard Worker 328*da0073e9SAndroid Build Coastguard Worker""" 329*da0073e9SAndroid Build Coastguard WorkerParent structure for guard env expressions. 330*da0073e9SAndroid Build Coastguard WorkerA GuardEnvExpr can have any subtype. 331*da0073e9SAndroid Build Coastguard WorkerNote: All subtypes must be handled exhaustively in 332*da0073e9SAndroid Build Coastguard Workertorch._dynamo.guards._parse_guard_env_guards to avoid a RuntimeError. 333*da0073e9SAndroid Build Coastguard Worker""" 334*da0073e9SAndroid Build Coastguard Worker 335*da0073e9SAndroid Build Coastguard Worker 336*da0073e9SAndroid Build Coastguard Worker@dataclasses.dataclass 337*da0073e9SAndroid Build Coastguard Workerclass GuardEnvExpr: 338*da0073e9SAndroid Build Coastguard Worker pass 339*da0073e9SAndroid Build Coastguard Worker 340*da0073e9SAndroid Build Coastguard Worker 341*da0073e9SAndroid Build Coastguard Worker""" 342*da0073e9SAndroid Build Coastguard WorkerA class representing a pair of duplicate inputs. 343*da0073e9SAndroid Build Coastguard Workerinput_pos_a and input_pos_b are input positions we have deduped. 344*da0073e9SAndroid Build Coastguard Worker""" 345*da0073e9SAndroid Build Coastguard Worker 346*da0073e9SAndroid Build Coastguard Worker 347*da0073e9SAndroid Build Coastguard Worker@dataclasses.dataclass 348*da0073e9SAndroid Build Coastguard Workerclass DuplicateInputs(GuardEnvExpr): 349*da0073e9SAndroid Build Coastguard Worker input_source_a: Source 350*da0073e9SAndroid Build Coastguard Worker input_source_b: Source 351*da0073e9SAndroid Build Coastguard Worker 352*da0073e9SAndroid Build Coastguard Worker def __post_init__(self): 353*da0073e9SAndroid Build Coastguard Worker assert self.input_source_a != self.input_source_b 354*da0073e9SAndroid Build Coastguard Worker 355*da0073e9SAndroid Build Coastguard Worker 356*da0073e9SAndroid Build Coastguard Worker""" 357*da0073e9SAndroid Build Coastguard WorkerCheckpointable is an interface for driving state snapshotting, left purposely vague for now. 358*da0073e9SAndroid Build Coastguard Worker 359*da0073e9SAndroid Build Coastguard Workercopy_graphstate() -> T, a somewhat legacy name, is expected to emit a snapshot of any type that 360*da0073e9SAndroid Build Coastguard Workercan also be taken in at restore_graphstate(T) calls. 361*da0073e9SAndroid Build Coastguard Worker 362*da0073e9SAndroid Build Coastguard WorkerWhen to snapshot, is, at the moment, an implementation detail of upstream callers. Checkpointable 363*da0073e9SAndroid Build Coastguard Workerdoes not provide any garuantees around consistency, idempotency, or safety of calling its APIs, yet. 364*da0073e9SAndroid Build Coastguard Worker 365*da0073e9SAndroid Build Coastguard WorkerIn the future, it will have a closer coupling to a generic Checkpoint management system. 366*da0073e9SAndroid Build Coastguard Worker""" 367*da0073e9SAndroid Build Coastguard Worker 368*da0073e9SAndroid Build Coastguard Worker 369*da0073e9SAndroid Build Coastguard Workerclass Checkpointable(Generic[T]): 370*da0073e9SAndroid Build Coastguard Worker @abstractmethod 371*da0073e9SAndroid Build Coastguard Worker def copy_graphstate(self) -> T: ... 372*da0073e9SAndroid Build Coastguard Worker 373*da0073e9SAndroid Build Coastguard Worker @abstractmethod 374*da0073e9SAndroid Build Coastguard Worker def restore_graphstate(self, state: T): ... 375*da0073e9SAndroid Build Coastguard Worker 376*da0073e9SAndroid Build Coastguard Worker 377*da0073e9SAndroid Build Coastguard Workerclass GuardsCheckpointState: 378*da0073e9SAndroid Build Coastguard Worker """ 379*da0073e9SAndroid Build Coastguard Worker The GuardCheckpointState - it is the T of Checkpointable[T] for GuardsContext 380*da0073e9SAndroid Build Coastguard Worker """ 381*da0073e9SAndroid Build Coastguard Worker 382*da0073e9SAndroid Build Coastguard Worker dynamo_guards: Set[Guard] = set() 383*da0073e9SAndroid Build Coastguard Worker 384*da0073e9SAndroid Build Coastguard Worker def __init__(self, dynamo_guards): 385*da0073e9SAndroid Build Coastguard Worker self.dynamo_guards = dynamo_guards 386*da0073e9SAndroid Build Coastguard Worker 387*da0073e9SAndroid Build Coastguard Worker def diff(self, other): 388*da0073e9SAndroid Build Coastguard Worker """ 389*da0073e9SAndroid Build Coastguard Worker Produces a delta against another GuardsCheckpointState. 390*da0073e9SAndroid Build Coastguard Worker 391*da0073e9SAndroid Build Coastguard Worker Returns None if no delta is found, otherwise, return a set() of mismatched 392*da0073e9SAndroid Build Coastguard Worker Guard type objects. 393*da0073e9SAndroid Build Coastguard Worker """ 394*da0073e9SAndroid Build Coastguard Worker r = self.dynamo_guards.difference(other.dynamo_guards) 395*da0073e9SAndroid Build Coastguard Worker if len(r) == 0: 396*da0073e9SAndroid Build Coastguard Worker return None 397*da0073e9SAndroid Build Coastguard Worker return r 398*da0073e9SAndroid Build Coastguard Worker 399*da0073e9SAndroid Build Coastguard Worker def __eq__(self, other): 400*da0073e9SAndroid Build Coastguard Worker return self.diff(other) is None 401*da0073e9SAndroid Build Coastguard Worker 402*da0073e9SAndroid Build Coastguard Worker 403*da0073e9SAndroid Build Coastguard Workerclass ModuleContextCheckpointState: 404*da0073e9SAndroid Build Coastguard Worker nn_modules: Dict[str, torch.nn.Module] = {} 405*da0073e9SAndroid Build Coastguard Worker 406*da0073e9SAndroid Build Coastguard Worker def __init__(self, nn_modules): 407*da0073e9SAndroid Build Coastguard Worker self.nn_modules = nn_modules 408*da0073e9SAndroid Build Coastguard Worker 409*da0073e9SAndroid Build Coastguard Worker def diff(self, other): 410*da0073e9SAndroid Build Coastguard Worker """ 411*da0073e9SAndroid Build Coastguard Worker Produces a delta against another ModuleContextCheckpointState. 412*da0073e9SAndroid Build Coastguard Worker 413*da0073e9SAndroid Build Coastguard Worker Returns None if no delta is found, otherwise, return a set() of mismatched 414*da0073e9SAndroid Build Coastguard Worker module key names. 415*da0073e9SAndroid Build Coastguard Worker """ 416*da0073e9SAndroid Build Coastguard Worker r = set(self.nn_modules.keys()).difference(set(other.nn_modules.keys())) 417*da0073e9SAndroid Build Coastguard Worker if len(r) == 0: 418*da0073e9SAndroid Build Coastguard Worker return None 419*da0073e9SAndroid Build Coastguard Worker return r 420*da0073e9SAndroid Build Coastguard Worker 421*da0073e9SAndroid Build Coastguard Worker def __eq__(self, other): 422*da0073e9SAndroid Build Coastguard Worker return self.diff(other) is None 423*da0073e9SAndroid Build Coastguard Worker 424*da0073e9SAndroid Build Coastguard Worker 425*da0073e9SAndroid Build Coastguard Workerclass ModuleContext(Checkpointable[ModuleContextCheckpointState]): 426*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 427*da0073e9SAndroid Build Coastguard Worker self.nn_modules: Dict[str, Any] = {} 428*da0073e9SAndroid Build Coastguard Worker 429*da0073e9SAndroid Build Coastguard Worker def copy_graphstate(self): 430*da0073e9SAndroid Build Coastguard Worker return ModuleContextCheckpointState(dict(self.nn_modules)) 431*da0073e9SAndroid Build Coastguard Worker 432*da0073e9SAndroid Build Coastguard Worker def restore_graphstate(self, state): 433*da0073e9SAndroid Build Coastguard Worker assert isinstance(state, ModuleContextCheckpointState) 434*da0073e9SAndroid Build Coastguard Worker self.nn_modules = state.nn_modules 435*da0073e9SAndroid Build Coastguard Worker 436*da0073e9SAndroid Build Coastguard Worker 437*da0073e9SAndroid Build Coastguard Workerclass GlobalContextCheckpointState: 438*da0073e9SAndroid Build Coastguard Worker global_state: Dict[str, Tuple[Callable, ...]] = {} 439*da0073e9SAndroid Build Coastguard Worker 440*da0073e9SAndroid Build Coastguard Worker def __init__(self, global_states): 441*da0073e9SAndroid Build Coastguard Worker self.global_state = global_states 442*da0073e9SAndroid Build Coastguard Worker 443*da0073e9SAndroid Build Coastguard Worker def diff(self, other): 444*da0073e9SAndroid Build Coastguard Worker """ 445*da0073e9SAndroid Build Coastguard Worker Produces a delta against another GlobalContextCheckpointState. 446*da0073e9SAndroid Build Coastguard Worker 447*da0073e9SAndroid Build Coastguard Worker Returns None if no delta is found, otherwise, return a set() of mismatched 448*da0073e9SAndroid Build Coastguard Worker global key names. 449*da0073e9SAndroid Build Coastguard Worker """ 450*da0073e9SAndroid Build Coastguard Worker r = set(self.global_state.keys()).difference(set(other.global_state.keys())) 451*da0073e9SAndroid Build Coastguard Worker if len(r) == 0: 452*da0073e9SAndroid Build Coastguard Worker return None 453*da0073e9SAndroid Build Coastguard Worker return r 454*da0073e9SAndroid Build Coastguard Worker 455*da0073e9SAndroid Build Coastguard Worker def __eq__(self, other): 456*da0073e9SAndroid Build Coastguard Worker return self.diff(other) is None 457*da0073e9SAndroid Build Coastguard Worker 458*da0073e9SAndroid Build Coastguard Worker 459*da0073e9SAndroid Build Coastguard Workerclass GlobalContext(Checkpointable[GlobalContextCheckpointState]): 460*da0073e9SAndroid Build Coastguard Worker """ 461*da0073e9SAndroid Build Coastguard Worker This keeps track of the global torch state during tracing of a function. 462*da0073e9SAndroid Build Coastguard Worker For example, torch.is_grad_enabled. 463*da0073e9SAndroid Build Coastguard Worker """ 464*da0073e9SAndroid Build Coastguard Worker 465*da0073e9SAndroid Build Coastguard Worker _supported_global_states = { 466*da0073e9SAndroid Build Coastguard Worker "grad_enabled", 467*da0073e9SAndroid Build Coastguard Worker "torch_function_enabled", 468*da0073e9SAndroid Build Coastguard Worker "autocast_enabled", 469*da0073e9SAndroid Build Coastguard Worker "autocast_cpu_enabled", 470*da0073e9SAndroid Build Coastguard Worker "autocast_gpu_dtype", 471*da0073e9SAndroid Build Coastguard Worker "autocast_cpu_dtype", 472*da0073e9SAndroid Build Coastguard Worker "autocast_cache_enabled", 473*da0073e9SAndroid Build Coastguard Worker } 474*da0073e9SAndroid Build Coastguard Worker 475*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 476*da0073e9SAndroid Build Coastguard Worker self.global_state: Dict[str, Tuple[Callable, ...]] = {} 477*da0073e9SAndroid Build Coastguard Worker 478*da0073e9SAndroid Build Coastguard Worker def copy_graphstate(self): 479*da0073e9SAndroid Build Coastguard Worker return GlobalContextCheckpointState(dict(self.global_state)) 480*da0073e9SAndroid Build Coastguard Worker 481*da0073e9SAndroid Build Coastguard Worker def restore_graphstate(self, state): 482*da0073e9SAndroid Build Coastguard Worker assert isinstance(state, GlobalContextCheckpointState) 483*da0073e9SAndroid Build Coastguard Worker self.global_state = state.global_state 484*da0073e9SAndroid Build Coastguard Worker assert ( 485*da0073e9SAndroid Build Coastguard Worker len(self.global_state) == len(self._supported_global_states) 486*da0073e9SAndroid Build Coastguard Worker and set(self.global_state.keys()) == self._supported_global_states 487*da0073e9SAndroid Build Coastguard Worker ), "Global state mismatch" 488*da0073e9SAndroid Build Coastguard Worker for func, args in self.global_state.values(): 489*da0073e9SAndroid Build Coastguard Worker func(args) 490*da0073e9SAndroid Build Coastguard Worker 491*da0073e9SAndroid Build Coastguard Worker 492*da0073e9SAndroid Build Coastguard Worker""" 493*da0073e9SAndroid Build Coastguard WorkerA GuardsContext is a checkpointable representation of all the guards in the current tracing 494*da0073e9SAndroid Build Coastguard Workercontext. It's lifecycle is bound 1:1 to the tracing context, and it should never be instantiated 495*da0073e9SAndroid Build Coastguard Workerdirectly outside of it. For passing around internal state representations of this object, 496*da0073e9SAndroid Build Coastguard Workerprefer to extract them with copy_graphstate to produce a GuardsCheckpointState. 497*da0073e9SAndroid Build Coastguard Worker""" 498*da0073e9SAndroid Build Coastguard Worker 499*da0073e9SAndroid Build Coastguard Worker 500*da0073e9SAndroid Build Coastguard Worker# Like a Set[Guard] but will record the user stack on all guards at the 501*da0073e9SAndroid Build Coastguard Worker# time they were installed at their destination 502*da0073e9SAndroid Build Coastguard Workerclass GuardsSet: 503*da0073e9SAndroid Build Coastguard Worker def __init__(self, inner=None): 504*da0073e9SAndroid Build Coastguard Worker if inner is None: 505*da0073e9SAndroid Build Coastguard Worker inner = set() 506*da0073e9SAndroid Build Coastguard Worker self.inner = inner 507*da0073e9SAndroid Build Coastguard Worker 508*da0073e9SAndroid Build Coastguard Worker def __iter__(self): 509*da0073e9SAndroid Build Coastguard Worker return iter(self.inner) 510*da0073e9SAndroid Build Coastguard Worker 511*da0073e9SAndroid Build Coastguard Worker def __len__(self): 512*da0073e9SAndroid Build Coastguard Worker return len(self.inner) 513*da0073e9SAndroid Build Coastguard Worker 514*da0073e9SAndroid Build Coastguard Worker # Subtraction along with bool is typically used to determine the delta of 515*da0073e9SAndroid Build Coastguard Worker # added guards between checkpoints for higher order ops 516*da0073e9SAndroid Build Coastguard Worker def __sub__(self, other): 517*da0073e9SAndroid Build Coastguard Worker return GuardsSet(self.inner - other.inner) 518*da0073e9SAndroid Build Coastguard Worker 519*da0073e9SAndroid Build Coastguard Worker def __bool__(self): 520*da0073e9SAndroid Build Coastguard Worker return bool(self.inner) 521*da0073e9SAndroid Build Coastguard Worker 522*da0073e9SAndroid Build Coastguard Worker def add(self, guard: Guard, *, collect_debug_stack=True, skip=0): 523*da0073e9SAndroid Build Coastguard Worker if guard in self.inner: 524*da0073e9SAndroid Build Coastguard Worker return 525*da0073e9SAndroid Build Coastguard Worker if collect_debug_stack: 526*da0073e9SAndroid Build Coastguard Worker if guard.stack is None: 527*da0073e9SAndroid Build Coastguard Worker guard.stack = CapturedTraceback.extract(skip=1 + skip) 528*da0073e9SAndroid Build Coastguard Worker if guard.user_stack is None: 529*da0073e9SAndroid Build Coastguard Worker guard.user_stack = TracingContext.extract_stack() 530*da0073e9SAndroid Build Coastguard Worker self.inner.add(guard) 531*da0073e9SAndroid Build Coastguard Worker 532*da0073e9SAndroid Build Coastguard Worker def update(self, *others: Set[Guard]): 533*da0073e9SAndroid Build Coastguard Worker for o in others: 534*da0073e9SAndroid Build Coastguard Worker for g in o: 535*da0073e9SAndroid Build Coastguard Worker self.add(g, skip=1) 536*da0073e9SAndroid Build Coastguard Worker 537*da0073e9SAndroid Build Coastguard Worker def remove_guards_with_source(self, source): 538*da0073e9SAndroid Build Coastguard Worker """Delete all guards with a given source""" 539*da0073e9SAndroid Build Coastguard Worker self.inner = {g for g in self.inner if g.originating_source != source} 540*da0073e9SAndroid Build Coastguard Worker 541*da0073e9SAndroid Build Coastguard Worker 542*da0073e9SAndroid Build Coastguard Workerclass GuardsContext(Checkpointable[GuardsCheckpointState]): 543*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 544*da0073e9SAndroid Build Coastguard Worker self.dynamo_guards: GuardsSet = GuardsSet() 545*da0073e9SAndroid Build Coastguard Worker self.aotautograd_guards: List[GuardEnvExpr] = [] 546*da0073e9SAndroid Build Coastguard Worker 547*da0073e9SAndroid Build Coastguard Worker def copy_graphstate(self): 548*da0073e9SAndroid Build Coastguard Worker return GuardsCheckpointState(set(self.dynamo_guards.inner)) 549*da0073e9SAndroid Build Coastguard Worker 550*da0073e9SAndroid Build Coastguard Worker def restore_graphstate(self, state): 551*da0073e9SAndroid Build Coastguard Worker # NB: "steals" the passed in state 552*da0073e9SAndroid Build Coastguard Worker assert isinstance(state, GuardsCheckpointState) 553*da0073e9SAndroid Build Coastguard Worker self.dynamo_guards = GuardsSet(state.dynamo_guards) 554*da0073e9SAndroid Build Coastguard Worker 555*da0073e9SAndroid Build Coastguard Worker 556*da0073e9SAndroid Build Coastguard Worker_TLS = threading.local() 557*da0073e9SAndroid Build Coastguard Worker 558*da0073e9SAndroid Build Coastguard Worker""" 559*da0073e9SAndroid Build Coastguard WorkerTracingContext is the source of truth for all currently accumulated information 560*da0073e9SAndroid Build Coastguard Workerneeded to trace. Its lifecycle is kept 1:1 when using TorchDynamo, but other systems 561*da0073e9SAndroid Build Coastguard Workerare open to managing their own TracingContext with that in mind. 562*da0073e9SAndroid Build Coastguard Worker 563*da0073e9SAndroid Build Coastguard WorkerThe purpose of TracingContext is not to be a dumping ground, or god object, but rather to avoid 564*da0073e9SAndroid Build Coastguard Workerhaving to plumb complex subsystems across multiple verticals. 565*da0073e9SAndroid Build Coastguard Worker 566*da0073e9SAndroid Build Coastguard WorkerEx: A common example is guard accumulation between dynamo, shape_env, aot_autograd, and inductor. 567*da0073e9SAndroid Build Coastguard WorkerAccessing the current tracing context via 568*da0073e9SAndroid Build Coastguard WorkerTracingContext.get() allows users to accumulate their own guards for processing, without needing to know how 569*da0073e9SAndroid Build Coastguard Workerto plumb objects back up to where frame interpretation happened. 570*da0073e9SAndroid Build Coastguard Worker 571*da0073e9SAndroid Build Coastguard WorkerNote that you can end up with multiple TracingContext for a single compilation 572*da0073e9SAndroid Build Coastguard Workerof a frame, as we reset the TracingContext whenever we restart analysis. 573*da0073e9SAndroid Build Coastguard WorkerCompileContext is a more overarching context that encompasses multiple restarts. 574*da0073e9SAndroid Build Coastguard Worker""" 575*da0073e9SAndroid Build Coastguard Worker 576*da0073e9SAndroid Build Coastguard Worker 577*da0073e9SAndroid Build Coastguard Workerclass CompileContext: 578*da0073e9SAndroid Build Coastguard Worker @staticmethod 579*da0073e9SAndroid Build Coastguard Worker def get() -> CompileContext: 580*da0073e9SAndroid Build Coastguard Worker assert _TLS.compile_context is not None 581*da0073e9SAndroid Build Coastguard Worker return _TLS.compile_context 582*da0073e9SAndroid Build Coastguard Worker 583*da0073e9SAndroid Build Coastguard Worker @staticmethod 584*da0073e9SAndroid Build Coastguard Worker def try_get() -> Optional[CompileContext]: 585*da0073e9SAndroid Build Coastguard Worker return getattr(_TLS, "compile_context", None) 586*da0073e9SAndroid Build Coastguard Worker 587*da0073e9SAndroid Build Coastguard Worker def __init__(self, compile_id): 588*da0073e9SAndroid Build Coastguard Worker assert compile_id is None or isinstance(compile_id, CompileId) 589*da0073e9SAndroid Build Coastguard Worker self.compile_id: Optional[CompileId] = compile_id 590*da0073e9SAndroid Build Coastguard Worker self.attempt = 0 591*da0073e9SAndroid Build Coastguard Worker 592*da0073e9SAndroid Build Coastguard Worker @staticmethod 593*da0073e9SAndroid Build Coastguard Worker def current_compile_id(): 594*da0073e9SAndroid Build Coastguard Worker self = CompileContext.try_get() 595*da0073e9SAndroid Build Coastguard Worker if self is None: 596*da0073e9SAndroid Build Coastguard Worker return None 597*da0073e9SAndroid Build Coastguard Worker return self.compile_id 598*da0073e9SAndroid Build Coastguard Worker 599*da0073e9SAndroid Build Coastguard Worker @staticmethod 600*da0073e9SAndroid Build Coastguard Worker def current_trace_id(): 601*da0073e9SAndroid Build Coastguard Worker self = CompileContext.try_get() 602*da0073e9SAndroid Build Coastguard Worker if self is None: 603*da0073e9SAndroid Build Coastguard Worker return None 604*da0073e9SAndroid Build Coastguard Worker if self.compile_id is None: 605*da0073e9SAndroid Build Coastguard Worker return None 606*da0073e9SAndroid Build Coastguard Worker return TraceId(self.compile_id, self.attempt) 607*da0073e9SAndroid Build Coastguard Worker 608*da0073e9SAndroid Build Coastguard Worker 609*da0073e9SAndroid Build Coastguard Workerclass TracingContext: 610*da0073e9SAndroid Build Coastguard Worker """ 611*da0073e9SAndroid Build Coastguard Worker Provides the currently installed TracingContext, or None. 612*da0073e9SAndroid Build Coastguard Worker 613*da0073e9SAndroid Build Coastguard Worker Note that it is a staticmethod, and invocations outside of `with tracing()` (see below), are valid but 614*da0073e9SAndroid Build Coastguard Worker will return None. 615*da0073e9SAndroid Build Coastguard Worker """ 616*da0073e9SAndroid Build Coastguard Worker 617*da0073e9SAndroid Build Coastguard Worker @staticmethod 618*da0073e9SAndroid Build Coastguard Worker def try_get() -> Optional[TracingContext]: 619*da0073e9SAndroid Build Coastguard Worker return getattr(_TLS, "tracing_context", None) 620*da0073e9SAndroid Build Coastguard Worker 621*da0073e9SAndroid Build Coastguard Worker @staticmethod 622*da0073e9SAndroid Build Coastguard Worker def get() -> TracingContext: 623*da0073e9SAndroid Build Coastguard Worker if ctx := TracingContext.try_get(): 624*da0073e9SAndroid Build Coastguard Worker return ctx 625*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 626*da0073e9SAndroid Build Coastguard Worker "TracingContext.get() must be called within an ongoing trace." 627*da0073e9SAndroid Build Coastguard Worker ) 628*da0073e9SAndroid Build Coastguard Worker 629*da0073e9SAndroid Build Coastguard Worker def __init__(self, fake_mode): 630*da0073e9SAndroid Build Coastguard Worker self.guards_context = GuardsContext() 631*da0073e9SAndroid Build Coastguard Worker self.module_context = ModuleContext() 632*da0073e9SAndroid Build Coastguard Worker self.global_context = GlobalContext() 633*da0073e9SAndroid Build Coastguard Worker self.fake_mode = fake_mode 634*da0073e9SAndroid Build Coastguard Worker self.frame_summary_stack = [] 635*da0073e9SAndroid Build Coastguard Worker # This is morally part of frame_summary_stack, but it is kept separate 636*da0073e9SAndroid Build Coastguard Worker # for clarity. As we process a frame, this variable gets updated 637*da0073e9SAndroid Build Coastguard Worker # to keep track of what line we are in the function. We make a 638*da0073e9SAndroid Build Coastguard Worker # function call, this gets cleared and the frame location is pushed 639*da0073e9SAndroid Build Coastguard Worker # to frame_summary_stack (prepping this variable for the inner frame's 640*da0073e9SAndroid Build Coastguard Worker # progress) 641*da0073e9SAndroid Build Coastguard Worker self.loc_in_frame = None 642*da0073e9SAndroid Build Coastguard Worker # this is only set after aot_autograd 643*da0073e9SAndroid Build Coastguard Worker self.fw_metadata = None 644*da0073e9SAndroid Build Coastguard Worker # this is only set after aot_autograd 645*da0073e9SAndroid Build Coastguard Worker self.aot_graph_name = None 646*da0073e9SAndroid Build Coastguard Worker self.params_flat = None 647*da0073e9SAndroid Build Coastguard Worker # this is for extended return calling convention from backend 648*da0073e9SAndroid Build Coastguard Worker # compiler to aot_autograd 649*da0073e9SAndroid Build Coastguard Worker # Per output, what the compiler specified stride of the output is, 650*da0073e9SAndroid Build Coastguard Worker # or None if no stride is known. This is always the HINT, it 651*da0073e9SAndroid Build Coastguard Worker # is never a SymInt (it would be better if it was a SymInt, but 652*da0073e9SAndroid Build Coastguard Worker # I can't conveniently get this from Inductor atm. Also, be 653*da0073e9SAndroid Build Coastguard Worker # careful not to accidentally induce guards on the SymInt if 654*da0073e9SAndroid Build Coastguard Worker # you ever do change this in aot_autograd.py; you should check 655*da0073e9SAndroid Build Coastguard Worker # on permutations preferentially.) 656*da0073e9SAndroid Build Coastguard Worker self.output_strides: Optional[List[Optional[Tuple[int, ...]]]] = None 657*da0073e9SAndroid Build Coastguard Worker # When this is True, whenever we encounter an int in Dynamo tracing, 658*da0073e9SAndroid Build Coastguard Worker # we will (1) force unspec it and (2) force it as a size-like unbacked 659*da0073e9SAndroid Build Coastguard Worker # integer. This is currently used when processing certain lists of 660*da0073e9SAndroid Build Coastguard Worker # ints that are known to be size-like and may have 0/1 entries that we 661*da0073e9SAndroid Build Coastguard Worker # must not specialize on. 662*da0073e9SAndroid Build Coastguard Worker self.force_unspec_int_unbacked_size_like = False 663*da0073e9SAndroid Build Coastguard Worker # See note [Tensor Fakification and Symbol Caching] 664*da0073e9SAndroid Build Coastguard Worker self.tensor_to_context = WeakTensorKeyDictionary() 665*da0073e9SAndroid Build Coastguard Worker 666*da0073e9SAndroid Build Coastguard Worker # If this true, Aot Autograd will return output Fake Tensors with appropiate 667*da0073e9SAndroid Build Coastguard Worker # meta on the first invocation 668*da0073e9SAndroid Build Coastguard Worker # see note: [Returning Fake Tensors on First AOT Autograd Call] 669*da0073e9SAndroid Build Coastguard Worker self.fakify_first_call = False 670*da0073e9SAndroid Build Coastguard Worker 671*da0073e9SAndroid Build Coastguard Worker def clear(self): 672*da0073e9SAndroid Build Coastguard Worker # Look at the note in output_graph.py in function `save_global_state` 673*da0073e9SAndroid Build Coastguard Worker # for the context on clearing global context. 674*da0073e9SAndroid Build Coastguard Worker self.global_context.global_state = {} 675*da0073e9SAndroid Build Coastguard Worker 676*da0073e9SAndroid Build Coastguard Worker @staticmethod 677*da0073e9SAndroid Build Coastguard Worker @contextmanager 678*da0073e9SAndroid Build Coastguard Worker def patch(**kwargs): 679*da0073e9SAndroid Build Coastguard Worker prior = {} 680*da0073e9SAndroid Build Coastguard Worker ctx = TracingContext.get() 681*da0073e9SAndroid Build Coastguard Worker 682*da0073e9SAndroid Build Coastguard Worker for key in kwargs.keys(): 683*da0073e9SAndroid Build Coastguard Worker # KeyError on invalid entry 684*da0073e9SAndroid Build Coastguard Worker prior[key] = getattr(ctx, key) 685*da0073e9SAndroid Build Coastguard Worker for key, val in kwargs.items(): 686*da0073e9SAndroid Build Coastguard Worker setattr(ctx, key, val) 687*da0073e9SAndroid Build Coastguard Worker try: 688*da0073e9SAndroid Build Coastguard Worker yield 689*da0073e9SAndroid Build Coastguard Worker finally: 690*da0073e9SAndroid Build Coastguard Worker for key, val in prior.items(): 691*da0073e9SAndroid Build Coastguard Worker setattr(ctx, key, val) 692*da0073e9SAndroid Build Coastguard Worker 693*da0073e9SAndroid Build Coastguard Worker @staticmethod 694*da0073e9SAndroid Build Coastguard Worker def extract_stack(): 695*da0073e9SAndroid Build Coastguard Worker self = TracingContext.try_get() 696*da0073e9SAndroid Build Coastguard Worker if self is None: 697*da0073e9SAndroid Build Coastguard Worker return traceback.StackSummary() 698*da0073e9SAndroid Build Coastguard Worker stack = self.frame_summary_stack 699*da0073e9SAndroid Build Coastguard Worker if self.loc_in_frame is not None: 700*da0073e9SAndroid Build Coastguard Worker stack = stack + [self.loc_in_frame] 701*da0073e9SAndroid Build Coastguard Worker return traceback.StackSummary.from_list(stack) 702*da0073e9SAndroid Build Coastguard Worker 703*da0073e9SAndroid Build Coastguard Worker # Call this when you want to call into some code that isn't necessarily 704*da0073e9SAndroid Build Coastguard Worker # associated with the current frame state 705*da0073e9SAndroid Build Coastguard Worker @staticmethod 706*da0073e9SAndroid Build Coastguard Worker @contextlib.contextmanager 707*da0073e9SAndroid Build Coastguard Worker def clear_frame(): 708*da0073e9SAndroid Build Coastguard Worker tc = TracingContext.get() 709*da0073e9SAndroid Build Coastguard Worker with unittest.mock.patch.object( 710*da0073e9SAndroid Build Coastguard Worker tc, "frame_summary_stack", [] 711*da0073e9SAndroid Build Coastguard Worker ), unittest.mock.patch.object(tc, "loc_in_frame", None): 712*da0073e9SAndroid Build Coastguard Worker try: 713*da0073e9SAndroid Build Coastguard Worker yield 714*da0073e9SAndroid Build Coastguard Worker except Exception as e: 715*da0073e9SAndroid Build Coastguard Worker # Prevent real_stack from getting attached 716*da0073e9SAndroid Build Coastguard Worker # 717*da0073e9SAndroid Build Coastguard Worker # The invariant is that if an Exception as real_stack, we've 718*da0073e9SAndroid Build Coastguard Worker # appropriately attached a user stack and we no longer need to 719*da0073e9SAndroid Build Coastguard Worker # attach anything. Because we cannot conveniently interpose 720*da0073e9SAndroid Build Coastguard Worker # when an exception is thrown, we instead interpose everywhere 721*da0073e9SAndroid Build Coastguard Worker # we set what the user stack is set (using the context 722*da0073e9SAndroid Build Coastguard Worker # manager). However, our compiler stack does "tail calls" 723*da0073e9SAndroid Build Coastguard Worker # (when it calls into user compiler), at which point the 724*da0073e9SAndroid Build Coastguard Worker # parent exception frames would incorrectly attach an 725*da0073e9SAndroid Build Coastguard Worker # incorrect frame. 726*da0073e9SAndroid Build Coastguard Worker # 727*da0073e9SAndroid Build Coastguard Worker # However, if, somehow, someone raised an exception with this 728*da0073e9SAndroid Build Coastguard Worker # scope that had a stack (for example, because they are 729*da0073e9SAndroid Build Coastguard Worker # restoring the user stack state appropriately as they process 730*da0073e9SAndroid Build Coastguard Worker # node by node), we should respect it. Thus, we cannot 731*da0073e9SAndroid Build Coastguard Worker # unconditionally set None. 732*da0073e9SAndroid Build Coastguard Worker if not hasattr(e, "real_stack"): 733*da0073e9SAndroid Build Coastguard Worker e.real_stack = None # type: ignore[attr-defined] 734*da0073e9SAndroid Build Coastguard Worker raise 735*da0073e9SAndroid Build Coastguard Worker 736*da0073e9SAndroid Build Coastguard Worker @staticmethod 737*da0073e9SAndroid Build Coastguard Worker @contextlib.contextmanager 738*da0073e9SAndroid Build Coastguard Worker def current_frame(frame_summary): 739*da0073e9SAndroid Build Coastguard Worker # frame_summary can be None to solely take advantage of real_stack 740*da0073e9SAndroid Build Coastguard Worker # attachment to thrown exceptions 741*da0073e9SAndroid Build Coastguard Worker tc = TracingContext.get() 742*da0073e9SAndroid Build Coastguard Worker if frame_summary is not None: 743*da0073e9SAndroid Build Coastguard Worker tc.frame_summary_stack.append(frame_summary) 744*da0073e9SAndroid Build Coastguard Worker old = tc.loc_in_frame 745*da0073e9SAndroid Build Coastguard Worker tc.loc_in_frame = None 746*da0073e9SAndroid Build Coastguard Worker try: 747*da0073e9SAndroid Build Coastguard Worker yield 748*da0073e9SAndroid Build Coastguard Worker except Exception as e: 749*da0073e9SAndroid Build Coastguard Worker if not hasattr(e, "real_stack"): 750*da0073e9SAndroid Build Coastguard Worker e.real_stack = tc.extract_stack() # type: ignore[attr-defined] 751*da0073e9SAndroid Build Coastguard Worker raise 752*da0073e9SAndroid Build Coastguard Worker finally: 753*da0073e9SAndroid Build Coastguard Worker if frame_summary is not None: 754*da0073e9SAndroid Build Coastguard Worker tc.frame_summary_stack.pop() 755*da0073e9SAndroid Build Coastguard Worker tc.loc_in_frame = old 756*da0073e9SAndroid Build Coastguard Worker 757*da0073e9SAndroid Build Coastguard Worker @staticmethod 758*da0073e9SAndroid Build Coastguard Worker @contextlib.contextmanager 759*da0073e9SAndroid Build Coastguard Worker def report_output_strides(): 760*da0073e9SAndroid Build Coastguard Worker tc = TracingContext.try_get() 761*da0073e9SAndroid Build Coastguard Worker if tc is None: 762*da0073e9SAndroid Build Coastguard Worker yield None 763*da0073e9SAndroid Build Coastguard Worker return 764*da0073e9SAndroid Build Coastguard Worker old_output_strides = tc.output_strides 765*da0073e9SAndroid Build Coastguard Worker tc.output_strides = [] 766*da0073e9SAndroid Build Coastguard Worker try: 767*da0073e9SAndroid Build Coastguard Worker yield tc.output_strides 768*da0073e9SAndroid Build Coastguard Worker finally: 769*da0073e9SAndroid Build Coastguard Worker tc.output_strides = old_output_strides 770*da0073e9SAndroid Build Coastguard Worker 771*da0073e9SAndroid Build Coastguard Worker @staticmethod 772*da0073e9SAndroid Build Coastguard Worker def set_current_loc(filename, lineno, frame_name): 773*da0073e9SAndroid Build Coastguard Worker TracingContext.get().loc_in_frame = traceback.FrameSummary( 774*da0073e9SAndroid Build Coastguard Worker filename, lineno, frame_name, lookup_line=False 775*da0073e9SAndroid Build Coastguard Worker ) 776*da0073e9SAndroid Build Coastguard Worker 777*da0073e9SAndroid Build Coastguard Worker 778*da0073e9SAndroid Build Coastguard Worker@contextmanager 779*da0073e9SAndroid Build Coastguard Workerdef compile_context(context: Optional[CompileContext]): 780*da0073e9SAndroid Build Coastguard Worker old_context = getattr(_TLS, "compile_context", None) 781*da0073e9SAndroid Build Coastguard Worker _TLS.compile_context = context 782*da0073e9SAndroid Build Coastguard Worker try: 783*da0073e9SAndroid Build Coastguard Worker yield context 784*da0073e9SAndroid Build Coastguard Worker finally: 785*da0073e9SAndroid Build Coastguard Worker if context is not None: 786*da0073e9SAndroid Build Coastguard Worker if context.compile_id is not None: 787*da0073e9SAndroid Build Coastguard Worker set_context_frame( 788*da0073e9SAndroid Build Coastguard Worker ( 789*da0073e9SAndroid Build Coastguard Worker context.compile_id.frame_id, 790*da0073e9SAndroid Build Coastguard Worker context.compile_id.frame_compile_id, 791*da0073e9SAndroid Build Coastguard Worker context.attempt, 792*da0073e9SAndroid Build Coastguard Worker ) 793*da0073e9SAndroid Build Coastguard Worker ) 794*da0073e9SAndroid Build Coastguard Worker _TLS.compile_context = old_context 795*da0073e9SAndroid Build Coastguard Worker 796*da0073e9SAndroid Build Coastguard Worker 797*da0073e9SAndroid Build Coastguard Worker@contextmanager 798*da0073e9SAndroid Build Coastguard Workerdef tracing(context: Optional[TracingContext]): 799*da0073e9SAndroid Build Coastguard Worker """ 800*da0073e9SAndroid Build Coastguard Worker This function installs the passed in tracing context as a dynamic scoped 801*da0073e9SAndroid Build Coastguard Worker global variable. 802*da0073e9SAndroid Build Coastguard Worker 803*da0073e9SAndroid Build Coastguard Worker Calls to TracingContext.get() while not under a `with tracing()` context 804*da0073e9SAndroid Build Coastguard Worker will return None. 805*da0073e9SAndroid Build Coastguard Worker """ 806*da0073e9SAndroid Build Coastguard Worker old_context = getattr(_TLS, "tracing_context", None) 807*da0073e9SAndroid Build Coastguard Worker _TLS.tracing_context = context 808*da0073e9SAndroid Build Coastguard Worker try: 809*da0073e9SAndroid Build Coastguard Worker yield context 810*da0073e9SAndroid Build Coastguard Worker except Exception as e: 811*da0073e9SAndroid Build Coastguard Worker if not hasattr(e, "real_stack") and context is not None: 812*da0073e9SAndroid Build Coastguard Worker e.real_stack = context.extract_stack() # type: ignore[attr-defined] 813*da0073e9SAndroid Build Coastguard Worker raise 814*da0073e9SAndroid Build Coastguard Worker finally: 815*da0073e9SAndroid Build Coastguard Worker if ( 816*da0073e9SAndroid Build Coastguard Worker context is not None 817*da0073e9SAndroid Build Coastguard Worker and context.fake_mode is not None 818*da0073e9SAndroid Build Coastguard Worker and context.fake_mode.shape_env is not None 819*da0073e9SAndroid Build Coastguard Worker ): 820*da0073e9SAndroid Build Coastguard Worker context.fake_mode.shape_env.cleanup() 821*da0073e9SAndroid Build Coastguard Worker _TLS.tracing_context = old_context 822*da0073e9SAndroid Build Coastguard Worker 823*da0073e9SAndroid Build Coastguard Worker 824*da0073e9SAndroid Build Coastguard Worker# Subclasses can be found in torch/_dynamo/source.py 825*da0073e9SAndroid Build Coastguard Worker# TODO(voz): Consider a toplevel torch/_source.py 826*da0073e9SAndroid Build Coastguard Worker@dataclasses.dataclass(frozen=True) 827*da0073e9SAndroid Build Coastguard Workerclass Source: 828*da0073e9SAndroid Build Coastguard Worker def is_dict_key(self): 829*da0073e9SAndroid Build Coastguard Worker return False 830*da0073e9SAndroid Build Coastguard Worker 831*da0073e9SAndroid Build Coastguard Worker def is_ephemeral(self): 832*da0073e9SAndroid Build Coastguard Worker return False 833*da0073e9SAndroid Build Coastguard Worker 834*da0073e9SAndroid Build Coastguard Worker def reconstruct(self, codegen): 835*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 836*da0073e9SAndroid Build Coastguard Worker 837*da0073e9SAndroid Build Coastguard Worker def guard_source(self) -> GuardSource: 838*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 839*da0073e9SAndroid Build Coastguard Worker 840*da0073e9SAndroid Build Coastguard Worker def name(self) -> str: 841*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 842*da0073e9SAndroid Build Coastguard Worker 843*da0073e9SAndroid Build Coastguard Worker def make_guard(self, fn) -> Guard: 844*da0073e9SAndroid Build Coastguard Worker if self.guard_source() is GuardSource.CONSTANT: 845*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 846*da0073e9SAndroid Build Coastguard Worker return Guard(self, fn) 847*da0073e9SAndroid Build Coastguard Worker 848*da0073e9SAndroid Build Coastguard Worker def is_specialized_nn_module(self) -> bool: 849*da0073e9SAndroid Build Coastguard Worker return self.guard_source().is_specialized_nn_module() 850*da0073e9SAndroid Build Coastguard Worker 851*da0073e9SAndroid Build Coastguard Worker def subguards_allowed(self): 852*da0073e9SAndroid Build Coastguard Worker """True if you can guard on attributes of this""" 853*da0073e9SAndroid Build Coastguard Worker return self.guard_source() != GuardSource.SYNTHETIC_LOCAL 854*da0073e9SAndroid Build Coastguard Worker 855*da0073e9SAndroid Build Coastguard Worker 856*da0073e9SAndroid Build Coastguard Worker# Subclasses can be found in torch/_dynamo/source.py 857*da0073e9SAndroid Build Coastguard Worker@dataclasses.dataclass(frozen=True) 858*da0073e9SAndroid Build Coastguard Workerclass ChainedSource(Source): 859*da0073e9SAndroid Build Coastguard Worker base: Source 860*da0073e9SAndroid Build Coastguard Worker 861*da0073e9SAndroid Build Coastguard Worker def is_dict_key(self): 862*da0073e9SAndroid Build Coastguard Worker # Recurse until you either hit a ConstDictKey or a Source 863*da0073e9SAndroid Build Coastguard Worker return self.base.is_dict_key() 864*da0073e9SAndroid Build Coastguard Worker 865*da0073e9SAndroid Build Coastguard Worker def is_ephemeral(self): 866*da0073e9SAndroid Build Coastguard Worker return self.base.is_ephemeral() 867*da0073e9SAndroid Build Coastguard Worker 868*da0073e9SAndroid Build Coastguard Worker 869*da0073e9SAndroid Build Coastguard Workerdef detect_fake_mode(inputs: Any = None): 870*da0073e9SAndroid Build Coastguard Worker """ 871*da0073e9SAndroid Build Coastguard Worker Attempts to "detect" what the current fake mode is. If there is one ambiently 872*da0073e9SAndroid Build Coastguard Worker available from TracingContext, we preferentially use that. Otherwise, we 873*da0073e9SAndroid Build Coastguard Worker heuristically detect the fake mode via the following sources, in order of 874*da0073e9SAndroid Build Coastguard Worker priority: 875*da0073e9SAndroid Build Coastguard Worker 876*da0073e9SAndroid Build Coastguard Worker - Currently active fake mode on stack 877*da0073e9SAndroid Build Coastguard Worker - Fake mode associated with passed in tensors (inputs does not 878*da0073e9SAndroid Build Coastguard Worker have to be flattened) 879*da0073e9SAndroid Build Coastguard Worker """ 880*da0073e9SAndroid Build Coastguard Worker from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode 881*da0073e9SAndroid Build Coastguard Worker 882*da0073e9SAndroid Build Coastguard Worker fake_modes = [] 883*da0073e9SAndroid Build Coastguard Worker 884*da0073e9SAndroid Build Coastguard Worker if context := TracingContext.try_get(): 885*da0073e9SAndroid Build Coastguard Worker fake_mode = context.fake_mode 886*da0073e9SAndroid Build Coastguard Worker if fake_mode is not None: 887*da0073e9SAndroid Build Coastguard Worker fake_modes.append((fake_mode, "tracing context", 0)) 888*da0073e9SAndroid Build Coastguard Worker 889*da0073e9SAndroid Build Coastguard Worker from torch.utils._python_dispatch import _get_current_dispatch_mode_stack 890*da0073e9SAndroid Build Coastguard Worker 891*da0073e9SAndroid Build Coastguard Worker for i, m in enumerate(reversed(_get_current_dispatch_mode_stack())): 892*da0073e9SAndroid Build Coastguard Worker if isinstance(m, FakeTensorMode): 893*da0073e9SAndroid Build Coastguard Worker fake_modes.append((m, "active fake mode", i)) 894*da0073e9SAndroid Build Coastguard Worker 895*da0073e9SAndroid Build Coastguard Worker flat_inputs = pytree.tree_leaves(inputs) 896*da0073e9SAndroid Build Coastguard Worker for i, flat_input in enumerate(flat_inputs): 897*da0073e9SAndroid Build Coastguard Worker if isinstance(flat_input, FakeTensor): 898*da0073e9SAndroid Build Coastguard Worker fake_modes.append((flat_input.fake_mode, "fake tensor input", i)) 899*da0073e9SAndroid Build Coastguard Worker 900*da0073e9SAndroid Build Coastguard Worker if fake_modes: 901*da0073e9SAndroid Build Coastguard Worker fake_mode, desc1, i1 = fake_modes[0] 902*da0073e9SAndroid Build Coastguard Worker for m, desc2, i2 in fake_modes[1:]: 903*da0073e9SAndroid Build Coastguard Worker assert fake_mode is m, ( 904*da0073e9SAndroid Build Coastguard Worker f"fake mode ({fake_mode}) from {desc1} {i1} doesn't match mode ({m}) from {desc2} {i2}\n\n" 905*da0073e9SAndroid Build Coastguard Worker f"fake mode from {desc1} {i1} allocated at:\n{fake_mode.stack}\n" 906*da0073e9SAndroid Build Coastguard Worker f"fake mode from {desc2} {i2} allocated at:\n{m.stack}" 907*da0073e9SAndroid Build Coastguard Worker ) 908*da0073e9SAndroid Build Coastguard Worker return fake_mode 909*da0073e9SAndroid Build Coastguard Worker else: 910*da0073e9SAndroid Build Coastguard Worker return None 911*da0073e9SAndroid Build Coastguard Worker 912*da0073e9SAndroid Build Coastguard Worker 913*da0073e9SAndroid Build Coastguard Workerdef active_fake_mode(): 914*da0073e9SAndroid Build Coastguard Worker """ 915*da0073e9SAndroid Build Coastguard Worker Inspects the dispatch mode stack for an active fake mode and returns it. 916*da0073e9SAndroid Build Coastguard Worker Returns None if no fake mode is active. 917*da0073e9SAndroid Build Coastguard Worker """ 918*da0073e9SAndroid Build Coastguard Worker from torch._subclasses.fake_tensor import FakeTensorMode 919*da0073e9SAndroid Build Coastguard Worker from torch.utils._python_dispatch import _get_current_dispatch_mode_stack 920*da0073e9SAndroid Build Coastguard Worker 921*da0073e9SAndroid Build Coastguard Worker for _, m in enumerate(reversed(_get_current_dispatch_mode_stack())): 922*da0073e9SAndroid Build Coastguard Worker if isinstance(m, FakeTensorMode): 923*da0073e9SAndroid Build Coastguard Worker return m 924*da0073e9SAndroid Build Coastguard Worker 925*da0073e9SAndroid Build Coastguard Worker return None 926