1# mypy: allow-untyped-defs 2import collections 3import dataclasses 4import enum 5from typing import Any, Optional, Union 6 7from torch._guards import ChainedSource, GuardSource, Source 8 9from . import utils 10from .bytecode_transformation import create_call_function, create_instruction 11from .utils import enum_repr 12 13 14# It shouldn't be supported to construct an NNModuleVariable inside an FSDP module, 15# so those cases are omitted intentionally 16 17# represents nn.Modules tracked with NNModuleVariable (specialized is implicit in the variable name) 18_GUARD_SOURCE_SPECIALIZED_NN_MODULE = { 19 GuardSource.LOCAL: GuardSource.LOCAL_SPECIALIZED_NN_MODULE, 20 GuardSource.GLOBAL: GuardSource.GLOBAL_SPECIALIZED_NN_MODULE, 21 GuardSource.LOCAL_SPECIALIZED_NN_MODULE: GuardSource.LOCAL_SPECIALIZED_NN_MODULE, 22 GuardSource.GLOBAL_SPECIALIZED_NN_MODULE: GuardSource.GLOBAL_SPECIALIZED_NN_MODULE, 23 # Just to ensure that guard_source() works 24 GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE, 25 GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE, 26 GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE, 27 GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE, 28} 29 30# represents nn.Modules tracked with UnspecializedNNModuleVariable 31_GUARD_SOURCE_UNSPECIALIZED_NN_MODULE = { 32 GuardSource.LOCAL: GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE, 33 GuardSource.GLOBAL: GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE, 34 GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE, 35 GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE, 36 # this happens for an UnspecializedNNModule submodule on a NNModuleVariable 37 GuardSource.LOCAL_SPECIALIZED_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE, 38 GuardSource.GLOBAL_SPECIALIZED_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE, 39 # Just to ensure that guard_source() works 40 GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE, 41 GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE, 42} 43 44# represents nn.Modules tracked with UnspecializedBuiltinNNModuleVariable 45_GUARD_SOURCE_UNSPECIALIZED_BUILTIN_NN_MODULE = { 46 GuardSource.LOCAL: GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE, 47 GuardSource.GLOBAL: GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE, 48 GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE, 49 GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE, 50 GuardSource.LOCAL_SPECIALIZED_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE, 51 GuardSource.GLOBAL_SPECIALIZED_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE, 52 # Just to ensure that guard_source() works 53 GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE, 54 GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE, 55} 56 57_GUARD_SOURCE_FSDP_MODULE = { 58 GuardSource.LOCAL: GuardSource.LOCAL_FSDP_MODULE, 59 GuardSource.GLOBAL: GuardSource.GLOBAL_FSDP_MODULE, 60 GuardSource.LOCAL_SPECIALIZED_NN_MODULE: GuardSource.LOCAL_FSDP_MODULE, 61 GuardSource.GLOBAL_SPECIALIZED_NN_MODULE: GuardSource.GLOBAL_FSDP_MODULE, 62 GuardSource.LOCAL_FSDP_MODULE: GuardSource.LOCAL_FSDP_MODULE, 63 GuardSource.GLOBAL_FSDP_MODULE: GuardSource.GLOBAL_FSDP_MODULE, 64 GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE: GuardSource.LOCAL_FSDP_MODULE, 65 GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE: GuardSource.GLOBAL_FSDP_MODULE, 66 GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.LOCAL_FSDP_MODULE, 67 GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.GLOBAL_FSDP_MODULE, 68} 69 70 71def is_constant_source(source): 72 if isinstance(source, ConstantSource): 73 return True 74 try: 75 if source.guard_source() == GuardSource.CONSTANT: 76 return True 77 except NotImplementedError: 78 pass 79 80 return False 81 82 83def reconstruct_getitem( 84 source: Union["GetItemSource", "ODictGetItemSource"], codegen, index_is_slice 85): 86 source.base.reconstruct(codegen) 87 if isinstance(source.index, Source): 88 source.index.reconstruct(codegen) 89 else: 90 if index_is_slice: 91 assert isinstance(source, GetItemSource) 92 codegen.append_output(codegen.create_load_const(source.unpack_slice())) 93 else: 94 codegen.append_output(codegen.create_load_const(source.index)) 95 96 97@dataclasses.dataclass(frozen=True) 98class LocalSource(Source): 99 local_name: str 100 cell_or_freevar: bool = False 101 102 def reconstruct(self, codegen): 103 codegen.append_output(codegen.create_load(self.local_name)) 104 105 def guard_source(self): 106 return GuardSource.LOCAL 107 108 def name(self): 109 return f"L[{repr(self.local_name)}]" 110 111 112@dataclasses.dataclass(frozen=True) 113class SyntheticLocalSource(Source): 114 local_name: str 115 116 def reconstruct(self, codegen): 117 codegen.append_output(codegen.create_load(self.local_name)) 118 119 def guard_source(self): 120 return GuardSource.SYNTHETIC_LOCAL 121 122 def name(self): 123 return f"SYNTHETIC_LOCAL[{self.local_name!r}]" 124 125 126@dataclasses.dataclass(frozen=True) 127class RandomValueSource(Source): 128 random_call_index: int 129 130 def guard_source(self): 131 return GuardSource.RANDOM_VALUE 132 133 def reconstruct(self, codegen): 134 codegen.append_output(codegen.create_load(codegen.tx.output.random_values_var)) 135 codegen.append_output(codegen.create_load_const(self.random_call_index)) 136 codegen.append_output(create_instruction("BINARY_SUBSCR")) 137 138 def name(self): 139 return f"random_value_{self.random_call_index}" 140 141 142@dataclasses.dataclass(frozen=True) 143class GlobalSource(Source): 144 global_name: str 145 146 def reconstruct(self, codegen): 147 codegen.append_output(codegen.create_load_global(self.global_name, add=True)) 148 149 def guard_source(self): 150 return GuardSource.GLOBAL 151 152 def name(self): 153 return f"G[{repr(self.global_name)}]" 154 155 156@dataclasses.dataclass(frozen=True) 157class GlobalWeakRefSource(Source): 158 global_name: str 159 160 def reconstruct(self, codegen): 161 codegen.add_push_null( 162 lambda: codegen.append_output( 163 codegen.create_load_global(self.global_name, add=True) 164 ) 165 ) 166 codegen.extend_output(create_call_function(0, False)) 167 168 def guard_source(self): 169 return GuardSource.GLOBAL 170 171 def name(self): 172 return f"G[{repr(self.global_name)}]()" 173 174 175@dataclasses.dataclass(frozen=True) 176class WeakRefCallSource(ChainedSource): 177 def reconstruct(self, codegen): 178 codegen.add_push_null(lambda: self.base.reconstruct(codegen)) 179 codegen.extend_output(create_call_function(0, False)) 180 181 def guard_source(self): 182 return self.base.guard_source() 183 184 def name(self): 185 return f"{self.base.name()}()" 186 187 188@dataclasses.dataclass(frozen=True) 189class AttrSource(ChainedSource): 190 member: str 191 192 def __post_init__(self): 193 assert self.base, "Can't construct an AttrSource without a valid base source" 194 if "." in self.member: 195 member_parts = self.member.split(".") 196 object.__setattr__( 197 self, "base", AttrSource(self.base, ".".join(member_parts[:-1])) 198 ) 199 object.__setattr__(self, "member", member_parts[-1]) 200 201 def reconstruct(self, codegen): 202 self.base.reconstruct(codegen) 203 codegen.extend_output(codegen.create_load_attrs(self.member)) 204 205 def guard_source(self): 206 return self.base.guard_source() 207 208 def name(self): 209 if not self.member.isidentifier(): 210 return f"getattr({self.base.name()}, {self.member!r})" 211 return f"{self.base.name()}.{self.member}" 212 213 214# Represents tensor.grad source. It could be represented by AttrSource as well. 215# But, we could access grad field on tensor directly in C++ without going 216# through the Python bytecodes. Therefore, we use a separate source for grad 217# field. 218@dataclasses.dataclass(frozen=True) 219class GradSource(ChainedSource): 220 member: str = "grad" 221 222 def reconstruct(self, codegen): 223 self.base.reconstruct(codegen) 224 codegen.extend_output(codegen.create_load_attrs(self.member)) 225 226 def guard_source(self): 227 return self.base.guard_source() 228 229 def name(self): 230 return f"{self.base.name()}.{self.member}" 231 232 233@dataclasses.dataclass(frozen=True) 234class ParamBufferSource(AttrSource): 235 def guard_source(self): 236 return _GUARD_SOURCE_SPECIALIZED_NN_MODULE[self.base.guard_source()] 237 238 239# Special AttrSource to differentiate module._buffers or module._parameters 240@dataclasses.dataclass(frozen=True) 241class UnspecializedParamBufferSource(AttrSource): 242 pass 243 244 245# This source is intended to be used in places where a source is needed but it is expected 246# that the symbol will be simplified out later on. Symbols with ephemeral sources are 247# prioritized to be simplified out when e.g. compared against a symbol without an ephemeral 248# source. Guarding on this source is an error. 249# 250# Example: During subclass view fake-ification, any close-over ViewFunc state should be 251# symbolicized / fake-ified to avoid invalid specialization during view replay. This source 252# is useful for symbols utilized in the middle of the view chain that are not expected to be 253# present within the final view shape metadata. 254@dataclasses.dataclass(frozen=True) 255class EphemeralSource(Source): 256 desc: Optional[str] = None 257 258 def guard_source(self): 259 return GuardSource.EPHEMERAL 260 261 def name(self): 262 return f"<ephemeral{': ' + self.desc if self.desc is not None else ''}>" 263 264 def make_guard(self): 265 raise NotImplementedError 266 267 def is_ephemeral(self): 268 return True 269 270 271class TensorProperty(enum.Enum): 272 SIZE = 0 273 STRIDE = 1 274 STORAGE_OFFSET = 2 275 276 def method_name(self): 277 if self is TensorProperty.SIZE: 278 return "size" 279 elif self is TensorProperty.STRIDE: 280 return "stride" 281 elif self is TensorProperty.STORAGE_OFFSET: 282 return "storage_offset" 283 284 285@dataclasses.dataclass(frozen=True) 286class TensorPropertySource(ChainedSource): 287 prop: TensorProperty 288 idx: Optional[int] = None # None for STORAGE_OFFSET 289 290 def __post_init__(self): 291 assert self.base is not None 292 if self.prop is TensorProperty.STORAGE_OFFSET: 293 assert self.idx is None 294 else: 295 assert self.idx is not None 296 297 def reconstruct(self, codegen): 298 def gen_fn(): 299 self.base.reconstruct(codegen) 300 codegen.append_output(codegen.create_load_attr(self.prop.method_name())) 301 302 codegen.add_push_null(gen_fn) 303 if self.idx is not None: 304 codegen.append_output(codegen.create_load_const(self.idx)) 305 codegen.extend_output( 306 create_call_function(1 if self.idx is not None else 0, False) 307 ) 308 309 def guard_source(self): 310 return self.base.guard_source() 311 312 def name(self): 313 if self.prop is TensorProperty.SIZE: 314 return f"{self.base.name()}.size()[{self.idx}]" 315 elif self.prop is TensorProperty.STRIDE: 316 return f"{self.base.name()}.stride()[{self.idx}]" 317 elif self.prop is TensorProperty.STORAGE_OFFSET: 318 assert self.idx is None 319 return f"{self.base.name()}.storage_offset()" 320 else: 321 raise AssertionError(f"unhandled {self.prop}") 322 323 324@dataclasses.dataclass(frozen=True) 325class NegateSource(ChainedSource): 326 def __post_init__(self): 327 assert self.base is not None 328 329 def reconstruct(self, codegen): 330 raise NotImplementedError 331 332 def guard_source(self): 333 return self.base.guard_source() 334 335 def name(self): 336 # NB: use method call so that function stripping regexes work 337 return f"{self.base.name()}.__neg__()" 338 339 340@dataclasses.dataclass(frozen=True) 341class ConvertIntSource(ChainedSource): 342 def __post_init__(self): 343 assert self.base is not None 344 345 def reconstruct(self, codegen): 346 self.base.reconstruct(codegen) 347 348 def guard_source(self): 349 return self.base.guard_source() 350 351 def name(self): 352 return f"cast_symbool_to_symint_guardless({self.base.name()})" 353 354 355@dataclasses.dataclass(frozen=True) 356class FlattenScriptObjectSource(ChainedSource): 357 def __post_init__(self): 358 assert self.base is not None 359 360 def reconstruct(self, codegen): 361 self.base.reconstruct(codegen) 362 363 def guard_source(self): 364 return self.base.guard_source() 365 366 def name(self): 367 return f"{self.base.name()}.__obj_flatten__()" 368 369 370@dataclasses.dataclass(frozen=True) 371class ScriptObjectQualifiedNameSource(ChainedSource): 372 def __post_init__(self): 373 assert self.base is not None 374 375 def reconstruct(self, codegen): 376 self.base.reconstruct(codegen) 377 378 def guard_source(self): 379 return self.base.guard_source() 380 381 def name(self): 382 return f"{self.base.name()}._type().qualified_name()" 383 384 385class AttrProxySource(ChainedSource): 386 def reconstruct(self, codegen): 387 self.base.reconstruct(codegen) 388 389 def guard_source(self): 390 return self.base.guard_source() 391 392 def name(self): 393 return f"{self.base.name()}.get_base()" 394 395 396@dataclasses.dataclass(frozen=True) 397class DefaultsSource(ChainedSource): 398 idx_key: Union[int, str] 399 is_kw: bool = False 400 field: str = dataclasses.field(init=False, repr=False, compare=False) 401 _name: str = dataclasses.field(init=False, repr=False, compare=False) 402 403 def __post_init__(self): 404 assert ( 405 self.base 406 ), "Base must be a valid source in order to properly track and guard this Defaults to its origin." 407 if self.is_kw: 408 assert isinstance(self.idx_key, str) 409 object.__setattr__(self, "field", "__kwdefaults__") 410 object.__setattr__( 411 self, "_name", f"{self.base.name()}.{self.field}['{self.idx_key}']" 412 ) 413 else: 414 assert isinstance(self.idx_key, int) 415 object.__setattr__(self, "field", "__defaults__") 416 object.__setattr__( 417 self, "_name", f"{self.base.name()}.{self.field}[{self.idx_key}]" 418 ) 419 420 def reconstruct(self, codegen): 421 self.base.reconstruct(codegen) 422 codegen.extend_output(codegen.create_load_attrs(self.field)) 423 codegen.append_output(codegen.create_load_const(self.idx_key)) 424 codegen.append_output(create_instruction("BINARY_SUBSCR")) 425 426 def guard_source(self): 427 return self.base.guard_source() 428 429 def name(self): 430 return self._name 431 432 433@dataclasses.dataclass(frozen=True) 434class GetItemSource(ChainedSource): 435 index: Any 436 index_is_slice: bool = False 437 438 def __post_init__(self): 439 assert self.base is not None 440 if isinstance(self.index, slice): 441 # store the hashable version of the slice so the whole GetItemSource is hashable 442 super().__setattr__("index", self.index.__reduce__()) 443 super().__setattr__("index_is_slice", True) 444 445 def reconstruct(self, codegen): 446 reconstruct_getitem(self, codegen, index_is_slice=self.index_is_slice) 447 codegen.append_output(create_instruction("BINARY_SUBSCR")) 448 449 def guard_source(self): 450 return self.base.guard_source() 451 452 def unpack_slice(self): 453 assert self.index_is_slice 454 slice_class, slice_args = self.index 455 return slice_class(*slice_args) 456 457 def name(self): 458 # Index can be of following types 459 # 1) ConstDictKeySource 460 # 2) enum.Enum 461 # 3) index is a slice - example 1:4 462 # 4) index is a constant - example string, integer 463 if isinstance(self.index, Source): 464 if not isinstance(self.index, ConstDictKeySource): 465 raise ValueError( 466 "GetItemSource index must be a constant, enum or ConstDictKeySource" 467 ) 468 return f"{self.base.name()}[{self.index.name()}]" 469 elif self.index_is_slice: 470 return f"{self.base.name()}[{self.unpack_slice()!r}]" 471 elif isinstance(self.index, enum.Enum): 472 return f"{self.base.name()}[{enum_repr(self.index, self.guard_source().is_local())}]" 473 else: 474 return f"{self.base.name()}[{self.index!r}]" 475 476 477@dataclasses.dataclass(frozen=True) 478class ConstDictKeySource(GetItemSource): 479 def is_dict_key(self): 480 return True 481 482 def reconstruct(self, codegen): 483 codegen.add_push_null( 484 lambda: codegen.load_import_from(utils.__name__, "dict_keys_getitem") 485 ) 486 self.base.reconstruct(codegen) 487 codegen.append_output(codegen.create_load_const(self.index)) 488 codegen.extend_output(create_call_function(2, False)) 489 490 def name(self): 491 # The list creation will be CSE'd by PyExprCSEPass 492 return f"list({self.base.name()}.keys())[{self.index!r}]" 493 494 495@dataclasses.dataclass(frozen=True) 496class TupleIteratorGetItemSource(GetItemSource): 497 def reconstruct(self, codegen): 498 codegen.add_push_null( 499 lambda: codegen.load_import_from(utils.__name__, "tuple_iterator_getitem") 500 ) 501 self.base.reconstruct(codegen) 502 codegen.append_output(codegen.create_load_const(self.index)) 503 codegen.extend_output(create_call_function(2, False)) 504 505 def name(self): 506 return f"___tuple_iterator_getitem({self.base.name()}, {self.index!r})" 507 508 509@dataclasses.dataclass(frozen=True) 510class TypeSource(ChainedSource): 511 def __post_init__(self): 512 assert self.base is not None 513 514 def reconstruct(self, codegen): 515 codegen.add_push_null(lambda: codegen.load_import_from("builtins", "type")) 516 self.base.reconstruct(codegen) 517 codegen.extend_output(create_call_function(1, False)) 518 519 def guard_source(self): 520 return self.base.guard_source() 521 522 def name(self): 523 return f"type({self.base.name()})" 524 525 526@dataclasses.dataclass(frozen=True) 527class ODictGetItemSource(ChainedSource): 528 index: Any 529 530 def __post_init__(self): 531 assert self.base is not None 532 533 def reconstruct(self, codegen): 534 codegen.add_push_null( 535 lambda: codegen.append_output( 536 codegen._create_load_const(collections.OrderedDict.__getitem__) 537 ) 538 ) 539 reconstruct_getitem(self, codegen, index_is_slice=False) 540 codegen.extend_output(create_call_function(2, False)) 541 542 def guard_source(self): 543 return self.base.guard_source() 544 545 def name(self): 546 if isinstance(self.index, type): 547 rep = f'__load_module("{self.index.__module__}").{self.index.__qualname__}' 548 return f"___odict_getitem({self.base.name()}, {rep})" 549 elif isinstance(self.index, Source): 550 return f"___odict_getitem({self.base.name()}, {self.index.name()})" 551 else: 552 return f"___odict_getitem({self.base.name()}, {self.index!r})" 553 554 555@dataclasses.dataclass(frozen=True) 556class OptimizerSource(ChainedSource): 557 def reconstruct(self, codegen): 558 self.base.reconstruct(codegen) 559 560 def guard_source(self): 561 return self.base.guard_source() 562 563 def name(self): 564 return self.base.name() 565 566 567@dataclasses.dataclass(frozen=True) 568class NNModuleSource(ChainedSource): 569 def reconstruct(self, codegen): 570 self.base.reconstruct(codegen) 571 572 def guard_source(self): 573 return _GUARD_SOURCE_SPECIALIZED_NN_MODULE[self.base.guard_source()] 574 575 def name(self): 576 return self.base.name() 577 578 579@dataclasses.dataclass(frozen=True) 580class UnspecializedNNModuleSource(NNModuleSource): 581 def guard_source(self): 582 return _GUARD_SOURCE_UNSPECIALIZED_NN_MODULE[self.base.guard_source()] 583 584 585@dataclasses.dataclass(frozen=True) 586class UnspecializedBuiltinNNModuleSource(UnspecializedNNModuleSource): 587 def guard_source(self): 588 return _GUARD_SOURCE_UNSPECIALIZED_BUILTIN_NN_MODULE[self.base.guard_source()] 589 590 591@dataclasses.dataclass(frozen=True) 592class FSDPNNModuleSource(NNModuleSource): 593 def guard_source(self): 594 return _GUARD_SOURCE_FSDP_MODULE[self.base.guard_source()] 595 596 597@dataclasses.dataclass(frozen=True) 598class GlobalStateSource(Source): 599 def name(self): 600 return "" 601 602 def guard_source(self): 603 return GuardSource.GLOBAL 604 605 606@dataclasses.dataclass(frozen=True) 607class TorchFunctionModeStackSource(Source): 608 ind: int 609 610 def name(self): 611 return "" 612 613 def _get_index(self): 614 from .variables.torch_function import TorchFunctionModeStackVariable 615 616 return TorchFunctionModeStackVariable.get_mode_index(self.ind) 617 618 def reconstruct(self, codegen): 619 codegen.add_push_null( 620 lambda: codegen.load_import_from( 621 utils.__name__, "get_torch_function_mode_stack_at" 622 ) 623 ) 624 codegen.extend_output([codegen.create_load_const(self._get_index())]) 625 codegen.extend_output(create_call_function(1, False)) 626 627 def guard_source(self): 628 return GuardSource.GLOBAL 629 630 631@dataclasses.dataclass(frozen=True) 632class ConstantSource(Source): 633 source_name: str 634 635 def reconstruct(self, codegen): 636 codegen.append_output(codegen.create_load_global(self.source_name, add=False)) 637 638 def guard_source(self): 639 return GuardSource.CONSTANT 640 641 def name(self): 642 return self.source_name 643 644 def make_guard(self, fn): 645 raise NotImplementedError 646 647 648@dataclasses.dataclass(frozen=True) 649class NumpyTensorSource(ChainedSource): 650 def name(self) -> str: 651 return f"___from_numpy({self.base.name()})" 652 653 def guard_source(self): 654 return self.base.guard_source() 655 656 def reconstruct(self, codegen): 657 codegen.add_push_null(lambda: codegen.load_import_from("torch", "as_tensor")) 658 self.base.reconstruct(codegen) 659 codegen.extend_output(create_call_function(1, False)) 660 661 662@dataclasses.dataclass(frozen=True) 663class SubclassAttrListSource(ChainedSource): 664 def name(self) -> str: 665 return f"{self.base.name()}.__tensor_flatten__()[0]" 666 667 def guard_source(self): 668 return self.base.guard_source() 669 670 671# NB: We don't expect you to actually ever generate guards against this 672# source, it is ephemeral 673@dataclasses.dataclass(frozen=True) 674class FloatTensorSource(ChainedSource): 675 def name(self) -> str: 676 return f"___as_tensor({self.base.name()})" 677 678 def guard_source(self): 679 return self.base.guard_source() 680 681 682@dataclasses.dataclass(frozen=True) 683class CallMethodItemSource(ChainedSource): 684 def name(self) -> str: 685 return f"{self.base.name()}.item()" 686 687 def guard_source(self): 688 return self.base.guard_source() 689 690 691# This is a synthetic source that is associated with the singleton 692# shape env guard we always register for all frames. We get the actual 693# guard contents from the ambient ShapeEnv 694@dataclasses.dataclass(frozen=True) 695class ShapeEnvSource(Source): 696 def name(self): 697 return "" 698 699 def guard_source(self): 700 return GuardSource.SHAPE_ENV 701 702 703@dataclasses.dataclass(frozen=True) 704class BackwardStateSource(Source): 705 def name(self): 706 return "" 707 708 def guard_source(self): 709 return GuardSource.BACKWARD_STATE 710 711 712def is_from_local_source(source: Source, *, allow_cell_or_freevar=True): 713 if isinstance(source, ChainedSource): 714 return is_from_local_source( 715 source.base, allow_cell_or_freevar=allow_cell_or_freevar 716 ) 717 if not isinstance(source, LocalSource): 718 return False 719 if not allow_cell_or_freevar and source.cell_or_freevar: 720 return False 721 return True 722 723 724def is_from_unspecialized_param_buffer_source(source: Source): 725 if isinstance(source, UnspecializedParamBufferSource): 726 return True 727 if isinstance(source, ChainedSource): 728 return is_from_unspecialized_param_buffer_source(source.base) 729 return False 730 731 732def is_from_flatten_script_object_source(source: Source): 733 if isinstance(source, FlattenScriptObjectSource): 734 return True 735 elif isinstance(source, ChainedSource): 736 return is_from_flatten_script_object_source(source.base) 737 return False 738 739 740def is_from_optimizer_source(source: Source): 741 if isinstance(source, OptimizerSource): 742 return True 743 if isinstance(source, ChainedSource): 744 return is_from_optimizer_source(source.base) 745 return False 746 747 748# TODO: can probably write a generic "test this on everything in the chain" 749# helper 750def is_from_defaults(source: Source): 751 if isinstance(source, DefaultsSource): 752 return True 753 if isinstance(source, ChainedSource): 754 return is_from_defaults(source.base) 755 return False 756 757 758def is_cell_contents(source: Source): 759 return isinstance(source, AttrSource) and source.member == "cell_contents" 760