1# mypy: ignore-errors 2 3import abc 4import collections 5import contextlib 6import dataclasses 7import enum 8import functools 9import inspect 10import itertools 11import logging 12import math 13import operator 14import random 15import re 16import sys 17import types 18import warnings 19import weakref 20from typing import ( 21 Any, 22 Callable, 23 Dict, 24 FrozenSet, 25 List, 26 MutableMapping, 27 NamedTuple, 28 Optional, 29 Set, 30 TYPE_CHECKING, 31 Union, 32) 33 34import torch 35from torch import SymInt 36from torch._guards import GuardSource, TracingContext 37from torch._higher_order_ops.torchbind import call_torchbind 38from torch._ops import HigherOrderOperator 39from torch._streambase import _EventBase, _StreamBase 40from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode 41from torch._subclasses.meta_utils import is_sparse_any, safe_grad 42from torch._utils_internal import justknobs_check 43from torch.fx.experimental._backward_state import BackwardState 44from torch.fx.experimental.symbolic_shapes import ( 45 _constrain_range_for_size, 46 DimDynamic, 47 RelaxedUnspecConstraint, 48 StatefulSymbolicContext, 49 SubclassSymbolicContext, 50 SymbolicContext, 51) 52from torch.fx.immutable_collections import immutable_dict, immutable_list 53from torch.utils._python_dispatch import is_traceable_wrapper_subclass 54from torch.utils._sympy.value_ranges import ValueRanges 55from torch.utils.weak import TensorWeakRef 56 57from .. import config, mutation_guard, replay_record, trace_rules 58from ..device_interface import get_registered_device_interfaces 59from ..exc import InternalTorchDynamoError, unimplemented 60from ..guards import GuardBuilder, install_guard, make_dupe_guard 61from ..side_effects import SideEffects 62from ..source import ( 63 AttrProxySource, 64 AttrSource, 65 CallMethodItemSource, 66 ConstantSource, 67 ConstDictKeySource, 68 ConvertIntSource, 69 FloatTensorSource, 70 GetItemSource, 71 GradSource, 72 is_cell_contents, 73 is_constant_source, 74 is_from_defaults, 75 is_from_optimizer_source, 76 LocalSource, 77 NumpyTensorSource, 78 OptimizerSource, 79 RandomValueSource, 80 Source, 81 SubclassAttrListSource, 82 TupleIteratorGetItemSource, 83) 84from ..trace_rules import ( 85 is_callable_allowed, 86 is_numpy, 87 is_numpy_dtype, 88 is_numpy_type_info, 89) 90from ..utils import ( 91 _extract_tensor_dict, 92 build_checkpoint_variable, 93 clone_input, 94 common_constant_types, 95 get_fake_value, 96 get_locals_to_steal, 97 get_static_address_type, 98 is_frozen_dataclass, 99 is_function_or_wrapper, 100 is_lru_cache_wrapped_function, 101 is_namedtuple, 102 is_parameter_freezing, 103 is_typing, 104 is_utils_checkpoint, 105 is_wrapper_or_member_descriptor, 106 istype, 107 odict_values, 108 proxy_args_kwargs, 109 set_example_value, 110 tensor_always_has_static_shape, 111 tuple_iterator, 112 tuple_iterator_getitem, 113 tuple_iterator_len, 114 unwrap_with_attr_name_if_wrapper, 115 wrap_fake_exception, 116) 117from .base import MutableLocal, typestr, VariableTracker, VariableTrackerMeta 118from .constant import ConstantVariable, EnumVariable 119from .ctx_manager import ( 120 AutocastModeVariable, 121 EventVariable, 122 NullContextVariable, 123 PreserveVersionContextVariable, 124 StreamContextVariable, 125 StreamVariable, 126) 127from .dicts import ( 128 ConstDictVariable, 129 CustomizedDictVariable, 130 DefaultDictVariable, 131 HFPretrainedConfigVariable, 132 PythonSysModulesVariable, 133 SetVariable, 134) 135from .distributed import ( 136 DeviceMeshVariable, 137 PlacementClassVariable, 138 PlacementVariable, 139 ProcessGroupVariable, 140 WorldMetaClassVariable, 141) 142from .functions import ( 143 CollectiveFunctionRewriteVariable, 144 FunctoolsPartialVariable, 145 TritonKernelVariable, 146 UserFunctionVariable, 147 UserMethodVariable, 148 WrapperUserFunctionVariable, 149) 150from .higher_order_ops import TorchHigherOrderOperatorVariable 151from .iter import ItertoolsVariable 152from .lazy import LazyVariableTracker 153from .lists import ( 154 BaseListVariable, 155 ListVariable, 156 NamedTupleVariable, 157 RangeVariable, 158 RestrictedListSubclassVariable, 159 SizeVariable, 160 SliceVariable, 161 TupleIteratorVariable, 162 TupleVariable, 163) 164from .misc import ( 165 AutogradEngineVariable, 166 AutogradFunctionContextVariable, 167 AutogradFunctionVariable, 168 ComptimeVariable, 169 DebuggingVariable, 170 DelayGraphBreakVariable, 171 GetAttrVariable, 172 GetSetDescriptorVariable, 173 InspectSignatureVariable, 174 LambdaVariable, 175 LoggingLoggerVariable, 176 MethodWrapperVariable, 177 NumpyDTypeVariable, 178 NumpyTypeInfoVariable, 179 NumpyVariable, 180 PythonModuleVariable, 181 RandomClassVariable, 182 RandomVariable, 183 RegexPatternVariable, 184 SavedTensorBox, 185 TorchVersionVariable, 186 TypingVariable, 187) 188from .nn_module import ( 189 FSDPManagedNNModuleVariable, 190 UnspecializedBuiltinNNModuleVariable, 191 UnspecializedNNModuleVariable, 192) 193from .optimizer import OptimizerVariable 194from .script_object import TorchScriptObjectVariable 195from .sdpa import SDPAParamsVariable 196from .tensor import ( 197 NumpyNdarrayVariable, 198 SymNodeVariable, 199 TensorSubclassVariable, 200 TensorVariable, 201 UnspecializedPythonVariable, 202) 203from .torch import TorchCtxManagerClassVariable, TorchInGraphFunctionVariable 204from .torch_function import ( 205 build_torch_function_fn, 206 TensorWithTFOverrideVariable, 207 TorchFunctionModeVariable, 208) 209from .user_defined import ( 210 FrozenDataClassVariable, 211 KeyedJaggedTensorVariable, 212 MutableMappingVariable, 213 SourcelessGraphModuleVariable, 214 UserDefinedClassVariable, 215 UserDefinedObjectVariable, 216 WeakRefVariable, 217) 218 219 220try: 221 import numpy as np 222except ModuleNotFoundError: 223 np = None 224 225 226if TYPE_CHECKING: 227 from torch._dynamo.symbolic_convert import InstructionTranslator 228 229 230log = logging.getLogger(__name__) 231static_inputs_log = torch._logging.getArtifactLogger( 232 __name__, "cudagraph_static_inputs" 233) 234 235 236DimList = List 237 238 239def safe_has_grad(t): 240 with warnings.catch_warnings(): 241 warnings.filterwarnings("ignore", "The .grad attribute of a Tensor") 242 return hasattr(t, "grad") 243 244 245class _missing: 246 pass 247 248 249@dataclasses.dataclass 250class GraphArg: 251 source: Source 252 # TODO: storing a SymInt here but not a FakeTensor is a pretty strange 253 # thing to do. Probably should have example (which stores an int) and 254 # fake_example 255 _example: Union[TensorWeakRef, torch.SymInt] 256 # When True, this indicates that this GraphArg is a Python quantity (e.g., 257 # a float or int) which we pass to the FX graph as a Tensor. This 258 # controls how we codegen calls into the Dynamo graph: we will call 259 # torch.as_tensor on the quantity before passing it in. 260 # 261 # Note that we typically do not pass dynamic integers as tensors, because 262 # they will most frequently just be used for size computation. But this 263 # is a policy decision that we can change our mind on; in particular, when 264 # an int comes from a random number generator (e.g., random.randint), we 265 # DO pass it as a tensor. 266 # 267 # It's also worth noting that our current tracing rules for 268 # pass_arg_as_tensor as subtly broken: we just pun the variable as a 269 # 0d scalar Tensor and pray that the semantics are the same. Which they 270 # often are, but not necessarily. ezyang(May 2024) plans to fix this 271 # soon. 272 pass_arg_as_tensor: bool 273 fake_tensor: Optional[torch._subclasses.fake_tensor.FakeTensor] 274 # UnspecializedPythonVariable often masquerades as a tensor. 275 # We MUST NOT generate shape guard code 276 # that actually tries to access tensor properties on these values. 277 # is_tensor lets us tell if this graph arg actually is a tensor 278 # or not. 279 is_tensor: bool = True 280 # Sometimes, the Tensor we pass to example is freshly allocated (smh). 281 # Then we cannot only keep a weak reference to it. This lets you 282 # stash a strong reference too. 283 example_strong_ref: Optional[torch.Tensor] = None 284 285 @property 286 def example(self): 287 if isinstance(self._example, TensorWeakRef): 288 r = self._example() 289 assert r is not None 290 return r 291 else: 292 return self._example 293 294 def __post_init__(self): 295 if isinstance(self._example, torch.Tensor): 296 self._example = TensorWeakRef(self._example) 297 assert is_fake(self.fake_tensor) 298 299 def reconstruct(self, codegen): 300 self.source.reconstruct(codegen) 301 302 def erase(self): 303 self._example = None 304 self.example_strong_ref = None 305 306 def __eq__(self, other): 307 return self.source.name() == other.source.name() 308 309 310class BackwardStateGraphArg(GraphArg): 311 def __init__(self) -> None: 312 super().__init__( 313 source=None, 314 _example=BackwardState(), 315 pass_arg_as_tensor=False, 316 fake_tensor=None, 317 is_tensor=False, 318 ) 319 320 def reconstruct(self, codegen): 321 assert codegen.tx.output.backward_state_var 322 codegen.add_push_null( 323 lambda: codegen.load_import_from(BackwardState.__module__, "BackwardState") 324 ) 325 codegen.call_function(0, False) 326 codegen.dup_top() 327 codegen.store(codegen.tx.output.backward_state_var) 328 329 330@dataclasses.dataclass 331class FrameStateSizeEntry: 332 scalar: Optional[int] 333 size: Optional[List[int]] 334 stride: Optional[List[int]] 335 336 337# All class-based iterators in itertools 338# NOTE: use id() because some objects are not hashable, it will raise error during lookup 339ITERTOOLS_TYPE_IDS: FrozenSet[int] = frozenset( 340 id(member) 341 for name, member in vars(itertools).items() 342 if not name.startswith("_") and inspect.isclass(member) 343) 344# Will be updated later in substitute_in_graph in torch/_dynamo/polyfills/itertools.py 345ITERTOOLS_POLYFILLED_TYPE_IDS: Set[int] = set() 346 347 348class VariableBuilder: 349 """Wrap a python value in a VariableTracker() instance""" 350 351 def __init__( 352 self, 353 tx, 354 source: Source, 355 ) -> None: 356 assert ( 357 source is not None 358 ), "Consider SourcelessBuilder for ephemeral objects, usually objects created locally." 359 assert TracingContext.try_get() is not None, "Expected active TracingContext" 360 super().__init__() 361 self.tx = tx 362 self.source = source 363 self.name = source.name() 364 365 def __call__(self, value): 366 if value in self.tx.output.side_effects: 367 side_effect_result = self.tx.output.side_effects[value] 368 dup_guard = make_dupe_guard(self.source, side_effect_result.source) 369 if dup_guard: 370 self.install_guards(dup_guard) 371 return side_effect_result 372 373 cached_vt = self.tx.output.variable_tracker_cache.lookup(value, self.source) 374 if cached_vt: 375 return cached_vt 376 377 vt = self._wrap(value) 378 vt.source = self.source 379 if ( 380 self._can_lift_attrs_to_inputs(vt) 381 and value not in self.tx.output.side_effects 382 and not is_wrapper_or_member_descriptor(value) 383 ): 384 vt = self.tx.output.side_effects.track_object_existing(value, vt) 385 386 self.tx.output.variable_tracker_cache.add(value, self.source, vt) 387 return vt 388 389 def _can_lift_attrs_to_inputs(self, vt): 390 return type(vt) in { 391 TensorVariable, 392 TensorWithTFOverrideVariable, 393 UserDefinedObjectVariable, 394 NumpyNdarrayVariable, 395 } 396 397 @staticmethod 398 @functools.lru_cache(None) 399 def _common_constants(): 400 return { 401 # We zero-one specialize shapes, so specialize these constants 402 # too 403 0, 404 1, 405 # NB: There used to be more constants here, but honestly it was 406 # pretty confusing. Note we specialize floats by default, and 407 # DON'T specialize ints by default. This all only matters with 408 # dynamic_shapes 409 } 410 411 def get_source(self): 412 return self.source 413 414 def install_guards(self, *guards): 415 source = self.get_source() 416 if ( 417 isinstance(source, ConstantSource) 418 or source.guard_source() == GuardSource.CONSTANT 419 ): 420 return None 421 install_guard(*[source.make_guard(guard) for guard in guards], skip=1) 422 return {} 423 424 def set_source_and_track_mutable(self, value, var): 425 assert isinstance(var, VariableTracker) 426 var.source = self.source 427 return self.tx.output.side_effects.track_mutable(value, var) 428 429 @classmethod 430 @functools.lru_cache(None) 431 def _type_dispatch(cls): 432 # NB: Careful not to close over self to avoid ref cycle from lru_cache 433 entries = [ 434 ( 435 ( 436 torch.Tensor, 437 torch.nn.Parameter, 438 torch._subclasses.FakeTensor, 439 torch._subclasses.functional_tensor.FunctionalTensor, 440 ), 441 cls.wrap_tensor, 442 ), 443 ( 444 (tuple, list, odict_values, collections.deque, torch.Size), 445 cls.wrap_listlike, 446 ), 447 (tuple_iterator, cls.wrap_tuple_iterator), 448 ((slice, range), cls.wrap_slice_range), 449 (tuple(common_constant_types), cls.wrap_literal), 450 (re.Pattern, cls.wrap_regex_pattern), 451 (weakref.ReferenceType, cls.wrap_weakref), 452 (torch.utils.hooks.RemovableHandle, cls.wrap_removable_handle), 453 (torch.jit.ScriptFunction, cls.wrap_jit_function), 454 ] 455 456 if config.trace_numpy and np: 457 entries.append((np.ndarray, cls.wrap_numpy_ndarray)) 458 459 result = {} 460 for ts, fn in entries: 461 for t in ts if isinstance(ts, tuple) else (ts,): 462 assert t not in result 463 result[t] = fn 464 465 return result 466 467 def wrap_regex_pattern(self, value: re.Pattern): 468 # TODO(jansel): something like a REPR_MATCH might be more robust here 469 self.install_guards(GuardBuilder.ID_MATCH) 470 return RegexPatternVariable(value) 471 472 def wrap_weakref(self, value: weakref.ReferenceType): 473 self.install_guards(GuardBuilder.TYPE_MATCH) 474 return WeakRefVariable(value, source=self.source) 475 476 def wrap_removable_handle(self, value): 477 # This means that the removable handle was created in some other frame. 478 # Our current infra requires the hook to be registered and removed in 479 # the same frame. So graph break. 480 # Related test - PYTORCH_TEST_WITH_DYNAMO=1 python test/test_autograd.py -k TestAutograd.test_hooks 481 unimplemented("unregistered hook removable handle") 482 483 def wrap_jit_function(self, value): 484 self.install_guards(GuardBuilder.TYPE_MATCH) 485 return WrapperUserFunctionVariable( 486 value, "_torchdynamo_inline", source=self.source 487 ) 488 489 @classmethod 490 @functools.lru_cache(None) 491 def _id_dispatch( 492 cls, 493 ) -> Dict[int, Callable[["VariableBuilder", Any], VariableTracker]]: 494 from ..comptime import comptime 495 496 entries = [ 497 ( 498 inspect.signature, 499 lambda self, value: LambdaVariable( 500 InspectSignatureVariable.create, 501 source=self.source, 502 **self.install_guards(GuardBuilder.CLOSURE_MATCH), 503 ), 504 ), 505 (comptime, lambda self, value: ComptimeVariable()), 506 ( 507 dataclasses.fields, 508 lambda self, value: LambdaVariable( 509 _dataclasses_fields_lambda, 510 source=self.source, 511 **self.install_guards(GuardBuilder.FUNCTION_MATCH), 512 ), 513 ), 514 (torch.__version__, lambda self, value: TorchVersionVariable()), 515 ] 516 517 result = {} 518 for ts, fn in entries: 519 for t in ts if isinstance(ts, (tuple, list)) else (ts,): 520 assert t not in result 521 result[id(t)] = fn 522 523 return result 524 525 def _wrap(self, value): 526 # import here to avoid circular dependencies 527 from torch.utils._triton import has_triton 528 529 if has_triton(): 530 from triton.runtime.autotuner import Autotuner 531 from triton.runtime.jit import JITFunction 532 else: 533 534 class JITFunction: 535 pass 536 537 class Autotuner: 538 pass 539 540 # Handle exact type() match 541 type_dispatch = self._type_dispatch().get(type(value)) 542 if type_dispatch is not None: 543 return type_dispatch(self, value) 544 545 # Handle exact id() match 546 id_dispatch = self._id_dispatch().get(id(value)) 547 if id_dispatch is not None: 548 return id_dispatch(self, value) 549 550 # Note - There are some nested values where types mismatch! 551 # We want to get those out and wrap those. 552 if is_function_or_wrapper(value): 553 value = inspect.getattr_static(value, "_torchdynamo_inline", value) 554 555 # Everything else (NB: order matters!) 556 if is_traceable_wrapper_subclass(value) or istype( 557 value, config.traceable_tensor_subclasses 558 ): 559 return self.wrap_tensor(value) 560 elif is_namedtuple(value): 561 return self.wrap_listlike(value) 562 563 elif value is torch.utils._pytree.SUPPORTED_NODES: 564 # For SUPPORTED_NODES, we guard on the dictionary version (PEP509) 565 # under the assumption that the values themselves don't change. 566 self.install_guards(GuardBuilder.DICT_VERSION) 567 568 # The keys on the SUPPORTED_NODES can be arbitrary, so save on the 569 # key order. 570 self.tx.output.guard_on_key_order.add(self.source.name()) 571 result = { 572 ConstantVariable.create(k): UserDefinedObjectVariable( 573 v, 574 source=GetItemSource( 575 self.get_source(), ConstDictKeySource(self.get_source(), i) 576 ), 577 ) 578 for i, (k, v) in enumerate(value.items()) 579 } 580 return ConstDictVariable(result, type(value)) 581 elif value is sys.modules: 582 self.install_guards(GuardBuilder.FUNCTION_MATCH) 583 return PythonSysModulesVariable(source=self.source) 584 elif CustomizedDictVariable.is_matching_cls_hf(type(value)): 585 self.install_guards(GuardBuilder.TYPE_MATCH) 586 result = CustomizedDictVariable.wrap(self, value) 587 result.source = self.source 588 return self.tx.output.side_effects.track_object_existing(value, result) 589 elif istype(value, (dict, collections.defaultdict, collections.OrderedDict)): 590 self.install_guards(GuardBuilder.SEQUENCE_LENGTH) 591 592 # Optimisation for the common case strings, ints, etc 593 all_const = all(ConstantVariable.is_literal(k) for k in value.keys()) 594 if all_const: 595 # TODO(anijain2305) - Do we have to guard on all the keys? Can 596 # keys be guarded lazily, similar to values? 597 self.install_guards(GuardBuilder.DICT_CONST_KEYS) 598 else: 599 # Guard on the key order 600 # This is not ideal, i.e., there is no need to guard on the key 601 # order. But we guard on the key order because of the complexity 602 # 603 # 1) For non-constant objects, we can't save the key in the 604 # guard context because it can be memory heavy. We can add 605 # weakrefs but this complicates the accesses. 606 # 607 # 2) For non-constant objects, we also have to guard on the keys 608 # (like TENSOR_MATCH on tensor). We might also have guards on 609 # the attributes of the keys (like tensor.grad). To make this 610 # work in tree strucutre is complicated. 611 # 612 # So, instead we guard on the key order. While guarding on key 613 # order, we just save the indices and use it to access keys and 614 # values. Indices are cheap to save. 615 self.tx.output.guard_on_key_order.add(self.source.name()) 616 617 # We need all the keys to be hashable. We do this within the 618 # _HashableTracker class in dicts.py 619 def build_key_value(i, k, v): 620 if all_const: 621 key = ConstantVariable.create(k) 622 source_key = k 623 else: 624 source_key = ConstDictKeySource(self.get_source(), i) 625 key = LazyVariableTracker.create(k, source_key) 626 627 source_value = GetItemSource(self.get_source(), source_key) 628 value = LazyVariableTracker.create(v, source_value) 629 630 return key, value 631 632 result = dict( 633 build_key_value(i, k, v) for i, (k, v) in enumerate(value.items()) 634 ) 635 636 if istype(value, collections.defaultdict): 637 factory_source = AttrSource(self.source, "default_factory") 638 result = DefaultDictVariable( 639 result, 640 type(value), 641 default_factory=VariableBuilder(self.tx, factory_source)( 642 value.default_factory 643 ), 644 source=self.source, 645 ) 646 else: 647 result = ConstDictVariable(result, type(value), source=self.source) 648 649 return self.set_source_and_track_mutable(value, result) 650 elif isinstance(value, torch.nn.Module): 651 return self.wrap_module(value) 652 elif ConstantVariable.is_literal(value): # non-atomic literals 653 return self.wrap_literal(value) 654 elif isinstance(value, torch.overrides.TorchFunctionMode): 655 var = TorchFunctionModeVariable(value, source=self.source) 656 self.tx.output.side_effects.track_object_existing(value, var) 657 return var 658 elif istype(value, frozenset) and ( 659 ConstantVariable.is_literal(x) for x in value 660 ): 661 # For frozenset, we can guard by object ID instead of value 662 # equality, this allows us to handle non-literal values 663 self.install_guards(GuardBuilder.ID_MATCH) 664 return ConstantVariable.create(value=value, source=self.source) 665 elif isinstance(value, enum.Enum): 666 self.install_guards(GuardBuilder.ID_MATCH) 667 return EnumVariable(value=value, source=self.source) 668 elif DebuggingVariable.is_reorderable_logging_function(value): 669 # Put this above builtin_callable so that print() can be handled 670 # along with other builtin debugging functions 671 self.install_guards(GuardBuilder.BUILTIN_MATCH) 672 return DebuggingVariable(value, source=self.source) 673 elif isinstance(value, logging.Logger): 674 self.install_guards(GuardBuilder.FUNCTION_MATCH) 675 return LoggingLoggerVariable(value, source=self.source) 676 elif is_utils_checkpoint(value): 677 return build_checkpoint_variable(source=self.source) 678 elif isinstance(value, functools.partial): 679 func_src = AttrSource(self.get_source(), "func") 680 func_obj = VariableBuilder(self.tx, func_src)(value.func) 681 682 args = [] 683 args_source = AttrSource(self.get_source(), "args") 684 for i, arg in enumerate(value.args): 685 args.append( 686 VariableBuilder(self.tx, GetItemSource(args_source, i))(arg) 687 ) 688 689 keywords = {} 690 keywords_source = AttrSource(self.get_source(), "keywords") 691 for k, v in value.keywords.items(): 692 if not ConstantVariable.is_literal(k): 693 unimplemented("functools.partial with non-literal keyword") 694 keywords[k] = VariableBuilder( 695 self.tx, GetItemSource(keywords_source, k) 696 )(v) 697 698 install_guard( 699 self.get_source().make_guard(GuardBuilder.TYPE_MATCH), 700 keywords_source.make_guard(GuardBuilder.DICT_KEYS), 701 args_source.make_guard(GuardBuilder.SEQUENCE_LENGTH), 702 ) 703 return FunctoolsPartialVariable(func_obj, args, keywords) 704 elif is_typing(value): 705 # typing.List, typing.Mapping, etc. 706 self.install_guards(GuardBuilder.ID_MATCH) 707 return TypingVariable( 708 value, 709 source=self.source, 710 ) 711 elif np is not None and isinstance(value, np.generic): 712 # numpy array scalars: convert to 0D arrays 713 return self.wrap_numpy_ndarray(np.asarray(value)) 714 elif is_numpy(value): 715 assert np 716 self.install_guards( 717 GuardBuilder.FUNCTION_MATCH 718 if callable(value) 719 else GuardBuilder.TYPE_MATCH 720 ) 721 return NumpyVariable(value, source=self.source) 722 elif is_numpy_dtype(value): 723 self.install_guards(GuardBuilder.ID_MATCH) 724 return NumpyDTypeVariable(value, source=self.source) 725 elif is_numpy_type_info(value): 726 if isinstance(value, np.iinfo): 727 self.install_guards(GuardBuilder.TYPE_MATCH) 728 dt_source = AttrSource(self.source, "dtype") 729 install_guard(dt_source.make_guard(GuardBuilder.ID_MATCH)) 730 else: 731 self.install_guards(GuardBuilder.ID_MATCH) 732 return NumpyTypeInfoVariable(value, source=self.source) 733 # NB: These can't be put in type_dispatch, they have to run later 734 elif CollectiveFunctionRewriteVariable.can_rewrite(value): 735 self.install_guards(GuardBuilder.FUNCTION_MATCH) 736 return CollectiveFunctionRewriteVariable.create( 737 self.tx, 738 value, 739 source=self.source, 740 ) 741 elif istype(value, torch.autograd.function.FunctionMeta): 742 self.install_guards(GuardBuilder.FUNCTION_MATCH) 743 return AutogradFunctionVariable( 744 value, 745 source=self.source, 746 ) 747 elif isinstance(value, torch.autograd.function.FunctionCtx): 748 actual_saved_tensors = None 749 try: 750 actual_saved_tensors = value.saved_tensors 751 except RuntimeError: 752 pass 753 754 saved_tensors = [] 755 guards = [self.source.make_guard(GuardBuilder.TYPE_MATCH)] 756 if isinstance(actual_saved_tensors, tuple): 757 saved_tensors_source = AttrSource(self.source, "saved_tensors") 758 guards.append( 759 saved_tensors_source.make_guard(GuardBuilder.SEQUENCE_LENGTH) 760 ) 761 for i, v in enumerate(actual_saved_tensors): 762 saved_tensors.append( 763 VariableBuilder( 764 self.tx, GetItemSource(saved_tensors_source, i) 765 )(v) 766 ) 767 install_guard(*guards) 768 769 return self.tx.output.side_effects.track_object_existing( 770 value, 771 AutogradFunctionContextVariable( 772 value, 773 source=self.source, 774 saved_tensors=SavedTensorBox(saved_tensors), 775 ), 776 ) 777 elif ( 778 isinstance(value, types.MethodType) 779 and istype( 780 getattr(value, "__self__", None), torch.autograd.function.FunctionMeta 781 ) 782 and getattr(value, "__name__", "") == "apply" 783 and value == getattr(value.__self__, "apply", None) 784 ): 785 # handle aliased autograd function `apply` calls 786 self.install_guards(GuardBuilder.FUNCTION_MATCH) 787 return GetAttrVariable( 788 AutogradFunctionVariable( 789 value.__self__, source=AttrSource(self.source, member="__self__") 790 ), 791 "apply", 792 ) 793 elif isinstance(value, torch._C._ImperativeEngine): 794 self.install_guards(GuardBuilder.ID_MATCH) 795 return AutogradEngineVariable(value, source=self.source) 796 elif ( 797 value 798 is torch._dynamo.external_utils.FakeCompiledAutogradEngine._exec_final_callbacks_stub 799 ): 800 self.install_guards(GuardBuilder.FUNCTION_MATCH) 801 return LambdaVariable( 802 lambda: UserFunctionVariable( 803 torch._dynamo.external_utils.FakeCompiledAutogradEngine.exec_final_callbacks, 804 ).call_function( 805 self.tx, 806 (self.tx.output.side_effects.get_ca_final_callbacks_var(),), 807 {}, 808 ) 809 ) 810 elif callable(value) and trace_rules.lookup_callable(value) is not None: 811 if is_callable_allowed(value): 812 self.tx.output.has_user_defined_allowed_in_graph = True 813 return trace_rules.lookup_callable(value).create_with_source( 814 value, source=self.source 815 ) 816 elif np and isinstance(value, np.number): 817 return self.wrap_unspecialized_primitive(value) 818 elif HFPretrainedConfigVariable.is_matching_object(value): 819 self.install_guards(GuardBuilder.TYPE_MATCH) 820 return HFPretrainedConfigVariable(value) 821 elif isinstance(value, HigherOrderOperator): 822 self.install_guards(GuardBuilder.TYPE_MATCH, GuardBuilder.NAME_MATCH) 823 return TorchHigherOrderOperatorVariable.make(value, source=self.source) 824 elif isinstance(value, torch.cuda.StreamContext): 825 self.install_guards(GuardBuilder.ID_MATCH) 826 stream_source = AttrSource(self.source, "stream") 827 stream_var = VariableBuilder(self.tx, stream_source)(value.stream) 828 return StreamContextVariable.create(self.tx, stream_var) 829 elif isinstance(value, _StreamBase): 830 self.install_guards(GuardBuilder.ID_MATCH) 831 stream_proxy = self.tx.output.create_proxy( 832 "call_function", 833 torch.cuda.Stream, 834 (), 835 { 836 "stream_id": value.stream_id, 837 "device_index": value.device_index, 838 "device_type": value.device_type, 839 }, 840 ) 841 set_example_value(stream_proxy.node, value) 842 return StreamVariable( 843 stream_proxy, 844 value, 845 value.device, 846 source=self.source, 847 ) 848 elif isinstance(value, (torch._C._SDPAParams)): 849 self.install_guards(GuardBuilder.TYPE_MATCH) 850 return SDPAParamsVariable.create(self.tx, value, self.source) 851 elif isinstance(value, _EventBase): 852 self.install_guards(GuardBuilder.ID_MATCH) 853 torch._dynamo.utils.store_user_object_weakref(value) 854 event_proxy = self.tx.output.create_proxy( 855 "call_function", 856 torch._dynamo.utils.get_user_object_from_id, 857 (id(value),), 858 {}, 859 ) 860 set_example_value(event_proxy.node, value) 861 return EventVariable( 862 event_proxy, 863 value, 864 source=self.source, 865 ) 866 elif ( 867 isinstance(value, torch._C._TensorMeta) 868 and value in config.traceable_tensor_subclasses 869 ): 870 return TensorSubclassVariable(value, source=self.source) 871 elif ( 872 istype(value, contextlib.nullcontext) 873 and inspect.getattr_static(value, "enter_result", None) is None 874 ): 875 self.install_guards(GuardBuilder.TYPE_MATCH) 876 return NullContextVariable(source=self.source) 877 elif KeyedJaggedTensorVariable.is_matching_object(value): 878 self.install_guards(GuardBuilder.TYPE_MATCH) 879 result = KeyedJaggedTensorVariable(value, source=self.source) 880 # TODO: this doing it manually is bad 881 return self.tx.output.side_effects.track_object_existing(value, result) 882 elif isinstance(value, torch.optim.Optimizer): 883 self.install_guards(GuardBuilder.ID_MATCH) 884 self.source = OptimizerSource(self.source) 885 return OptimizerVariable(value, source=self.source) 886 elif WorldMetaClassVariable.is_group_member_type(value): 887 return WorldMetaClassVariable(value, source=self.source) 888 elif ProcessGroupVariable.is_process_group(value): 889 self.install_guards(GuardBuilder.ID_MATCH) 890 return ProcessGroupVariable(value, source=self.source) 891 elif DeviceMeshVariable.is_device_mesh(value): 892 # TODO: see if we need to add custom guard instead of a simple ID_MATCH 893 self.install_guards(GuardBuilder.EQUALS_MATCH) 894 return DeviceMeshVariable(value, source=self.source) 895 elif PlacementClassVariable.is_placement_type(value): 896 # TODO: see if we need to add custom guard instead of a simple ID_MATCH 897 self.install_guards(GuardBuilder.ID_MATCH) 898 return PlacementClassVariable(value, source=self.source) 899 elif PlacementVariable.is_placement(value): 900 # TODO: see if we need to add custom guard instead of a simple ID_MATCH 901 self.install_guards(GuardBuilder.EQUALS_MATCH) 902 return PlacementVariable( 903 value, 904 source=self.source, 905 ) 906 elif ( 907 id(value) in ITERTOOLS_TYPE_IDS 908 and id(value) not in ITERTOOLS_POLYFILLED_TYPE_IDS 909 ): 910 self.install_guards(GuardBuilder.FUNCTION_MATCH) 911 return ItertoolsVariable(value, source=self.source) 912 elif isinstance(value, torch.SymBool): 913 # Note: the idea here is to re-use the infra we've built for SymInt by simulating the 914 # user provided SymBool with a SymInt in dynamo. 915 916 # Concretely, 917 # 1. We create a SymInt in dynamo's shape_env, whose source is constructed as ConvertIntSource(self.source). 918 # so that guards on the SymInts can be effectively applied on the original SymBool in user program. 919 # 2. We create a SymBool based on the SymInt in dynamo's ShapeEnv. Because the original user program 920 # depends on the value being a SymBool. This allows dynamo to interpret the user's program correctly. 921 922 new_source = ConvertIntSource(self.source) 923 if value.node.has_hint(): 924 value_hint = value.node.require_hint() 925 926 new_symint = ( 927 self.tx.output.shape_env.create_unspecified_symint_and_symbol( 928 int(value_hint), 929 new_source, 930 dynamic_dim=DimDynamic.DYNAMIC, 931 ) 932 ) 933 else: 934 # We need to create an unbacked symint to replace the unbacked symbool. 935 new_symint = self.tx.output.shape_env.create_unbacked_symint() 936 937 sym_node_proxy = self.tx.output.root_tracer.create_graph_input( 938 re.sub(r"[^a-zA-Z0-9]+", "_", self.name), 939 type(new_symint), 940 source=new_source, 941 ) 942 943 sym_node_proxy.node.meta["grapharg"] = GraphArg( 944 new_source, 945 new_symint, 946 False, 947 None, 948 is_tensor=False, 949 example_strong_ref=new_symint, 950 ) 951 # We bind the new_symint to graph input. 952 set_example_value(sym_node_proxy.node, new_symint) 953 self.tx.output.bound_symbols.add(new_symint.node.expr) 954 self.tx.output.tracked_fakes.append( 955 TrackedFake(new_symint, new_source, None) 956 ) 957 return SymNodeVariable( 958 sym_node_proxy, 959 new_symint == 1, 960 ) 961 elif isinstance(value, (JITFunction, Autotuner)): 962 self.install_guards(GuardBuilder.ID_MATCH) 963 return TritonKernelVariable( 964 value, 965 None, # No kernel idx provided 966 None, # No grid provided 967 source=self.source, 968 ) 969 elif isinstance(value, torch.amp.autocast_mode.autocast): 970 self.install_guards(GuardBuilder.ID_MATCH) 971 return AutocastModeVariable( 972 target_values=[ 973 value.device, 974 value.fast_dtype, 975 value._enabled, 976 value._cache_enabled, 977 ], 978 source=self.source, 979 ) 980 elif TorchCtxManagerClassVariable.is_matching_cls(value): 981 self.install_guards(GuardBuilder.FUNCTION_MATCH) 982 return TorchCtxManagerClassVariable(value, source=self.source) 983 elif inspect.getattr_static(value, "__script_if_tracing_wrapper", False): 984 self.install_guards(GuardBuilder.TYPE_MATCH) 985 return WrapperUserFunctionVariable( 986 value, "__original_fn", source=self.source 987 ) 988 elif is_lru_cache_wrapped_function(value): 989 self.install_guards(GuardBuilder.TYPE_MATCH) 990 return WrapperUserFunctionVariable(value, "__wrapped__", source=self.source) 991 elif is_function_or_wrapper(value): 992 value, attr_name = unwrap_with_attr_name_if_wrapper(value) 993 # For these wrappers, Dynamo points to the wrapped function, 994 # so source needs to be updated as well. 995 if attr_name is not None: 996 self.source = AttrSource(self.source, attr_name) 997 return trace_rules.lookup(value).create_with_source( 998 value, source=self.source 999 ) 1000 elif value is random.Random: 1001 self.install_guards(GuardBuilder.ID_MATCH) 1002 return RandomClassVariable(source=self.source) 1003 elif istype(value, random.Random) and RandomVariable.is_supported_random_obj( 1004 value 1005 ): 1006 self.install_guards(GuardBuilder.TYPE_MATCH) 1007 result = RandomVariable(value, source=self.source) 1008 self.tx.output.side_effects.track_mutable(value, result) 1009 return result 1010 # Don't use istype, since some python modules are not subclasses of types.ModuleType directly. 1011 # E.g, type(torch.ops) -> <class 'torch._ops._Ops'>, 1012 # type(torch.backends.cudnn) -> <class 'torch.backends.cudnn.CudnnModule'> 1013 elif isinstance(value, (types.ModuleType, replay_record.DummyModule)): 1014 self.install_guards(GuardBuilder.FUNCTION_MATCH) 1015 result = PythonModuleVariable( 1016 value, 1017 source=self.source, 1018 ) 1019 self.tx.output.side_effects.track_object_existing(value, result) 1020 return result 1021 elif isinstance(value, types.MethodType) and isinstance( 1022 value.__self__, (torch.nn.Module, torch.utils._pytree.TreeSpec) 1023 ): 1024 # don't let MethodTypes fall through to UserDefinedObject, 1025 # which doesn't support 'CALL_FUNCTION' 1026 1027 # TODO(whc): Why do we limit this to methods on NNModules? 1028 # I don't have a good reason for this, but it preserves the existing behavior 1029 # for MBartForConditionalGeneration, which generates many graph breaks and OOMs otherwise. 1030 # I suspect we probably want to relax this check and dig deeper there. 1031 1032 # In order to construct a MethodVariable in Dynamo, we start with an actual method obj from python, 1033 # but need to separately wrap its underlying `__func__` and its `self` argument. We wrap `self` here 1034 # and then `__func__` gets wrapped inside UserMethodVariable. 1035 self_obj = VariableBuilder( 1036 self.tx, source=AttrSource(self.source, "__self__") 1037 )(value.__self__) 1038 assert self_obj and isinstance( 1039 self_obj, VariableTracker 1040 ), "Failed to produce a valid self obj" 1041 self.install_guards(GuardBuilder.FUNCTION_MATCH) 1042 return UserMethodVariable( 1043 value.__func__, 1044 self_obj, 1045 source=self.source, 1046 ) 1047 elif isinstance(value, types.GetSetDescriptorType): 1048 # GetSet descriptors are C functions attached to an attribute lookup 1049 # using PyGetSetDef. Python, on attribute lookup, can decide to 1050 # create a new object on the fly, and therefore the `id` of the 1051 # descriptors is not guaranteed to be same for different attribute 1052 # accesses. Since these are unlikely to change during the program 1053 # execution, we can skip guarding on them. 1054 return GetSetDescriptorVariable(value) 1055 elif isinstance(value, types.MethodWrapperType): 1056 # Method-wrappers are written in C, and they are not guaranteed to 1057 # return the same object on attribute lookup. Therefore, we cannot 1058 # insert a FUNCTION_MATCH guard here. method-wrappers are very 1059 # unlikely to change, so its ok to skip the guard here. 1060 return MethodWrapperVariable(value) 1061 elif issubclass(type(value), type): 1062 if value in ( 1063 torch.utils.hooks.BackwardHook, 1064 torch.nn.Parameter, 1065 torch.nn.Buffer, 1066 ): 1067 # TODO(jansel): combine this case with the one above 1068 return trace_rules.lookup(value).create_with_source( 1069 value, source=self.source 1070 ) 1071 if value is torch.autograd._unsafe_preserve_version_counter: 1072 self.install_guards(GuardBuilder.FUNCTION_MATCH) 1073 return PreserveVersionContextVariable.constructor(self.tx) 1074 # This is a userdefined class, so install an ID_MATCH even if its a 1075 # global variable. 1076 self.install_guards(GuardBuilder.ID_MATCH) 1077 return UserDefinedClassVariable( 1078 value, 1079 source=self.source, 1080 ) 1081 elif RestrictedListSubclassVariable.is_matching_cls(type(value)): 1082 self.install_guards(GuardBuilder.SEQUENCE_LENGTH) 1083 return self.set_source_and_track_mutable( 1084 value, 1085 RestrictedListSubclassVariable( 1086 [ 1087 LazyVariableTracker.create( 1088 value=value[i], source=GetItemSource(self.source, i) 1089 ) 1090 for i in range(len(value)) 1091 ], 1092 user_cls=type(value), 1093 user_cls_source=AttrSource(self.source, "__class__"), 1094 ), 1095 ) 1096 elif TorchScriptObjectVariable.is_matching_cls(type(value)): 1097 from ..source import ( 1098 FlattenScriptObjectSource, 1099 ScriptObjectQualifiedNameSource, 1100 ) 1101 1102 if torch._library.fake_class_registry.tracing_with_real(value): 1103 proxy = self.tx.output.root_tracer.create_graph_input( 1104 re.sub(r"[^a-zA-Z0-9]+", "_", self.name), 1105 type(value), 1106 source=self.source, 1107 ) 1108 1109 # setting is_unspecialized=False to not insert a as_tensor call in reconstruct by default 1110 # seting example to be real value because these example values will be used 1111 # as example_inputs for user compiler. 1112 proxy.node.meta["grapharg"] = GraphArg( 1113 self.source, value, False, None, False, value 1114 ) 1115 return TorchScriptObjectVariable.create( 1116 proxy, 1117 value, 1118 source=self.source, 1119 ) 1120 1121 # This exists to allow a smoother transition. 1122 # The implications are: 1123 # The script objects won't be tracked as proxies. 1124 # Methods on these objects won't show up in the graph. 1125 # The original script object might be mutated. 1126 if not hasattr(value, "__obj_flatten__"): 1127 return self.wrap_user_defined(value) 1128 1129 # Install the guards on the fully qualified name of the script object 1130 LazyVariableTracker.realize_all( 1131 VariableBuilder(self.tx, ScriptObjectQualifiedNameSource(self.source))( 1132 value._type().qualified_name() # type: ignore[attr-defined] 1133 ) 1134 ) 1135 # Install the guards on the content of the script object by setting the source 1136 # to be FlattenScriptObjectSource, which calls __obj_flatten__() to get the contents. 1137 LazyVariableTracker.realize_all( 1138 VariableBuilder(self.tx, FlattenScriptObjectSource(self.source))( 1139 value.__obj_flatten__() 1140 ) 1141 ) 1142 1143 fake_script_obj = torch._library.fake_class_registry.maybe_to_fake_obj( 1144 self.tx.output.fake_mode, value 1145 ) 1146 1147 proxy = self.tx.output.root_tracer.create_graph_input( 1148 re.sub(r"[^a-zA-Z0-9]+", "_", self.name), 1149 type(value), 1150 source=self.source, 1151 ) 1152 1153 # setting is_unspecialized=False to not insert a as_tensor call in reconstruct by default 1154 # seting example to be real value because these example values will be used 1155 # as example_inputs for user compiler. 1156 proxy.node.meta["grapharg"] = GraphArg( 1157 self.source, value, False, None, False, fake_script_obj 1158 ) 1159 return TorchScriptObjectVariable.create( 1160 proxy, 1161 fake_script_obj, 1162 source=self.source, 1163 ) 1164 elif issubclass(type(value), MutableMapping): 1165 self.install_guards(GuardBuilder.TYPE_MATCH) 1166 return MutableMappingVariable(value, source=self.source) 1167 elif is_frozen_dataclass(value): 1168 self.install_guards(GuardBuilder.TYPE_MATCH) 1169 result = FrozenDataClassVariable.create(self.tx, value, source=self.source) 1170 return self.tx.output.side_effects.track_object_existing(value, result) 1171 else: 1172 return self.wrap_user_defined(value) 1173 1174 def wrap_user_defined(self, value: Any): 1175 self.install_guards(GuardBuilder.TYPE_MATCH) 1176 result = UserDefinedObjectVariable(value, source=self.source) 1177 if not SideEffects.cls_supports_mutation_side_effects(type(value)): 1178 # don't allow STORE_ATTR mutation with custom __setattr__ 1179 return result 1180 return self.tx.output.side_effects.track_object_existing(value, result) 1181 1182 def wrap_listlike(self, value: Union[tuple, list, odict_values, NamedTuple]): 1183 if config.specialize_int and type(value) is torch.Size: 1184 self.install_guards(GuardBuilder.CONSTANT_MATCH) 1185 return ConstantVariable.create(value=value) 1186 1187 # One can index a tensor with a list/tuple. Therefore, we need to 1188 # have a stricter match. 1189 self.install_guards(GuardBuilder.SEQUENCE_LENGTH) 1190 1191 for item in value: 1192 if item is value: 1193 unimplemented("list elements are pointing to the list itself") 1194 1195 # Tuples are immutable objects, so we should mark its items static. This 1196 # avoids wrapping of tuple items as symints. This helps for nn module 1197 # attributes like conv2d strides, dilations. 1198 if ( 1199 istype(value, tuple) 1200 and all(ConstantVariable.is_literal(item) for item in value) 1201 and self.source.guard_source().is_unspecialized_nn_module() 1202 ): 1203 self.install_guards(GuardBuilder.CONSTANT_MATCH) 1204 return TupleVariable([ConstantVariable.create(item) for item in value]) 1205 1206 output = [ 1207 LazyVariableTracker.create( 1208 item, 1209 source=GetItemSource(self.get_source(), i), 1210 ) 1211 for i, item in enumerate(value) 1212 ] 1213 1214 maybe_gm = self.tx.output.local_scope.get("self") 1215 if isinstance( 1216 self.source, LocalSource 1217 ) and self.source.local_name in get_locals_to_steal(maybe_gm): 1218 # The input tensor list to dynamo from compiled autograd may contain activations 1219 # which are freed as they are used in inductor. Dynamo's default behavior is to 1220 # lift all tensors to the graph inputs, but this will cause dynamo to hold an 1221 # extra reference to the activation tensors and increase peak memory usage. 1222 # To allow freeing ASAP, we keep the list as graph argument to the dynamo output 1223 # graph, and unpack it locally. 1224 # e.g. instead of `def forward(self, L_inputs_0_, L_inputs_1_, ...):`, we have 1225 # `def forward(self, L_inputs_):` 1226 source = self.source 1227 assert isinstance(value, list) 1228 tensor_list_proxy = self.tx.output.root_tracer.create_graph_input( 1229 re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(value), source=source 1230 ) 1231 tensor_list_proxy.node.meta["steal_arg"] = True 1232 1233 list_variable = wrap_fx_proxy_cls( 1234 target_cls=TensorVariable, 1235 tx=self.tx, 1236 proxy=tensor_list_proxy, 1237 example_value=value, 1238 subclass_type=None, 1239 source=source, 1240 ) 1241 1242 guards = [] 1243 for i, tensor_variable in enumerate(list_variable.items): 1244 source_i = GetItemSource(base=source, index=i, index_is_slice=False) 1245 # access unpacked tensor from this list instead of from a lifted arg 1246 self.tx.output.input_source_to_var[source_i] = tensor_variable 1247 tensor_variable.proxy.node.meta["tensor_dict"] = _extract_tensor_dict( 1248 value[i] 1249 ) 1250 1251 guard = functools.partial( 1252 GuardBuilder.TENSOR_MATCH, value=TensorWeakRef(value[i]) 1253 ) 1254 guards.append(source_i.make_guard(guard)) 1255 1256 install_guard(*guards, skip=1) 1257 1258 grapharg = GraphArg( 1259 source, 1260 value, 1261 pass_arg_as_tensor=False, 1262 fake_tensor=None, 1263 is_tensor=False, 1264 ) 1265 tensor_list_proxy.node.meta["grapharg"] = grapharg 1266 1267 result = BaseListVariable.cls_for_instance(value)( 1268 output, mutable_local=MutableLocal() 1269 ) 1270 if istype(value, list): 1271 return self.set_source_and_track_mutable(value, result) 1272 return result 1273 1274 def wrap_tuple_iterator(self, value: tuple_iterator): 1275 self.install_guards(GuardBuilder.TUPLE_ITERATOR_LEN) 1276 output = [ 1277 VariableBuilder(self.tx, TupleIteratorGetItemSource(self.get_source(), i))( 1278 tuple_iterator_getitem(value, i) 1279 ) 1280 for i in range(tuple_iterator_len(value)) 1281 ] 1282 result = TupleIteratorVariable( 1283 output, mutable_local=MutableLocal(), source=self.source 1284 ) 1285 1286 return self.set_source_and_track_mutable(value, result) 1287 1288 def wrap_slice_range(self, value: Union[slice, range]): 1289 items = [ 1290 VariableBuilder(self.tx, AttrSource(self.get_source(), k))( 1291 getattr(value, k) 1292 ) 1293 for k in ("start", "stop", "step") 1294 ] 1295 self.install_guards(GuardBuilder.TYPE_MATCH) 1296 if isinstance(value, slice): 1297 return SliceVariable(items, source=self.source) 1298 else: 1299 return RangeVariable(items, source=self.source) 1300 1301 def mark_static_input(self, value: torch.Tensor, guard: bool): 1302 from ..decorators import mark_static_address 1303 1304 static_inputs_log.debug( 1305 "Marking static input %s, id: %s)", self.source.name(), id(value) 1306 ) 1307 mark_static_address(value, guard=guard) 1308 1309 # Check if we've seen this tensor before and update graph metadata if needed 1310 # As long as this runs before AOT this is sound 1311 if value in self.tx.output.side_effects: 1312 var = self.tx.output.side_effects[value] 1313 var.proxy.node.meta["tensor_dict"][ 1314 "_dynamo_static_input_type" 1315 ] = value._dynamo_static_input_type 1316 1317 def wrap_module(self, value: torch.nn.Module): 1318 from ..eval_frame import OptimizedModule 1319 1320 if len(value.__dict__) == 0: 1321 unimplemented(f"uninitialized nn.Module: {typestr(value)}") 1322 if istype(value, OptimizedModule): 1323 # Check if the optimized module was disabled 1324 if inspect.getattr_static(value.forward, "_torchdynamo_disable", False): 1325 # This bytecode is mostly of kind LOAD_ATTR or LOAD_METHOD. If 1326 # we graph break here, Dynamo does not know how to create 1327 # continuation functions for such bytecodes. So, we delay the 1328 # graph break to CALL_FUNCTION. 1329 return DelayGraphBreakVariable(source=self.source) 1330 1331 self.install_guards(GuardBuilder.TYPE_MATCH) 1332 self.source = AttrSource(self.source, "_orig_mod") 1333 return self.wrap_module(value._orig_mod) 1334 1335 if ( 1336 isinstance(value, (torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM)) 1337 and not config.allow_rnn 1338 ): 1339 unimplemented("TorchDynamo purposely graph breaks on RNN, GRU, LSTMs") 1340 1341 if getattr(value, "_is_fsdp_managed_module", False): 1342 # See note [Dynamo treats FSDP wrapped modules as UnspecializedNNModule] 1343 # in fully_sharded_data_parallel.py for more information 1344 1345 # we can't do this assert inside FSDP constructor, 1346 # since we don't know yet whether dynamo will be used 1347 assert getattr( 1348 value, "_fsdp_use_orig_params", False 1349 ), "Dynamo only supports FSDP with use_orig_params=True" 1350 1351 # Note on FSDP guarding 1352 # Eager FSDP already assumes (requires, but without enforcement) 1353 # that users don't mutate their model parameters/structure after 1354 # FSDP wrapping, because FSDP wouldn't notice or update its 1355 # FlatParams. 1356 # 1357 # Therefore, torch.compile can skip guarding on params or submodule 1358 # structure of fsdp_managed modules, by using FSDPNNModuleSource as 1359 # the guard source. This behavior is gated on 1360 # config.skip_fsdp_guards. 1361 self.install_guards(GuardBuilder.TYPE_MATCH) 1362 result = FSDPManagedNNModuleVariable(value, source=self.get_source()) 1363 if not SideEffects.cls_supports_mutation_side_effects(type(value)): 1364 # don't allow STORE_ATTR mutation with custom __setattr__ 1365 return result 1366 return self.tx.output.side_effects.track_object_existing(value, result) 1367 elif mutation_guard.is_dynamic_nn_module(value, self.tx.export): 1368 # created dynamically, don't specialize on it 1369 1370 # Note [Tracing a torch.compiled function] 1371 # when make_fx tracing a compiled function, we need 1372 if isinstance(value, torch.fx.experimental.proxy_tensor._AttrProxy): 1373 value = value.get_base() 1374 self.source = AttrProxySource(self.source) 1375 1376 self.install_guards(GuardBuilder.TYPE_MATCH) 1377 if torch._dynamo.config.inline_inbuilt_nn_modules: 1378 freezing = is_parameter_freezing() 1379 for p in value.parameters(): 1380 self.mark_static_input(p, guard=freezing) 1381 1382 for b in value.buffers(): 1383 self.mark_static_input(b, guard=freezing) 1384 1385 if freezing: 1386 # we need to add the module to tracing context 1387 # in order to allow its params to get invalidated 1388 # this will get cleaned up once compile ends 1389 self.tx.output.nn_modules[self.name] = value 1390 1391 if value.__module__.startswith(("torch.nn.", "torch.ao.")) or getattr( 1392 value.__class__, "_dynamo_marked_static", False 1393 ): 1394 result = UnspecializedBuiltinNNModuleVariable(value, source=self.source) 1395 else: 1396 result = UnspecializedNNModuleVariable(value, source=self.source) 1397 1398 if not SideEffects.cls_supports_mutation_side_effects(type(value)): 1399 # don't allow STORE_ATTR mutation with custom __setattr__ 1400 return result 1401 return self.tx.output.side_effects.track_object_existing(value, result) 1402 elif issubclass( 1403 value.__class__, torch.nn.parallel.distributed.DistributedDataParallel 1404 ): 1405 self.install_guards(GuardBuilder.TYPE_MATCH) 1406 return UnspecializedNNModuleVariable(value, source=self.get_source()) 1407 else: 1408 return self.tx.output.register_attr_or_module( 1409 value, 1410 self.name, 1411 source=self.get_source(), 1412 # Guards are added inside register_attr_or_module 1413 ) 1414 1415 def wrap_literal(self, value): 1416 if not config.specialize_int and type(value) is int: 1417 # unspecializing int by default, but still 1418 # specialize for the following conditions 1419 if not TracingContext.get().force_unspec_int_unbacked_size_like and ( 1420 # Assume integers from global variables want to be specialized 1421 not self.source.guard_source().is_local() 1422 # Assume that integers that came from NN modules want to be 1423 # specialized (as we don't expect users to be changing the 1424 # NN modules on the fly) 1425 or self.source.guard_source().is_specialized_nn_module() 1426 or self.source.guard_source().is_unspecialized_builtin_nn_module() 1427 or is_from_defaults(self.source) 1428 or is_cell_contents(self.source) 1429 # TODO: Delete this condition when rollout is done. NB: this 1430 # condition never evaluates True in open source 1431 or ( 1432 not justknobs_check( 1433 "pytorch/dynamo:enable_unspecialize_zero_one_plain_int" 1434 ) 1435 and value in self._common_constants() 1436 ) 1437 ): 1438 self.install_guards(GuardBuilder.CONSTANT_MATCH) 1439 return ConstantVariable.create(value=value, source=self.source) 1440 else: 1441 return self.wrap_symint(value) 1442 elif not config.specialize_float and type(value) is float: 1443 return self.wrap_symfloat(value) 1444 else: 1445 self.install_guards(GuardBuilder.CONSTANT_MATCH) 1446 result = ConstantVariable.create(value=value, source=self.source) 1447 if isinstance(value, (list, set)): 1448 return self.set_source_and_track_mutable(value, result) 1449 return result 1450 1451 def assert_not_wrapped_by_this_graph(self, value: torch.Tensor): 1452 if is_fake(value) and maybe_get_fake_mode(value) is self.tx.fake_mode: 1453 raise InternalTorchDynamoError( 1454 "Cannot wrap a Tensor that has already been", 1455 "wrapped by this instance of Dynamo", 1456 ) 1457 1458 def wrap_tensor(self, value: torch.Tensor): 1459 source = self.get_source() 1460 1461 # We cannot already be tracking the tensor, which implies 1462 # it would have already been wrapped 1463 assert value not in self.tx.output.side_effects 1464 1465 is_static_input = get_static_address_type(value) is not None 1466 1467 if ( 1468 config.inline_inbuilt_nn_modules 1469 and not is_static_input 1470 and ( 1471 isinstance(value, torch.nn.Parameter) 1472 # mark tensor attributes of nn modules static. This is done to keep inline_inbuilt_nn_modules behavior 1473 # compatible with previous behavior. 1474 or (source and source.guard_source().is_unspecialized_nn_module()) 1475 ) 1476 ): 1477 self.mark_static_input(value, guard=is_parameter_freezing()) 1478 is_static_input = True 1479 1480 make_graph_attribute = is_static_input and ( 1481 not config.inline_inbuilt_nn_modules or is_parameter_freezing() 1482 ) 1483 1484 if ( 1485 source.guard_source().is_specialized_nn_module() or make_graph_attribute 1486 ) and not source.guard_source().is_fsdp_module(): 1487 self.assert_not_wrapped_by_this_graph(value) 1488 return self.tx.output.register_attr_or_module( 1489 value, self.name, source=source 1490 ) 1491 1492 if is_constant_source(source): 1493 self.assert_not_wrapped_by_this_graph(value) 1494 return self.tx.output.register_attr_or_module( 1495 value, 1496 re.sub(r"[^a-zA-Z0-9]+", "_", self.name), 1497 source=source, 1498 # Guards are added inside register_attr_or_module 1499 ) 1500 1501 if type(value) in config.traceable_tensor_subclasses: 1502 # Ordinarily, we would fakeify a tensor so that it can get dynamic 1503 # shapes and be computed on without triggering actual operations. 1504 # However, how can we fakeify a tensor subclass? Ordinary 1505 # inheritance (nor multiple inheritance) won't work work. 1506 # 1507 # Instead, our plan is to *manually simulate* the tensor subclass 1508 # inheriting from a fake tensor with dynamo. This means our 1509 # data representation for a tensor subclass will be a fake tensor 1510 # + tensor subclass type + any extra data the subclass may have 1511 # been storing on the tensor. Because all Python accesses are 1512 # mediated through TensorWithTFOverrideVariable, we can ensure 1513 # that we dispatch differently, e.g., according to 1514 # __torch_function__ 1515 # 1516 # To simplify things for now, the __dict__ tracking bits haven't 1517 # been implemented yet, but they can be added into this design at 1518 # a later point in time. 1519 subclass_type = type(value) 1520 else: 1521 assert type(value) in ( 1522 torch.Tensor, 1523 torch.nn.Parameter, 1524 torch._subclasses.fake_tensor.FakeTensor, 1525 torch._subclasses.functional_tensor.FunctionalTensor, 1526 ) or is_traceable_wrapper_subclass(value), type(value) 1527 subclass_type = None 1528 1529 # NB: this just says we accessed a tensor from the same source again 1530 # (e.g., a tensor lives in a global foo, and we LOAD_GLOBAL it twice). 1531 # This is distinct from two distinct sources mapping to the same 1532 # Tensor (per id())! No guard is necessary here. See below for the 1533 # other case. 1534 is_duplicate_tensor = source in self.tx.output.input_source_to_var 1535 if is_duplicate_tensor: 1536 return self.tx.output.input_source_to_var[source] 1537 1538 if get_static_address_type(value) == "guarded": 1539 self.install_guards(GuardBuilder.ID_MATCH) 1540 1541 # By this point, we should have deduplicated all tensors 1542 self.assert_not_wrapped_by_this_graph(value) 1543 1544 # tx.output has multiple tracers if we're introspecting HigherOrderOperator. 1545 # When we've discovered an untracked tensor, then we actually need 1546 # to get Dynamo to track the tensor (which is what this function does) 1547 # and put it as a graph input on the root tracer. Later on, 1548 # if the input is actually used in the body of the HigherOrderOperator, 1549 # then the relevant SubgraphTracer will lift it to being an input of 1550 # the subgraph. 1551 # See NOTE [HigherOrderOperator tracing design] for more details. 1552 1553 tensor_proxy = self.tx.output.root_tracer.create_graph_input( 1554 re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(value), source=source 1555 ) 1556 options = {} 1557 if type(value) in config.traceable_tensor_subclasses: 1558 options["torch_function_fn"] = build_torch_function_fn( 1559 self.tx, value, self.source 1560 ) 1561 self.install_guards(GuardBuilder.TYPE_MATCH) 1562 1563 if ( 1564 isinstance(value, torch.Tensor) 1565 and value.is_nested 1566 and not isinstance(value, torch.nested._internal.nested_tensor.NestedTensor) 1567 ): 1568 unimplemented("torch.compile does not support strided NestedTensor") 1569 1570 # TODO(pearu,sparse-team) - Add the corresponding SPARSE_TENSOR_MATCH guards 1571 if ( 1572 isinstance(value, torch.Tensor) 1573 and is_sparse_any(value) 1574 and (not self.tx.export or not config.capture_sparse_compute) 1575 ): 1576 # A hot fix for sparse tensors + torch.compile. Support for 1577 # export + sparsity is being added but we need to create 1578 # SPARSE_TENSOR_GUARDS for guards to work propertly. 1579 unimplemented("torch.compile does not support sparse Tensors") 1580 1581 if ( 1582 safe_has_grad(value) 1583 and safe_grad(value) is not None 1584 and value.dtype != safe_grad(value).dtype 1585 ): 1586 unimplemented( 1587 "Inconsistent dtype between tensor and its gradient. " 1588 "This can happen in FSDP and crashes meta tensor creation. " 1589 "This is potentially a workaround. Fixing it correctly " 1590 "requires some design around FSDP + torch.compile." 1591 ) 1592 1593 tensor_variable = wrap_fx_proxy( 1594 tx=self.tx, 1595 proxy=tensor_proxy, 1596 example_value=value, 1597 subclass_type=subclass_type, 1598 source=source, 1599 **options, 1600 ) 1601 1602 guard_type = GuardBuilder.TENSOR_MATCH 1603 1604 if isinstance(source, GradSource) and is_from_optimizer_source(source): 1605 guard_type = GuardBuilder.NOT_NONE_MATCH 1606 1607 self.install_guards( 1608 functools.partial( 1609 guard_type, 1610 value=( 1611 value 1612 if isinstance(source, NumpyTensorSource) 1613 else TensorWeakRef(value) 1614 ), 1615 ) 1616 ) 1617 1618 # We install TYPE_MATCH guards for traceable wrapper subclass object, 1619 # and recursively install corresponding guard for each inner attribute. 1620 if is_traceable_wrapper_subclass(value): 1621 self.install_guards(GuardBuilder.TENSOR_SUBCLASS_METADATA_MATCH) 1622 self.install_guards(GuardBuilder.TYPE_MATCH) 1623 install_guard( 1624 SubclassAttrListSource(source).make_guard(GuardBuilder.EQUALS_MATCH) 1625 ) 1626 1627 attrs, _ = value.__tensor_flatten__() 1628 for attr in attrs: 1629 inner_value = getattr(value, attr) 1630 inner_source = AttrSource(self.source, attr) 1631 LazyVariableTracker.realize_all( 1632 VariableBuilder(self.tx, inner_source)(inner_value) 1633 ) 1634 1635 self.tx.output.input_source_to_var[source] = tensor_variable 1636 assert "tensor_dict" not in tensor_proxy.node.meta 1637 tensor_proxy.node.meta["tensor_dict"] = _extract_tensor_dict(value) 1638 1639 # Note: this information is conveyed via subclass_type now 1640 fake_tensor_value = tensor_variable.proxy.node.meta["example_value"] 1641 if maybe_get_fake_mode(fake_tensor_value) is not self.tx.fake_mode: 1642 raise InternalTorchDynamoError("Wrapped Tensor must be this graph's fake") 1643 1644 grapharg = GraphArg(source, value, False, fake_tensor_value) 1645 tensor_proxy.node.meta["grapharg"] = grapharg 1646 self.tx.output.add_symbol_bindings(grapharg) 1647 return tensor_variable 1648 1649 def wrap_numpy_ndarray(self, value): 1650 assert np is not None 1651 assert isinstance(value, np.ndarray) 1652 1653 source = NumpyTensorSource(self.get_source()) 1654 1655 from torch._numpy import _util 1656 1657 readonly = not value.flags.writeable 1658 if readonly: 1659 try: 1660 value.flags.writeable = True 1661 except ValueError: 1662 # One can not easily make nditer elements writable, 1663 # but warning is not the end of the world 1664 assert isinstance(value.base, np.nditer) 1665 1666 try: 1667 tensor_value = _util._try_convert_to_tensor(value) 1668 if readonly: 1669 from torch._prims_common import clone_preserve_strides 1670 1671 tensor_value = clone_preserve_strides(tensor_value) 1672 except NotImplementedError as e: 1673 # failed to convert to tensor, graph break 1674 unimplemented(str(e)) 1675 1676 # We do this because we want the full behavior of guarding the numpy ndarray as if it were 1677 # a tensor. It's a little annoying to make a VT to throw out, but there's so many side effects here 1678 # that there's not another great way to do this atm. 1679 # This creates the right graphargs, as well as registration for guards in tensor names and shape env. 1680 LazyVariableTracker.realize_all(VariableBuilder(self.tx, source)(tensor_value)) 1681 proxy = self.tx.output.root_tracer.create_graph_input( 1682 re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(tensor_value), source=source 1683 ) 1684 options = {"source": source} 1685 numpy_ndarray_variable = wrap_fx_proxy_cls( 1686 target_cls=NumpyNdarrayVariable, 1687 tx=self.tx, 1688 proxy=proxy, 1689 example_value=tensor_value, 1690 **options, 1691 ) 1692 1693 self.tx.output.input_source_to_var[source] = numpy_ndarray_variable 1694 example_value = numpy_ndarray_variable.proxy.node.meta["example_value"] 1695 1696 # pass_arg_as_tensor should be true because we are wrapping a np.ndarray as argument input, and it needs to be 1697 # converted to a tensor. 1698 grapharg = GraphArg( 1699 source, 1700 tensor_value, 1701 pass_arg_as_tensor=True, 1702 fake_tensor=example_value, 1703 is_tensor=True, 1704 example_strong_ref=tensor_value, 1705 ) 1706 proxy.node.meta["grapharg"] = grapharg 1707 1708 return numpy_ndarray_variable 1709 1710 def wrap_symint(self, value): 1711 assert type(value) is int 1712 1713 if self.name in self.tx.output.unspec_variable_map: 1714 return self.tx.output.unspec_variable_map[self.name] 1715 1716 shape_env = self.tx.output.shape_env 1717 if TracingContext.get().force_unspec_int_unbacked_size_like: 1718 wrapped_value = shape_env.create_unbacked_symint() 1719 _constrain_range_for_size(wrapped_value) 1720 self.tx.output.bound_symbols.add(wrapped_value.node.expr) 1721 self.tx.output.tracked_fakes.append( 1722 TrackedFake(wrapped_value, self.source, None) 1723 ) 1724 1725 # NB: We do not do float. For motivation, see 1726 # https://docs.google.com/document/d/1INSCdYu1PxXcr43HrD82OudeEuS-qxQe1yZmLg2wy6A/edit 1727 # but the general idea is that we generate kernels that can 1728 # take unspecialized floats and use them in sizevar computation 1729 elif not is_constant_source(self.get_source()): 1730 if torch._dynamo.config.specialize_int: 1731 # If specialize_int is False, also return 1732 # a constant (but this should have been handled 1733 # in the caller, TBH) 1734 self.install_guards(GuardBuilder.CONSTANT_MATCH) 1735 return ConstantVariable.create(value=value, source=self.source) 1736 1737 name = self.source.name() 1738 1739 def update_frame_state(value): 1740 if name not in self.tx.output.frame_state: 1741 # Note - this essentially means that if this name gets reused as a tensor, 1742 # it will start fully dynamic. That should always be a safe option, and not awfully inefficient. 1743 # Alternatively, if we want to improve pef here, we can add a third state of unset, but I am not 1744 # sure that is necessary for now. 1745 frame_state_entry = FrameStateSizeEntry( 1746 scalar=value, size=None, stride=None 1747 ) 1748 else: 1749 frame_state_entry = self.tx.output.frame_state[name] 1750 if frame_state_entry.scalar != value: 1751 log.debug( 1752 "automatic dynamic int %s val %s != %s", 1753 name, 1754 value, 1755 frame_state_entry.scalar, 1756 ) 1757 if self.source.guard_source().is_unspecialized_nn_module(): 1758 log.info( 1759 "%s", 1760 ( 1761 f"{name} is converted to a symbolic integer. It is an attribute of a " 1762 "user defined nn module class. If you wish to keep it static, you can " 1763 "mark the nn module class as `torch._dynamo.mark_static`." 1764 ), 1765 ) 1766 frame_state_entry.scalar = None 1767 self.tx.output.frame_state[name] = frame_state_entry 1768 1769 if (st := self.tx.distributed_state) is None: 1770 update_frame_state(value) 1771 frame_state_entry = self.tx.output.frame_state[name] 1772 elif st.all_states is None: 1773 # Preflight, always pretend as if it's static 1774 frame_state_entry = FrameStateSizeEntry( 1775 size=None, scalar=value, stride=None 1776 ) 1777 st.local_state.input_sizes[name] = value 1778 else: 1779 # Apply the updates 1780 for sub_state in st.all_states: 1781 update_frame_state(sub_state.input_sizes[name]) 1782 frame_state_entry = self.tx.output.frame_state[name] 1783 1784 # TODO: This should be dynamic, as we in general do not 1785 # know if bare integers are actually going to be sizevars 1786 # and it is inappropriate to eagerly duck size them with 1787 # real sizevars 1788 if ( 1789 config.automatic_dynamic_shapes and frame_state_entry.scalar is None 1790 ) or not config.assume_static_by_default: 1791 dynamic_dim = DimDynamic.DYNAMIC 1792 else: # assume_static_by_default 1793 # TODO: dynamic_dim = DimDynamic.STATIC should work but 1794 # for some reason it doesn't 1795 self.install_guards(GuardBuilder.CONSTANT_MATCH) 1796 return ConstantVariable.create(value=value) 1797 1798 wrapped_value = shape_env.create_unspecified_symint_and_symbol( 1799 value, 1800 source=self.source, 1801 dynamic_dim=dynamic_dim, 1802 ) 1803 self.tx.output.bound_symbols.add(wrapped_value.node.expr) 1804 1805 self.tx.output.tracked_fakes.append( 1806 TrackedFake(wrapped_value, self.source, None) 1807 ) 1808 else: 1809 assert is_constant_source(self.get_source()) 1810 # TODO: Do I actually need guard for constant source? 1811 self.install_guards(GuardBuilder.CONSTANT_MATCH) 1812 return ConstantVariable.create(value=value, source=self.source) 1813 1814 assert not isinstance(self.get_source(), RandomValueSource) 1815 install_guard(self.get_source().make_guard(GuardBuilder.TYPE_MATCH)) 1816 1817 options = {"source": self.get_source()} 1818 1819 proxy = self.tx.output.root_tracer.create_graph_input( 1820 re.sub(r"[^a-zA-Z0-9]+", "_", self.name), 1821 type(wrapped_value), 1822 source=self.get_source(), 1823 ) 1824 1825 set_example_value(proxy.node, wrapped_value) 1826 unspec_var = SymNodeVariable(proxy, wrapped_value, **options) 1827 self.tx.output.unspec_variable_map[self.name] = unspec_var 1828 1829 if not is_constant_source(self.get_source()): 1830 if self.tx.export and not isinstance(self.get_source(), LocalSource): 1831 raise AssertionError( 1832 f"Dynamo attempts to add additional input during export: value={wrapped_value}, source={self.get_source()}" 1833 ) 1834 1835 example_value = unspec_var.proxy.node.meta["example_value"] 1836 1837 proxy.node.meta["grapharg"] = GraphArg( 1838 self.get_source(), 1839 wrapped_value, 1840 pass_arg_as_tensor=False, 1841 fake_tensor=None, 1842 is_tensor=False, 1843 example_strong_ref=wrapped_value, 1844 ) 1845 1846 return unspec_var 1847 1848 def wrap_symfloat(self, value): 1849 # SymFloat wrapping is special. We first wrap it in the same way we 1850 # do an unspecialized primitive, and then we item() it into a 1851 # SymFloat. Removal of the item() call is left to a later FX pass, 1852 # mostly because that pass is more easily done after we have lowered 1853 # to ATen ops. (Dynamo doesn't do decomposition right now). 1854 1855 if self.name in self.tx.output.unspec_variable_map: 1856 return self.tx.output.unspec_variable_map[self.name] 1857 1858 # NB: we specialize on nan input, because our guard modeling in 1859 # ShapeEnv cannot deal with nan 1860 if ( 1861 torch._dynamo.config.specialize_float 1862 or is_constant_source(self.get_source()) 1863 or math.isnan(value) 1864 ): 1865 self.install_guards(GuardBuilder.CONSTANT_MATCH) 1866 return ConstantVariable.create(value=value, source=self.source) 1867 1868 # NB: At the point we've gotten here, we don't assume static by 1869 # default. Since we have a guard mechanism, there isn't really any 1870 # downside to trying to be dynamic for float all the time. Unlike 1871 # ints, this won't make codegen perf worse. Modest cost to compile 1872 # time. 1873 1874 wrapped_value = torch.tensor(value, dtype=torch.float64) 1875 # TODO: Switch RandomValueSource over to use this, this is more 1876 # accurate 1877 assert not isinstance(self.get_source(), RandomValueSource) 1878 install_guard(self.get_source().make_guard(GuardBuilder.TYPE_MATCH)) 1879 1880 # The FloatTensorSource here is just for pedantic correctness: if you 1881 # guard against an UnspecializedPythonVariable, you need to guard 1882 # against the tensor-ified version of the local, otherwise it's not a 1883 # Tensor. However, we never let the UnspecializedPythonVariable escape 1884 # here, so there should never actually be any guards against this 1885 # source. 1886 options = {"source": FloatTensorSource(self.get_source()), "raw_value": value} 1887 1888 # TODO: Maybe the tensor-ification should be built into the source, 1889 # rather than by special pattern match 1890 proxy = self.tx.output.root_tracer.create_graph_input( 1891 re.sub(r"[^a-zA-Z0-9]+", "_", self.name), 1892 type(wrapped_value), 1893 source=self.get_source(), 1894 ) 1895 1896 unspec_var = wrap_fx_proxy_cls( 1897 UnspecializedPythonVariable, 1898 tx=self.tx, 1899 proxy=proxy, 1900 example_value=wrapped_value, 1901 **options, 1902 ) 1903 assert isinstance(unspec_var, UnspecializedPythonVariable) 1904 self.tx.output.unspec_variable_map[self.name] = unspec_var 1905 1906 if self.tx.export and not isinstance(self.get_source(), LocalSource): 1907 raise AssertionError( 1908 f"Dynamo attempts to add additional input during export: value={wrapped_value}, source={self.get_source()}" 1909 ) 1910 fake_tensor_value = None 1911 example_value = unspec_var.proxy.node.meta["example_value"] 1912 assert is_fake(example_value) 1913 1914 fake_tensor_value = example_value 1915 assert fake_tensor_value.fake_mode is self.tx.fake_mode, ( 1916 f"fake mode ({fake_tensor_value.fake_mode}) from fake tensor metadata doesn't match mode" 1917 "({self.tx.fake_mode}) from InstructionTranslator" 1918 ) 1919 1920 # There's something a bit incoherent about pass_arg_as_tensor, 1921 # specifically regarding sources. 1922 # 1923 # Specifically, suppose we have "x: float" local argument. We 1924 # eventually end up with an UnspecializedPythonVariable denoting 1925 # torch.as_tensor(x)... but it's source is still L['x'] (which if you 1926 # accessed it directly is a float!) So you gotta be careful when 1927 # setting up your guards, because it's still going to be a float at 1928 # this point, the conversion happens only precisely at the point we're 1929 # actually calling the FX graph. This happens to be what we want for 1930 # shape guard generation, but it's kind of unintuitive. 1931 proxy.node.meta["grapharg"] = GraphArg( 1932 self.get_source(), 1933 wrapped_value, 1934 pass_arg_as_tensor=True, 1935 fake_tensor=fake_tensor_value, 1936 is_tensor=False, 1937 example_strong_ref=wrapped_value, 1938 ) 1939 1940 # Directly do item to bypass capture_scalar_outputs 1941 r = wrap_fx_proxy( 1942 self.tx, 1943 self.tx.output.create_proxy( 1944 "call_method", 1945 "item", 1946 *proxy_args_kwargs([unspec_var], {}), 1947 ), 1948 ) 1949 self.tx.output.tracked_fakes.append(TrackedFake(r.sym_num, self.source, None)) 1950 1951 return r 1952 1953 def wrap_unspecialized_primitive(self, value): 1954 if self.name in self.tx.output.unspec_variable_map: 1955 return self.tx.output.unspec_variable_map[self.name] 1956 1957 wrapped_value = torch.tensor(value) 1958 if not isinstance(self.get_source(), RandomValueSource): 1959 install_guard(self.get_source().make_guard(GuardBuilder.TYPE_MATCH)) 1960 1961 options = {"source": self.get_source()} 1962 options.update({"raw_value": value}) 1963 1964 proxy = self.tx.output.root_tracer.create_graph_input( 1965 re.sub(r"[^a-zA-Z0-9]+", "_", self.name), 1966 type(wrapped_value), 1967 source=self.get_source(), 1968 ) 1969 1970 unspec_var = wrap_fx_proxy_cls( 1971 UnspecializedPythonVariable, 1972 tx=self.tx, 1973 proxy=proxy, 1974 example_value=wrapped_value, 1975 **options, 1976 ) 1977 self.tx.output.unspec_variable_map[self.name] = unspec_var 1978 if not is_constant_source(self.get_source()): 1979 if self.tx.export and not isinstance(self.get_source(), LocalSource): 1980 raise AssertionError( 1981 f"Dynamo attempts to add additional input during export: value={wrapped_value}, source={self.get_source()}" 1982 ) 1983 fake_tensor_value = None 1984 if isinstance(unspec_var, ConstantVariable): 1985 # TODO: when can this happen? 1986 example_value = unspec_var.value 1987 else: 1988 example_value = unspec_var.proxy.node.meta["example_value"] 1989 assert is_fake(example_value) 1990 1991 fake_tensor_value = example_value 1992 assert fake_tensor_value.fake_mode is self.tx.fake_mode, ( 1993 f"fake mode ({fake_tensor_value.fake_mode}) from fake tensor metadata doesn't match mode" 1994 "({self.tx.fake_mode}) from InstructionTranslator" 1995 ) 1996 1997 proxy.node.meta["grapharg"] = GraphArg( 1998 self.get_source(), 1999 wrapped_value, 2000 pass_arg_as_tensor=True, 2001 fake_tensor=fake_tensor_value, 2002 is_tensor=False, 2003 example_strong_ref=wrapped_value, 2004 ) 2005 return unspec_var 2006 2007 2008def _dataclasses_fields_lambda(obj): 2009 if isinstance(obj, UserDefinedObjectVariable): 2010 value = obj.value 2011 elif isinstance(obj, CustomizedDictVariable): 2012 value = obj.user_cls 2013 else: 2014 unimplemented(f"Dataclass fields handling fails for type {obj}") 2015 items = [] 2016 for field in dataclasses.fields(value): 2017 source = None 2018 if obj.source: 2019 source = GetItemSource( 2020 AttrSource(obj.source, "__dataclass_fields__"), field.name 2021 ) 2022 items.append(UserDefinedObjectVariable(field, source=source)) 2023 return TupleVariable(items) 2024 2025 2026def wrap_fx_proxy( 2027 tx, proxy, example_value=None, subclass_type=None, **options 2028) -> VariableTracker: 2029 kwargs = { 2030 "tx": tx, 2031 "proxy": proxy, 2032 "example_value": example_value, 2033 "subclass_type": subclass_type, 2034 **options, 2035 } 2036 if subclass_type is None: 2037 return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs) 2038 else: 2039 result = wrap_fx_proxy_cls(target_cls=TensorWithTFOverrideVariable, **kwargs) 2040 result.install_global(tx) 2041 return result 2042 2043 2044# Note: Unfortunate split due to some gross classes existing that subclass TensorVariable 2045# Should be compositional instead 2046# 2047# This is a horribly complicated function that does too many things, to 2048# explain what it does, let's first talk about the classic usage wrap_fx_proxy 2049# for a TensorVariable. There are two primary modes of use: 2050# 2051# 1. Wrapping a pre-existing Tensor. In this case, example_value is set 2052# to the pre-existing Tensor. (Note that this example_value will NOT 2053# be the final example_value we put into node.meta['example_value'], 2054# instead it is converted into a fake tensor using 2055# wrap_to_fake_tensor_and_record and registered as a graph input.) 2056# 2057# 2. "Wrapping" the result of some Tensor operation Dynamo traced over. In 2058# this case, example_value is None (and we are going to figure it out 2059# ourselves using FakeTensors, via get_fake_value, which will run 2060# the operation represented by the (singular!) FX node referenced by 2061# the passed in proxy.) 2062# 2063# The expectation is you end up with a Tensor output, and everything is 2064# straightforwardly traced into the graph. 2065# 2066# In all cases, the returned `TensorVariable` subclass will have an `example_value` 2067# and that `example_value` must be a `FakeTensor` produced by the currently running 2068# instance of Dynamo. 2069# 2070# Upon closer inspection, you may notice that there are a slurry of non-Tensor 2071# output cases. What gives? Well, we sometimes trace operations into the 2072# graph that don't involve tensors. 2073# 2074# * Some operators return tuples; we need to recursively handle their 2075# contents 2076# 2077# * Some operators have side effects that will affect subsequent AOTAutograd 2078# tracing but don't otherwise return anything. 2079# 2080# * Some operators return symbolic ints/floats/bools which can go in the 2081# graph and be traced (but only if they're actually symbolic! If they're 2082# static you don't want to put them in the graph, which means you 2083# shouldn't call this function.) 2084# 2085# The common theme is that you only use this function WHEN YOU ARE TRACING 2086# SOMETHING INTO THE GRAPH. This is sort of obvious, because you can't call 2087# this function without a proxy. 2088def wrap_fx_proxy_cls( 2089 target_cls, tx, proxy, example_value=None, subclass_type=None, **options 2090): 2091 from ..symbolic_convert import InstructionTranslatorBase 2092 2093 assert isinstance(tx, InstructionTranslatorBase) 2094 if "guards" in options and options["guards"] is not None: 2095 tx.output.guards.update(options["guards"]) 2096 2097 assert "example_value" not in proxy.node.meta, f"{proxy.node.meta['example_value']}" 2098 2099 initial_example_value = example_value 2100 2101 def _clone_input(value): 2102 if isinstance(value, torch.Tensor): 2103 # tensor subclasses will not be converted to FakeTensors and need to be cloned 2104 if not ( 2105 isinstance(value, FakeTensor) 2106 or ( 2107 # Is functional tensor fakeified by this instance of Dynamo 2108 torch._is_functional_tensor(value) 2109 and maybe_get_fake_mode(value) is tx.fake_mode 2110 ) 2111 or value.is_nested 2112 ): 2113 # NB: ensure strides are preserved 2114 value = clone_input(value) 2115 2116 return value 2117 2118 # See NOTE: [Deferring tensor pack/unpack hooks until runtime] 2119 with torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing(): 2120 # with preserve_rng_state(): 2121 if example_value is None: 2122 # only allow_non_graph_fake in this instance because we handle the non-fake 2123 # cases properly below. 2124 example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True) 2125 2126 # Handle recursive calls here 2127 elif maybe_get_fake_mode(example_value) is tx.fake_mode: 2128 pass 2129 2130 elif isinstance(example_value, torch.Tensor): 2131 if tx.export: 2132 # The legacy behavior for real value cache with subclasses was 2133 # to perform a clone WITHOUT preserving the subclass. It's 2134 # not entirely clear this is what you actually want though. 2135 with torch._C.DisableTorchFunctionSubclass(): 2136 proxy.tracer.real_value_cache[proxy.node] = _clone_input( 2137 example_value 2138 ) 2139 # NB: If we're ignoring subclass, then the expectation is you will 2140 # take the returned TensorVariable and wrap it into a more 2141 # accurate TensorVariable that is able to track subclass-ness; 2142 # otherwise this is wrong! 2143 kwargs = { 2144 "is_tensor": target_cls 2145 in (TensorVariable, TensorWithTFOverrideVariable), 2146 } 2147 assert "source" in options and options["source"] is not None 2148 kwargs["source"] = options["source"] 2149 example_value = wrap_to_fake_tensor_and_record( 2150 example_value, tx=tx, **kwargs 2151 ) 2152 if ( 2153 isinstance(example_value, torch.Tensor) 2154 and example_value.device.type != "meta" 2155 and (maybe_get_fake_mode(example_value) is not tx.fake_mode) 2156 ): 2157 raise InternalTorchDynamoError( 2158 "`example_value` needs to be a `FakeTensor`" 2159 f"wrapped by this instance of Dynamo. Found: {example_value}" 2160 ) 2161 2162 if isinstance(example_value, torch.Tensor): 2163 is_parameter = isinstance(example_value, torch.nn.Parameter) 2164 is_buffer = isinstance(example_value, torch.nn.Buffer) 2165 2166 # NB: In most (all?) cases, this does not actually do a clone. 2167 # (WARNING: this means that if we mutate metadata on the fake 2168 # tensor, the stored example value will update too!) 2169 example_value = _clone_input(example_value) 2170 set_example_value(proxy.node, example_value) 2171 specialized_props = target_cls.specialize(example_value) 2172 # TODO: not sure about this fake mode test 2173 if ( 2174 isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor) 2175 and example_value.fake_mode is tx.fake_mode 2176 ): 2177 tensor_type = subclass_type if subclass_type else torch.Tensor 2178 specialized_props["class_type"] = ( 2179 torch.nn.Parameter 2180 if is_parameter 2181 else torch.nn.Buffer 2182 if is_buffer 2183 else tensor_type 2184 ) 2185 2186 options.update(specialized_props) 2187 return target_cls(proxy, **options) 2188 elif ( 2189 hasattr(proxy.node.target, "__name__") 2190 and proxy.node.target.__name__ == "set_state" 2191 and isinstance(proxy.node.target.__self__, torch._C.Generator) 2192 or proxy.node.target == torch.random.set_rng_state 2193 ): 2194 return TorchInGraphFunctionVariable(proxy.node.target) 2195 elif ( 2196 proxy.node.target == torch._C._DisableFuncTorch 2197 or proxy.node.target == torch.cuda._is_in_bad_fork 2198 ): 2199 return UserDefinedObjectVariable(example_value) 2200 elif istype(example_value, torch.Size) and all( 2201 isinstance(x, int) for x in example_value 2202 ): 2203 sizes = [ConstantVariable.create(x) for x in example_value] 2204 return SizeVariable(sizes, **options) 2205 elif isinstance(example_value, (tuple, list)): 2206 set_example_value(proxy.node, example_value) 2207 unpacked = [] 2208 for i, val in enumerate(example_value): 2209 if val is None: 2210 # nn.MultiheadAttention() can return None, see issue #175 2211 unpacked.append( 2212 ConstantVariable.create(None, **options), 2213 ) 2214 else: 2215 proxy_i = proxy.tracer.create_proxy( 2216 kind="call_function", 2217 target=operator.getitem, 2218 args=(proxy, i), 2219 kwargs={}, 2220 ) 2221 2222 if "source" in options: 2223 source = options["source"] 2224 options_i = options.copy() 2225 options_i["source"] = GetItemSource( 2226 base=source, index=i, index_is_slice=False 2227 ) 2228 else: 2229 # use the same options object as parent 2230 options_i = options 2231 2232 # WARNING: this assumes the same target_cls as this tuple/list call 2233 unpacked.append( 2234 wrap_fx_proxy_cls( 2235 target_cls=target_cls, 2236 tx=tx, 2237 proxy=proxy_i, 2238 example_value=val, 2239 **options_i, 2240 ) 2241 ) 2242 if isinstance(example_value, torch.Size): 2243 # NB: Keep the old proxy around. See SizeVariable for an 2244 # explanation why 2245 return SizeVariable(unpacked, proxy, **options) 2246 elif istype(example_value, tuple): 2247 return TupleVariable(unpacked, **options) 2248 elif istype(example_value, (list, immutable_list)): 2249 return ListVariable(unpacked, mutable_local=MutableLocal(), **options) 2250 else: 2251 assert example_value.__class__.__module__ == "torch.return_types" or hasattr( 2252 example_value, "_fields" 2253 ), f"expected {example_value.__class__.__module__} == torch.return_types or named tuple but got {type(example_value)}" 2254 return NamedTupleVariable(unpacked, example_value.__class__, **options) 2255 elif example_value is None or proxy.node.target is torch.manual_seed: 2256 return ConstantVariable.create(None, **options) 2257 elif isinstance(example_value, (torch.SymInt, torch.SymFloat, torch.SymBool)): 2258 set_example_value(proxy.node, example_value) 2259 return SymNodeVariable(proxy, example_value, **options) 2260 elif ( 2261 inspect.isclass(proxy.node.target) 2262 and issubclass(proxy.node.target, _StreamBase) 2263 ) or proxy.node.target in [ 2264 device_interface.current_stream 2265 for _, device_interface in get_registered_device_interfaces() 2266 ]: 2267 set_example_value(proxy.node, example_value) 2268 return StreamVariable(proxy, example_value, example_value.device, **options) 2269 elif ( 2270 inspect.isclass(proxy.node.target) and issubclass(proxy.node.target, _EventBase) 2271 ) or proxy.node.target in [ 2272 device_interface.Event 2273 for _, device_interface in get_registered_device_interfaces() 2274 ]: 2275 set_example_value(proxy.node, example_value) 2276 return EventVariable(proxy, example_value, **options) 2277 elif proxy.node.target == "query" and proxy.node.op == "call_method": 2278 set_example_value(proxy.node, example_value) 2279 return ConstantVariable(example_value, **options) 2280 elif ( 2281 example_value is not None 2282 and isinstance(example_value, _EventBase) 2283 and proxy.node.target == "record_event" 2284 and proxy.node.op == "call_method" 2285 ): 2286 set_example_value(proxy.node, example_value) 2287 return EventVariable(proxy, example_value, **options) 2288 elif isinstance(example_value, int) and ( 2289 proxy.node.target 2290 in [ 2291 torch.sym_int, 2292 getattr, 2293 operator.getitem, 2294 torch._utils._element_size, 2295 torch.seed, 2296 operator.mod, 2297 torch._functorch.vmap._validate_and_get_batch_size, 2298 # some mac builds are missing torch.distributed.get_rank() 2299 getattr(torch.distributed, "get_rank", _missing), 2300 getattr(torch.distributed, "get_world_size", _missing), 2301 # This always wants to be in the graph, even if the constraint 2302 # results in a constant int 2303 torch._constrain_as_size, 2304 ] 2305 or ( 2306 # TODO: this is a little sus, because we didn't check what the self is 2307 proxy.node.op == "call_method" 2308 and proxy.node.target in ["bit_length"] 2309 ) 2310 ): 2311 set_example_value(proxy.node, example_value) 2312 return ConstantVariable.create(example_value, **options) 2313 elif isinstance(example_value, torch.backends.cuda.SDPAParams): 2314 from .sdpa import SDPAParamsVariable 2315 2316 set_example_value(proxy.node, example_value) 2317 return SDPAParamsVariable(proxy, **options) 2318 elif isinstance(example_value, bool) and proxy.node.target in [ 2319 torch._C._are_functorch_transforms_active, 2320 torch.backends.cuda.is_flash_attention_available, 2321 torch.backends.cuda.can_use_flash_attention, 2322 torch.backends.cuda.can_use_efficient_attention, 2323 ]: 2324 set_example_value(proxy.node, example_value) 2325 return ConstantVariable.create(example_value, **options) 2326 elif ( 2327 isinstance(example_value, (int, float, bool)) 2328 and proxy.node.target is call_torchbind 2329 ): 2330 set_example_value(proxy.node, example_value) 2331 return ConstantVariable.create(example_value, **options) 2332 else: 2333 unimplemented( 2334 "torch.* op returned non-Tensor " 2335 + f"{typestr(example_value)} {proxy.node.op} {proxy.node.target}", 2336 case_name="unsupported_operator", 2337 ) 2338 2339 2340# Tracks the sources of all fake tensors we wrap in Dynamo. 2341# Used by shape guard computation. 2342@dataclasses.dataclass 2343class TrackedFake: 2344 fake: Union[FakeTensor, SymInt] 2345 source: Source 2346 # Is None when fake is SymInt 2347 symbolic_context: Optional[SymbolicContext] 2348 2349 def __hash__(self) -> int: 2350 return hash((self.fake, self.source.name())) 2351 2352 def __eq__(self, other: object) -> bool: 2353 if isinstance(other, TrackedFake): 2354 return self.fake is other.fake and self.source.name() == other.source.name() 2355 return False 2356 2357 2358# Performs automatic dynamic dim determination. 2359# Returns a SymbolicContext 2360def _automatic_dynamic( 2361 e, tx, source, static_shapes, outer_only=False 2362) -> SymbolicContext: 2363 # strided NT not supported 2364 if e.is_nested and not isinstance( 2365 e, torch.nested._internal.nested_tensor.NestedTensor 2366 ): 2367 unimplemented("torch.compile does not support strided NestedTensor") 2368 2369 name = source.name() 2370 prior_policy = tx.output.tracing_context.tensor_to_context.get(e, None) 2371 shape_env_to_source_to_symbol_cache = ( 2372 prior_policy.shape_env_to_source_to_symbol_cache if prior_policy else None 2373 ) 2374 2375 # Get base context if the tensor is a view 2376 view_base_context: Optional[SymbolicContext] = None 2377 if e._is_view(): 2378 base_source = AttrSource(source, "_base") 2379 view_base_context = _automatic_dynamic(e._base, tx, base_source, static_shapes) 2380 2381 if is_traceable_wrapper_subclass(e) and not outer_only: 2382 # Get symbolic context for outer tensor 2383 outer_context = _automatic_dynamic( 2384 e, tx, source, static_shapes, outer_only=True 2385 ) 2386 2387 # Get symbolic contexts for inner tensors 2388 inner_contexts = {} # mapping from attr -> symbolic context 2389 attrs, _ = type(e).__tensor_flatten__(e) 2390 for attr in attrs: 2391 inner_tensor = getattr(e, attr) 2392 inner_source = AttrSource(source, attr) 2393 inner_contexts[attr] = _automatic_dynamic( 2394 inner_tensor, tx, inner_source, static_shapes 2395 ) 2396 2397 return SubclassSymbolicContext( 2398 dynamic_sizes=outer_context.dynamic_sizes, 2399 dynamic_strides=outer_context.dynamic_strides, 2400 constraint_sizes=outer_context.constraint_sizes, 2401 constraint_strides=outer_context.constraint_strides, 2402 view_base_context=view_base_context, 2403 tensor_source=outer_context.tensor_source, 2404 shape_env_to_source_to_symbol_cache=outer_context.shape_env_to_source_to_symbol_cache, 2405 inner_contexts=inner_contexts, 2406 ) 2407 2408 if static_shapes: 2409 return StatefulSymbolicContext( 2410 dynamic_sizes=[DimDynamic.STATIC] * e.dim(), 2411 dynamic_strides=[DimDynamic.INFER_STRIDE] * e.dim(), 2412 constraint_sizes=[None] * e.dim(), 2413 constraint_strides=[None] * e.dim(), 2414 view_base_context=view_base_context, 2415 tensor_source=source, 2416 shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache, 2417 ) 2418 2419 # We preserve the dynamism of inputs. For example, when users call 2420 # make_fx(torch.cond, tracing_mode="symbolic")(*args), inputs have SymInt sizes. 2421 from torch.fx.experimental.symbolic_shapes import is_nested_int 2422 2423 if any(isinstance(s, SymInt) and not is_nested_int(s) for s in e.size()): 2424 return StatefulSymbolicContext( 2425 dynamic_sizes=[ 2426 DimDynamic.DYNAMIC if isinstance(s, SymInt) else DimDynamic.STATIC 2427 for s in e.size() 2428 ], 2429 dynamic_strides=[DimDynamic.INFER_STRIDE] * e.dim(), 2430 constraint_sizes=[None] * e.dim(), 2431 constraint_strides=[None] * e.dim(), 2432 view_base_context=view_base_context, 2433 tensor_source=source, 2434 shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache, 2435 ) 2436 2437 # Prep for automatic dynamic 2438 def update_frame_state(size, stride): 2439 # Intentionally shadow e from parent scope so it is not accidentally 2440 # called 2441 e = None 2442 frame_state_entry = None 2443 if name not in tx.output.frame_state: 2444 # If there is no entry for this source, add the tensor to frame state with its current static size. 2445 # E.g., {} -> {"x": [2, 4]} 2446 frame_state_entry = FrameStateSizeEntry(None, None, None) 2447 frame_state_entry.size = list(size) 2448 frame_state_entry.stride = list(stride) 2449 else: 2450 frame_state_entry = tx.output.frame_state[name] 2451 if frame_state_entry.size is not None: 2452 if len(size) != len(frame_state_entry.size): 2453 # If there is already an entry, and the dim mismatches, replace the frame state entry with None. 2454 # E.g. {"x": [2, 3, 4]} -> {"x": None} 2455 log.debug( 2456 "automatic dynamic %s dim %s != %s", 2457 name, 2458 len(size), 2459 frame_state_entry.size, 2460 ) 2461 frame_state_entry.size = None 2462 frame_state_entry.stride = None 2463 else: 2464 # If there is already an entry, and the dim matches, for every size/stride in the frame state which 2465 # disagrees with the current static size/stride, replace it with None. 2466 # E.g., {"x": [2, 3]} -> {"x": [2, # None]} 2467 2468 has_size_changed = False 2469 for i, dim in enumerate(frame_state_entry.size): 2470 if dim is not None and size[i] != dim: 2471 log.debug( 2472 "automatic dynamic %s size(%s) %s != %s", 2473 name, 2474 i, 2475 size[i], 2476 dim, 2477 ) 2478 frame_state_entry.size[i] = None 2479 has_size_changed = ( 2480 has_size_changed or frame_state_entry.size[i] is None 2481 ) 2482 2483 # We want to trigger automatic dynamism when strides change, but we have to think whether stride should 2484 # be INFER_STRIDE or DYNAMIC. 2485 # 2486 # Case 1: if strides change because of size changes, we might not want to allocate a new symbol for 2487 # stride. Lets say we have a tensor (10, 20) and we mark the dim=1 dynamic for size. Resulting size will 2488 # be (10, s0) and stride can be either (s0, 1) or (s1, 1). In most cases, (s0, 1) is preferred because 2489 # users are not changing both size and stride. 2490 # 2491 # Case 2: But for another case, lets suppose the size remains same between the two invocations but stride 2492 # change. In this case, we definitely want to mark the changing stride to be DYNAMIC. 2493 2494 # Here, we use a hueristic to simplify determination of dynamic stride. For case 1, we will always 2495 # assume that stride will be inferred (INFER_STRIDE). This might be suboptimal, where user is doing something 2496 # arbitrary size and stride resizing, and we fail to trigger dynamism, but we have not seen any cases 2497 # yet. For case 2, we will mark the changing dimensions DYNAMIC. 2498 if not has_size_changed: 2499 for i, dim in enumerate(frame_state_entry.stride): 2500 if dim is not None and stride[i] != dim: 2501 log.debug( 2502 "automatic dynamic %s stride(%s) %s != %s", 2503 name, 2504 i, 2505 stride[i], 2506 dim, 2507 ) 2508 frame_state_entry.stride[i] = None 2509 tx.output.frame_state[name] = frame_state_entry 2510 2511 if (st := tx.distributed_state) is None: 2512 stride = e.stride() if not is_sparse_any(e) else () 2513 update_frame_state(e.size(), stride) 2514 frame_state_entry = tx.output.frame_state[name] 2515 elif st.all_states is None: 2516 # Preflight, always pretend as if it's static 2517 frame_state_entry = FrameStateSizeEntry( 2518 size=e.size(), scalar=None, stride=e.stride() 2519 ) 2520 st.local_state.input_sizes[name] = list(e.size()) 2521 st.local_state.input_strides[name] = list(e.stride()) 2522 else: 2523 # Apply the updates 2524 for sub_state in st.all_states: 2525 # Not all inputs are necessarily present on all ranks 2526 if name in sub_state.input_sizes and name in sub_state.input_strides: 2527 update_frame_state( 2528 sub_state.input_sizes[name], sub_state.input_strides[name] 2529 ) 2530 frame_state_entry = tx.output.frame_state[name] 2531 2532 # TODO: index export_constraints ahead of time so we don't have to 2533 # do a linear scan every time here 2534 t_id = id(e) 2535 dim2constraint = {} 2536 2537 def update_dim2constraint(dim, constraint_range, name): 2538 if dim in dim2constraint: 2539 from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint 2540 2541 old_constraint_range, old_name = dim2constraint[dim] 2542 new_constraint_range = StrictMinMaxConstraint( 2543 vr=constraint_range.vr & old_constraint_range.vr, 2544 warn_only=False, 2545 ) 2546 # It is possible for (non-None) old_name and name to be different 2547 # but this will only happen the corresponding Dims can be derived equal. 2548 new_name = old_name or name 2549 dim2constraint[dim] = new_constraint_range, new_name 2550 else: 2551 dim2constraint[dim] = constraint_range, name 2552 2553 if tx.output.export_constraints: 2554 for constraint in tx.output.export_constraints: 2555 if constraint.t_id == t_id: 2556 update_dim2constraint( 2557 constraint.dim, constraint.constraint_range, constraint.name 2558 ) 2559 2560 dynamic_sizes = [] 2561 dynamic_strides = [] 2562 constraint_sizes = [] 2563 constraint_strides = [] 2564 for i in range(e.dim()): 2565 # NB: mark dynamic has precedence over static 2566 marked_unbacked = i in getattr(e, "_dynamo_unbacked_indices", set()) 2567 marked_dynamic = i in getattr(e, "_dynamo_dynamic_indices", set()) 2568 marked_weak_dynamic = i in getattr(e, "_dynamo_weak_dynamic_indices", set()) 2569 marked_static = i in getattr(e, "_dynamo_static_indices", set()) 2570 2571 # NB: both static and dynamic have precedence over 2572 automatic_dynamic_size = config.automatic_dynamic_shapes and ( 2573 frame_state_entry.size is None or frame_state_entry.size[i] is None 2574 ) 2575 2576 # if size is None, no need to make stride dynamic 2577 automatic_dynamic_stride = config.automatic_dynamic_shapes and ( 2578 frame_state_entry.size is not None 2579 and ( 2580 frame_state_entry.stride is None or frame_state_entry.stride[i] is None 2581 ) 2582 ) 2583 2584 automatic_dynamic = automatic_dynamic_size or automatic_dynamic_stride 2585 2586 # Reflect the user directive in the frame_state 2587 # For dynamic, apply None always 2588 if frame_state_entry.size and marked_dynamic: 2589 log.debug("automatic dynamic %s marked dynamic", name) 2590 frame_state_entry.size[i] = None 2591 frame_state_entry.stride[i] = None 2592 2593 # We will process constraints first, as they will imply that we 2594 # have a dynamic dimension 2595 # Precedence: export constraints > eager constraints 2596 constraint = dim2constraint.get(i) 2597 if constraint is None: 2598 constraint_size = None 2599 constraint_stride = None 2600 if marked_dynamic and not config.allow_ignore_mark_dynamic: 2601 # constraint_stride is deliberaly kept None because no easy way to provide value ranges for mark dynamic 2602 constraint_stride = None 2603 if hasattr(e, "_dynamo_dynamic_range"): 2604 dim_range = [ 2605 dr for dr in e._dynamo_dynamic_range if dr.dim == i 2606 ].pop() 2607 if dim_range.min is None and dim_range.max is None: 2608 constraint_size = RelaxedUnspecConstraint(warn_only=False) 2609 else: 2610 from torch.fx.experimental.symbolic_shapes import ( 2611 StrictMinMaxConstraint, 2612 ) 2613 2614 constraint_size = StrictMinMaxConstraint( 2615 vr=ValueRanges(lower=dim_range.min, upper=dim_range.max), 2616 warn_only=False, 2617 ) 2618 else: 2619 constraint_size = RelaxedUnspecConstraint(warn_only=False) 2620 elif not marked_static and automatic_dynamic: 2621 if automatic_dynamic_size: 2622 constraint_size = RelaxedUnspecConstraint(warn_only=True) 2623 if automatic_dynamic_stride: 2624 constraint_stride = RelaxedUnspecConstraint(warn_only=True) 2625 else: 2626 constraint_size = None 2627 constraint_stride = None 2628 else: 2629 constraint_size, name_ = constraint 2630 constraint_stride = None 2631 dim_name = f"{name}.size()[{i}]" 2632 tx.output.shape_env.source_name_to_debug_name[dim_name] = name_ 2633 constraint_sizes.append(constraint_size) 2634 constraint_strides.append(constraint_stride) 2635 2636 if marked_unbacked: 2637 dynamic_size = DimDynamic.SIZE_LIKE_UNBACKED 2638 elif ( 2639 constraint_size is not None 2640 or marked_dynamic 2641 or marked_weak_dynamic 2642 or is_nested_int(e.size()[i]) 2643 ): 2644 # NB: We could assert static_shapes is False here, but it 2645 # seems better to allow the user to override symbolic_context in this 2646 # case 2647 dynamic_size = DimDynamic.DYNAMIC 2648 elif static_shapes or config.assume_static_by_default or marked_static: 2649 dynamic_size = DimDynamic.STATIC 2650 else: 2651 dynamic_size = DimDynamic.DUCK 2652 2653 if constraint_stride is not None: 2654 dynamic_stride = DimDynamic.DYNAMIC 2655 else: 2656 dynamic_stride = DimDynamic.INFER_STRIDE 2657 2658 dynamic_sizes.append(dynamic_size) 2659 dynamic_strides.append(dynamic_stride) 2660 2661 return StatefulSymbolicContext( 2662 dynamic_sizes=dynamic_sizes, 2663 dynamic_strides=dynamic_strides, 2664 constraint_sizes=constraint_sizes, 2665 constraint_strides=constraint_strides, 2666 view_base_context=view_base_context, 2667 tensor_source=source, 2668 shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache, 2669 ) 2670 2671 2672# See note [Tensor Fakification and Symbol Caching] 2673def wrap_to_fake_tensor_and_record( 2674 e, tx, *, source: Optional[Source], is_tensor: bool, parent_context=None 2675): 2676 if ( 2677 type(e) in (torch.Tensor, torch.nn.Parameter, FakeTensor) 2678 or isinstance(e, torch.Tensor) 2679 or is_traceable_wrapper_subclass(e) 2680 ): 2681 assert source is not None 2682 static_shapes, reason = tensor_always_has_static_shape( 2683 e, 2684 is_tensor, 2685 tensor_source=source, 2686 ) 2687 2688 if not parent_context: 2689 symbolic_context = _automatic_dynamic(e, tx, source, static_shapes) 2690 else: 2691 # Parent contexts are passed in when we are recursively creating 2692 # fake tensors for subclasses. A better design would be not to create a 2693 # parent/child relationship, but to recursively call _automatic_dynamic 2694 # as we recursively call wrap_to_fake_tensor_and_record. This runs 2695 # into bugs around how meta_utils knows and works to create fake tensors 2696 # with tensor subclasses. Ideally, dynamo would drive both the recursive 2697 # wrap_to_fake_tensor_and_record and _automatic_dynamic policy creation. 2698 assert isinstance(source, AttrSource) 2699 inner_context_name = source.member 2700 symbolic_context = parent_context.inner_contexts[inner_context_name] 2701 2702 log.debug( 2703 "wrap_to_fake %s %s %s %s", 2704 source.name(), 2705 tuple(e.shape), 2706 symbolic_context, 2707 type(e), 2708 ) 2709 fake_e = wrap_fake_exception( 2710 lambda: tx.fake_mode.from_tensor( 2711 e, 2712 source=source, 2713 symbolic_context=symbolic_context, 2714 ) 2715 ) 2716 if ( 2717 source is not None 2718 and isinstance(fake_e, FakeTensor) 2719 and (sym_val := fake_e.item_memo) is not None 2720 ): 2721 tx.output.tracked_fakes.append( 2722 TrackedFake(sym_val, CallMethodItemSource(source), symbolic_context) 2723 ) 2724 2725 if is_traceable_wrapper_subclass(fake_e): 2726 attrs, _ = fake_e.__tensor_flatten__() 2727 for attr in attrs: 2728 fake_inner = getattr(fake_e, attr) 2729 inner = getattr(e, attr) 2730 inner_source = AttrSource(source, attr) 2731 wrap_to_fake_tensor_and_record( 2732 inner, 2733 tx, 2734 source=inner_source, 2735 is_tensor=isinstance(fake_inner, torch.Tensor), 2736 parent_context=symbolic_context, 2737 ) 2738 2739 tx.output.tracing_context.tensor_to_context[e] = symbolic_context 2740 if is_sparse_any(fake_e): 2741 # TODO: for TensorGuards, this eventually may need more 2742 # fields for the size/stride of any other constituents 2743 values = fake_e._values() if fake_e.is_sparse else fake_e.values() 2744 tx.output.input_source_to_sizes_strides[source] = { 2745 "size": fake_e.size(), 2746 # TODO: revise this, but for now this stride instead of () 2747 # avoids SegFault with PYTORCH_TEST_WITH_DYNAMO=1 2748 "stride": (1,) * fake_e.ndim, 2749 "values_size": values.size(), 2750 "values_stride": values.stride(), 2751 } 2752 else: 2753 tx.output.input_source_to_sizes_strides[source] = { 2754 "size": fake_e.size(), 2755 "stride": fake_e.stride(), 2756 } 2757 2758 if ( 2759 is_tensor 2760 and not (static_shapes and source.is_specialized_nn_module()) 2761 and not is_constant_source(source) 2762 ): 2763 tx.output.tracked_fakes.append( 2764 TrackedFake(fake_e, source, symbolic_context) 2765 ) 2766 tx.output.tracked_fakes_id_to_source[id(e)].append(source) 2767 2768 return fake_e 2769 else: 2770 return e 2771 2772 2773class SourcelessBuilder: 2774 """ 2775 Like builder, but stateless and does not require a source. Useful for simple type->VT objects, or objects 2776 that are being created/evaporated during inlining (ex: consider a locally made list of tensors we then iterate over 2777 .), such a list should not show up as an artifact from inputs, nor in reconstruction, nor in the graph. However, 2778 there may be reasons to represent it as a ListVariable internally. 2779 2780 NOTE - Objects produced here are born UNGUARDED due to the nature of sources! 2781 2782 NOTE - This class is very new! It will have some rough edges, but it was created to stem the bleeding of giant 2783 if/else type->VariableTracker trees that were cropping up all over dynamo. 2784 """ 2785 2786 def __init__(self) -> None: 2787 raise AssertionError("Use SourcelessBuilder.create()") 2788 2789 @staticmethod 2790 def create(tx: "InstructionTranslator", value) -> VariableTracker: 2791 value_type = type(value) 2792 fast_handler = SourcelessBuilder._type_handlers.get(value_type) 2793 if fast_handler: 2794 return fast_handler(tx, value) 2795 2796 if isinstance(value, VariableTracker): 2797 # This is always valid to call, and useful for recursive calls. 2798 return value 2799 elif isinstance(value, dataclasses._HAS_DEFAULT_FACTORY_CLASS): 2800 return UserDefinedObjectVariable(value) 2801 elif ConstantVariable.is_literal(value): 2802 return ConstantVariable.create(value) 2803 elif callable(value) and trace_rules.lookup_callable(value) is not None: 2804 if is_callable_allowed(value): 2805 tx.output.has_user_defined_allowed_in_graph = True 2806 return trace_rules.lookup_callable(value)(value) 2807 elif is_function_or_wrapper(value): 2808 return trace_rules.lookup(value)(value) 2809 elif isinstance(value, enum.Enum): 2810 return EnumVariable(value) 2811 elif isinstance(value, (type, abc.ABCMeta)): 2812 return UserDefinedClassVariable(value) 2813 elif isinstance(value, types.MethodWrapperType): 2814 return MethodWrapperVariable(value) 2815 elif isinstance(value, torch.fx.graph_module.GraphModule): 2816 return SourcelessGraphModuleVariable(value) 2817 elif isinstance( 2818 value, (torch.utils._pytree.TreeSpec, torch.utils._pytree.LeafSpec) 2819 ): 2820 return UserDefinedObjectVariable(value) 2821 elif PlacementVariable.is_placement(value): 2822 return PlacementVariable(value) 2823 elif DeviceMeshVariable.is_device_mesh(value): 2824 return DeviceMeshVariable(value) 2825 elif isinstance(value, re.Pattern): 2826 return RegexPatternVariable(value) 2827 elif isinstance(value, torch._dynamo.variables.lazy.LazySymNodeFormatString): 2828 return ConstantVariable.create(str(value)) 2829 unimplemented( 2830 f"Unexpected type in sourceless builder {value_type.__module__}.{value_type.__qualname__}" 2831 ) 2832 2833 @staticmethod 2834 def wrap_constant_literal(value): 2835 assert ConstantVariable.is_literal(value) 2836 return ConstantVariable.create(value=value) 2837 2838 @staticmethod 2839 def make_type_handlers(): 2840 create = SourcelessBuilder.create 2841 handlers = {} 2842 for t in common_constant_types: 2843 handlers[t] = lambda tx, value: ConstantVariable(value) 2844 handlers[set] = lambda tx, value: SetVariable( 2845 [create(tx, x) for x in value], mutable_local=MutableLocal() 2846 ) 2847 handlers[dict] = lambda tx, value: ConstDictVariable( 2848 {create(tx, k): create(tx, v) for k, v in value.items()}, 2849 type(value), 2850 mutable_local=MutableLocal(), 2851 ) 2852 handlers[list] = lambda tx, value: ListVariable( 2853 [create(tx, x) for x in value], mutable_local=MutableLocal() 2854 ) 2855 handlers[tuple] = lambda tx, value: TupleVariable( 2856 [create(tx, x) for x in value] 2857 ) 2858 handlers[torch.Size] = lambda tx, value: SizeVariable( 2859 [create(tx, x) for x in value] 2860 ) 2861 handlers[collections.OrderedDict] = handlers[dict] 2862 handlers[immutable_dict] = handlers[dict] 2863 handlers[immutable_list] = handlers[list] 2864 handlers[random.Random] = lambda tx, value: RandomClassVariable() 2865 handlers[types.ModuleType] = lambda tx, value: PythonModuleVariable(value) 2866 2867 handlers[ 2868 torch.distributions.constraints._Real 2869 ] = lambda tx, value: UserDefinedObjectVariable( 2870 value, mutable_local=MutableLocal() 2871 ) 2872 handlers[ 2873 torch.distributions.constraints._Interval 2874 ] = lambda tx, value: UserDefinedObjectVariable( 2875 value, mutable_local=MutableLocal() 2876 ) 2877 handlers[ 2878 torch.distributions.constraints.Constraint 2879 ] = lambda tx, value: UserDefinedObjectVariable( 2880 value, mutable_local=MutableLocal() 2881 ) 2882 2883 def passthrough(tx: "InstructionTranslator", value): 2884 return value 2885 2886 for cls in VariableTrackerMeta.all_subclasses: 2887 handlers[cls] = passthrough 2888 return handlers 2889 2890 2891SourcelessBuilder._type_handlers = SourcelessBuilder.make_type_handlers() 2892