1# mypy: ignore-errors 2 3import collections 4from enum import Enum 5from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING 6 7from .. import variables 8from ..current_scope_id import current_scope_id 9from ..exc import unimplemented 10from ..source import AttrSource, Source 11from ..utils import istype 12 13 14if TYPE_CHECKING: 15 from torch._dynamo.symbolic_convert import InstructionTranslator 16 17 18class MutableLocalSource(Enum): 19 """ 20 If the VariableTracker.mutable_local represents a Variable that: 21 - already existed that Dynamo began tracking while introspection (Existing) 22 - is a new variable that is created during Dynamo introspection (Local) 23 """ 24 25 Existing = 0 26 Local = 1 27 28 29class MutableLocalBase: 30 """ 31 Base class for Variable.mutable_local 32 """ 33 34 def __init__(self, typ: MutableLocalSource) -> None: 35 # In HigherOrderOperator tracing, we need to distinguish 36 # between MutableLocals inside the HigherOrderOperator and 37 # ones outside it. For example, it is not safe to mutate 38 # `a` in the following example because it was constructed 39 # in a different scope. 40 # 41 # def f(x): 42 # a = 1 43 # def g(x): 44 # nonlocal a 45 # a = 2 46 # return x 47 # return wrap(g, x) + a 48 # 49 # We use self.scope to distinguish this. 50 # scope == 0: The object was an existing variable 51 # scope == 1: The object was created while Dynamo 52 # was introspecting a function 53 # (and no HigherOrderOps were involved) 54 # scope >= 2: The object was created through 55 # Dynamo introspection of a HigherOrderOp. 56 # The exact number corresponds to the level 57 # of nested HigherOrderOps. 58 if typ is MutableLocalSource.Existing: 59 self.scope = 0 60 elif typ is MutableLocalSource.Local: 61 self.scope = current_scope_id() 62 else: 63 unimplemented(f"Unsupported MutableLocalSource: {typ}") 64 65 66class MutableLocal(MutableLocalBase): 67 """ 68 Marker used to indicate this (list, iter, etc) was constructed in 69 local scope and can be mutated safely in analysis without leaking 70 state. 71 """ 72 73 def __init__(self) -> None: 74 super().__init__(MutableLocalSource.Local) 75 76 def __hash__(self): 77 return id(self) 78 79 def __eq__(self, other): 80 return self is other 81 82 83def _is_top_level_scope(scope_id): 84 return scope_id == 1 85 86 87def is_side_effect_safe(m: MutableLocalBase): 88 scope_id = current_scope_id() 89 90 # In the top-level scope (if no HigherOrderOperators are involved), 91 # we are allowed to modify variables created in this scope as well 92 # as existing variables. 93 if _is_top_level_scope(scope_id): 94 return True 95 # Otherwise, only allow local mutation of variables created in the current scope 96 return m.scope == scope_id 97 98 99class VariableTrackerMeta(type): 100 all_subclasses = [] 101 102 def __instancecheck__(cls, instance) -> bool: 103 """Make isinstance work with LazyVariableTracker""" 104 if type.__instancecheck__( 105 variables.LazyVariableTracker, instance 106 ) and cls not in ( 107 VariableTracker, 108 variables.LazyVariableTracker, 109 ): 110 instance = instance.realize() 111 return type.__instancecheck__(cls, instance) 112 113 def __init__(cls, name, bases, attrs) -> None: 114 super().__init__(name, bases, attrs) 115 VariableTrackerMeta.all_subclasses.append(cls) 116 117 118class VariableTracker(metaclass=VariableTrackerMeta): 119 """ 120 Base class for tracked locals and stack values 121 122 VariableTracker instances are immutable and should be copied in 123 order to change them. 124 """ 125 126 # fields to leave unmodified in apply() 127 _nonvar_fields = { 128 "value", 129 "guards", 130 "source", 131 "mutable_local", 132 "parents_tracker", 133 "user_code_variable_name", 134 } 135 136 def clone(self, **kwargs): 137 """Shallow copy with some (optional) changes""" 138 args = dict(self.__dict__) 139 args.update(kwargs) 140 return self.__class__(**args) 141 142 @classmethod 143 def visit( 144 cls, 145 fn: Callable[["VariableTracker"], None], 146 value: Any, 147 cache: Optional[Dict[int, Any]] = None, 148 ) -> None: 149 """ 150 Walk value and call fn on all the VariableTracker instances 151 """ 152 if cache is None: 153 cache = {} 154 155 idx = id(value) 156 if idx in cache: 157 return 158 # save `value` to keep it alive and ensure id() isn't reused 159 cache[idx] = value 160 161 if isinstance(value, VariableTracker): 162 value = value.unwrap() 163 fn(value) 164 value = value.unwrap() # calling fn() might have realized it 165 nonvars = value._nonvar_fields 166 for key, subvalue in value.__dict__.items(): 167 if key not in nonvars: 168 cls.visit(fn, subvalue, cache) 169 elif istype(value, (list, tuple)): 170 for subvalue in value: 171 cls.visit(fn, subvalue, cache) 172 elif istype(value, (dict, collections.OrderedDict)): 173 for subvalue in value.values(): 174 cls.visit(fn, subvalue, cache) 175 176 def __repr__(self) -> str: 177 return f"{self.__class__.__name__}()" 178 179 def debug_repr(self): 180 # Intended to be overridden to provide more info 181 try: 182 return repr(self.as_python_constant()) 183 except NotImplementedError: 184 return repr(self) 185 186 def python_type(self): 187 """ 188 Abstract method to be implemented by subclasses of VariableTracker. 189 190 This method should return the type represented by the instance of the subclass. 191 The purpose is to provide a standardized way to retrieve the Python type information 192 of the variable being tracked. 193 194 Returns: 195 type: The Python type (such as int, str, list, etc.) of the variable tracked by 196 the subclass. If the type cannot be determined or is not relevant, 197 leaving it undefined or invoking super() is always sound. 198 199 Note: 200 This is an abstract method and may be overridden in subclasses. 201 202 Example: 203 class SetVariable(VariableTracker): 204 def python_type(self): 205 return set 206 207 Raises: 208 NotImplementedError: If the method is not implemented in a subclass. 209 """ 210 try: 211 return type(self.as_python_constant()) 212 except NotImplementedError: 213 raise NotImplementedError(f"{self} has no type") from None 214 215 def as_python_constant(self): 216 """For constants""" 217 raise NotImplementedError(f"{self} is not a constant") 218 219 def guard_as_python_constant(self): 220 """Similar to as_python_constant(), but add ID_MATCH guards to try to force things to become constants""" 221 try: 222 return self.as_python_constant() 223 except NotImplementedError as e: 224 unimplemented(str(e)) 225 226 def is_python_constant(self): 227 try: 228 self.as_python_constant() 229 return True 230 except NotImplementedError: 231 return False 232 233 def make_guard(self, fn): 234 if self.source: 235 return self.source.make_guard(fn) 236 raise NotImplementedError 237 238 def const_getattr(self, tx: "InstructionTranslator", name: str) -> Any: 239 """getattr(self, name) returning a python constant""" 240 raise NotImplementedError 241 242 def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": 243 """getattr(self, name) returning a new variable""" 244 value = self.const_getattr(tx, name) 245 if not variables.ConstantVariable.is_literal(value): 246 raise NotImplementedError 247 source = None 248 if self.source: 249 source = AttrSource(self.source, name) 250 return variables.ConstantVariable.create(value, source=source) 251 252 def is_proxy(self): 253 try: 254 self.as_proxy() 255 return True 256 except NotImplementedError: 257 return False 258 259 def as_proxy(self): 260 raise NotImplementedError(str(self)) 261 262 def maybe_fx_node(self): 263 try: 264 proxy = self.as_proxy() 265 import torch.fx 266 267 if isinstance(proxy, torch.fx.Proxy): 268 return proxy.node 269 return None 270 except NotImplementedError: 271 return None 272 273 def reconstruct(self, codegen): 274 raise NotImplementedError 275 276 def can_reconstruct(self, tx): 277 """If it is possible to reconstruct the Python object this 278 VariableTracker represents.""" 279 assert tx is tx.output.root_tx, "Only root tx can reconstruct" 280 try: 281 from ..codegen import PyCodegen 282 283 cg = PyCodegen(tx) 284 self.reconstruct(cg) 285 return True 286 except NotImplementedError: 287 return False 288 289 def unpack_var_sequence(self, tx) -> List["VariableTracker"]: 290 raise NotImplementedError 291 292 def force_unpack_var_sequence(self, tx) -> List["VariableTracker"]: 293 # like unpack_var_sequence, but should only be used when it is 294 # safe to eagerly (vs. lazily) unpack this variable. 295 # e.g. map(f, x) is normally evaluated lazily but sometimes 296 # we want to force eager unpacking, e.g. when converting to a list. 297 # NOTE: this method is allowed to mutate the VariableTracker, so 298 # it should only be called once. 299 return self.unpack_var_sequence(tx) 300 301 def has_unpack_var_sequence(self, tx) -> bool: 302 try: 303 self.unpack_var_sequence(tx) 304 return True 305 except NotImplementedError: 306 return False 307 308 # NB: don't call force_unpack_var_sequence, especially if it mutates! 309 def has_force_unpack_var_sequence(self, tx) -> bool: 310 return self.has_unpack_var_sequence(tx) 311 312 def inspect_parameter_names(self) -> List[str]: 313 unimplemented(f"inspect_parameter_names: {self}") 314 315 def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": 316 unimplemented(f"hasattr {self.__class__.__name__} {name}") 317 318 def call_function( 319 self, 320 tx: "InstructionTranslator", 321 args: "List[VariableTracker]", 322 kwargs: "Dict[str, VariableTracker]", 323 ) -> "VariableTracker": 324 unimplemented(f"call_function {self} {args} {kwargs}") 325 326 def call_method( 327 self, 328 tx, 329 name, 330 args: "List[VariableTracker]", 331 kwargs: "Dict[str, VariableTracker]", 332 ) -> "VariableTracker": 333 if name == "__len__" and self.has_unpack_var_sequence(tx): 334 assert not (args or kwargs) 335 return variables.ConstantVariable.create(len(self.unpack_var_sequence(tx))) 336 elif ( 337 name == "__getattr__" 338 and len(args) == 1 339 and args[0].is_python_constant() 340 and not kwargs 341 ): 342 return self.var_getattr(tx, args[0].as_python_constant()) 343 unimplemented(f"call_method {self} {name} {args} {kwargs}") 344 345 def set_name_hint(self, name): 346 pass 347 348 def realize(self) -> "VariableTracker": 349 """Used by LazyVariableTracker to build the real VariableTracker""" 350 return self 351 352 def unwrap(self) -> "VariableTracker": 353 """Used by LazyVariableTracker to return the real VariableTracker if it already exists""" 354 return self 355 356 def is_realized(self): 357 """Used by LazyVariableTracker to indicate an unrealized node""" 358 return True 359 360 def next_variable(self, tx): 361 unimplemented(f"next({self})") 362 363 def is_strict_mode(self, tx): 364 return tx.strict_checks_fn and tx.strict_checks_fn(self) 365 366 def __init__( 367 self, 368 *, 369 source: Source = None, 370 mutable_local: MutableLocal = None, 371 ) -> None: 372 super().__init__() 373 self.source = source 374 self.mutable_local = mutable_local 375 376 377def typestr(*objs): 378 if len(objs) == 1: 379 (obj,) = objs 380 if isinstance(obj, VariableTracker): 381 return str(obj) 382 else: 383 return type(obj).__name__ 384 else: 385 return " ".join(map(typestr, objs)) 386