1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3import functools 4import inspect 5import logging 6import math 7import re 8from typing import Dict, List, TYPE_CHECKING 9 10import torch._C 11import torch._refs 12import torch.fx 13import torch.nn 14import torch.onnx.operators 15from torch._guards import TracingContext 16from torch._logging import warning_once 17from torch._streambase import _StreamBase 18from torch.utils._python_dispatch import is_traceable_wrapper_subclass_type 19 20from .. import config, polyfills, variables 21from ..codegen import PyCodegen 22from ..create_parameter_op import ( 23 can_convert_to_tracable_parameter, 24 new_parameter_placeholder, 25 tracable_create_parameter, 26) 27from ..device_interface import get_registered_device_interfaces 28from ..exc import unimplemented 29from ..guards import GuardBuilder, install_guard 30from ..source import SyntheticLocalSource 31from ..utils import ( 32 check_unspec_or_constant_args, 33 guard_if_dyn, 34 has_torch_function, 35 hashable, 36 product, 37 proxy_args_kwargs, 38 unwrap_if_wrapper, 39) 40from .base import VariableTracker 41from .ctx_manager import ( 42 AutocastModeVariable, 43 NullContextVariable, 44 TorchFunctionDisableVariable, 45) 46from .distributed import DistributedVariable, ProcessGroupVariable 47from .lists import ListVariable, TupleVariable 48from .torch_function import ( 49 can_dispatch_torch_function, 50 dispatch_torch_function, 51 TorchFunctionModeStackVariable, 52) 53 54 55try: 56 import numpy as np 57except ModuleNotFoundError: 58 np = None # type: ignore[assignment] 59 60try: 61 from torch.distributed._composable.fsdp import _fsdp_param_group 62except ModuleNotFoundError: 63 _fsdp_param_group = None # type: ignore[assignment] 64 65 66if TYPE_CHECKING: 67 from torch._dynamo.symbolic_convert import InstructionTranslator 68 69 70log = logging.getLogger(__name__) 71 72supported_ctx_manager_classes = dict.fromkeys( 73 [ 74 torch.profiler.profiler.profile, 75 torch.autograd.forward_ad._set_fwd_grad_enabled, 76 torch.autograd.forward_ad.dual_level, 77 torch.autograd.profiler.profile, 78 torch.autograd.profiler.record_function, 79 torch._C.DisableTorchFunctionSubclass, 80 torch._functorch.vmap.vmap_increment_nesting, 81 torch._functorch.eager_transforms.grad_increment_nesting, 82 torch._functorch.eager_transforms.jvp_increment_nesting, 83 torch._functorch.eager_transforms.enable_inplace_requires_grad, 84 torch.amp.autocast_mode.autocast, 85 torch.autograd.grad_mode.enable_grad, 86 torch.autograd.grad_mode.inference_mode, 87 torch.autograd.grad_mode.no_grad, 88 torch.autograd.grad_mode.set_grad_enabled, 89 torch.autograd.graph.disable_saved_tensors_hooks, 90 torch.cpu.amp.autocast_mode.autocast, 91 torch.cuda.amp.autocast_mode.autocast, 92 ] 93) 94 95 96REWRITE_OPS_TO_TENSOR_SIZE_METHOD = dict.fromkeys( 97 [ 98 torch.onnx.operators.shape_as_tensor, 99 torch._shape_as_tensor, 100 ] 101) 102 103constant_fold_functions = [ 104 torch._assert, 105 torch._utils._get_device_index, 106 torch._C._get_cublas_allow_tf32, 107 torch._C._is_any_autocast_enabled, 108 torch.cuda.get_device_properties, 109 torch.cuda.is_available, 110 torch.distributed.is_available, 111 torch.get_autocast_dtype, 112 torch.get_autocast_gpu_dtype, 113 torch.get_default_dtype, 114 torch.is_autocast_cache_enabled, 115 torch.is_autocast_cpu_enabled, 116 torch.is_autocast_enabled, 117 torch.is_complex, 118 torch.is_floating_point, 119 torch.nn.functional._Reduction.get_enum, # type: ignore[attr-defined] 120 torch.promote_types, 121 torch._C._get_privateuse1_backend_name, 122 torch.autograd._is_checkpoint_valid, 123] 124if torch.distributed.is_available(): 125 constant_fold_functions.extend( 126 [ 127 torch.distributed.is_initialized, 128 torch.distributed.get_rank, 129 torch.distributed.get_world_size, 130 ] 131 ) 132# Convert to dict for O(1) access times 133constant_fold_functions = dict.fromkeys(constant_fold_functions) 134 135 136tracing_state_functions = { 137 torch.jit.is_scripting: False, 138 torch.jit.is_tracing: False, 139 torch._C._get_tracing_state: None, 140 torch.fx._symbolic_trace.is_fx_tracing: False, 141 torch.onnx.is_in_onnx_export: False, 142 torch._dynamo.external_utils.is_compiling: True, 143 torch._utils.is_compiling: True, 144 torch.compiler.is_compiling: True, 145 torch.compiler.is_dynamo_compiling: True, 146 torch.nn.modules.activation._is_make_fx_tracing: False, 147} 148 149bin_ops = dict.fromkeys(["add", "sub", "mul", "div", "sqrt"]) 150 151 152class BaseTorchVariable(VariableTracker): 153 """common base for all torch.* functions, classes, modules and other things""" 154 155 @classmethod 156 def create_with_source(cls, value, source): 157 install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH)) 158 return cls(value, source=source) 159 160 def __init__(self, value, **kwargs) -> None: 161 super().__init__(**kwargs) 162 self.value = value 163 164 def reconstruct(self, codegen): 165 try: 166 name = f"{self.value.__module__}.{self.value.__name__}" 167 except Exception: 168 name = f"torch_obj_{id(self.value)}" 169 unique_var_name = "__" + re.sub(r"[^a-zA-Z0-9_]+", "_", name) 170 codegen.extend_output( 171 codegen.setup_globally_cached(unique_var_name, self.value) 172 ) 173 174 def as_proxy(self): 175 return self.value 176 177 def as_python_constant(self): 178 return self.value 179 180 def call_hasattr(self, tx: "InstructionTranslator", name): 181 result = hasattr(self.value, name) 182 return variables.ConstantVariable.create(result) 183 184 def can_constant_fold_through(self): 185 if self.value in constant_fold_functions: 186 return True 187 return getattr(self.value, "__module__", None) == "math" 188 189 190class TorchCtxManagerClassVariable(BaseTorchVariable): 191 """Points to a context manager class in torch.* that dynamo has implementations""" 192 193 def __repr__(self) -> str: 194 return f"TorchCtxManagerClassVariable({self.value})" 195 196 @staticmethod 197 def is_matching_cls(value): 198 # Unwrap if it's a functools.lru_cache wrapper 199 value = unwrap_if_wrapper(value) 200 # We can't do isinstance(value, type) check because some ctx managers 201 # are implemented as a function decorated by contextlib.contextmanager, 202 # E.g., torch._functorch.vmap.vmap_increment_nesting. 203 return ( 204 # Context manager type or function with @contextmanager is callable 205 callable(value) 206 and ( 207 hashable(value) # accesses value.__hash__() 208 and value in supported_ctx_manager_classes 209 ) 210 ) 211 212 def call_function( 213 self, 214 tx: "InstructionTranslator", 215 args: "List[VariableTracker]", 216 kwargs: "Dict[str, VariableTracker]", 217 ) -> "VariableTracker": 218 from . import ( 219 DisabledSavedTensorsHooksVariable, 220 DualLevelContextManager, 221 FSDPParamGroupUseTrainingStateVariable, 222 GradIncrementNestingCtxManagerVariable, 223 GradInplaceRequiresGradCtxManagerVariable, 224 GradModeVariable, 225 InferenceModeVariable, 226 JvpIncrementNestingCtxManagerVariable, 227 SetFwdGradEnabledContextManager, 228 StreamVariable, 229 VmapIncrementNestingCtxManagerVariable, 230 ) 231 232 if self.value is torch.no_grad: 233 if len(args) == 1 and isinstance( 234 args[0], variables.functions.BaseUserFunctionVariable 235 ): 236 ctx = GradModeVariable.create(tx, False) 237 return ctx.call_function(tx, args, kwargs) 238 else: 239 return GradModeVariable.create(tx, False) 240 elif self.value is torch.enable_grad: 241 if len(args) == 1 and isinstance( 242 args[0], variables.functions.BaseUserFunctionVariable 243 ): 244 ctx = GradModeVariable.create(tx, True) 245 return ctx.call_function(tx, args, kwargs) 246 return GradModeVariable.create(tx, True) 247 elif self.value is torch.set_grad_enabled and len(args) == 1: 248 return GradModeVariable.create( 249 tx, args[0].as_python_constant(), initialized=True 250 ) 251 elif self.value is torch.inference_mode: 252 assert len(args) <= 1 and len(kwargs) == 0 253 inf_mode = args[0].as_python_constant() if len(args) == 1 else True 254 return InferenceModeVariable.create(tx, inf_mode) 255 elif inspect.isclass(self.value) and issubclass(self.value, _StreamBase): 256 from torch._dynamo.variables.builder import wrap_fx_proxy_cls 257 258 return wrap_fx_proxy_cls( 259 StreamVariable, 260 tx, 261 tx.output.create_proxy( 262 "call_function", 263 self.value, 264 (), 265 {}, 266 ), 267 ) 268 elif self.value in ( 269 torch.amp.autocast_mode.autocast, 270 torch.cuda.amp.autocast, 271 torch.cpu.amp.autocast, 272 ): 273 return AutocastModeVariable.create(self.value, args, kwargs) 274 elif self.value in ( 275 torch.profiler.profile, 276 torch.profiler.record_function, 277 torch.autograd.profiler.profile, 278 torch.autograd.profiler.record_function, 279 ): 280 warning_once(log, "Profiler function %s will be ignored", self.value) 281 return NullContextVariable() 282 elif self.value is torch._C.DisableTorchFunctionSubclass: 283 assert not (args or kwargs) 284 return TorchFunctionDisableVariable.create(tx) 285 elif self.value is torch._functorch.vmap.vmap_increment_nesting: 286 assert len(args) == 2 287 return VmapIncrementNestingCtxManagerVariable.create( 288 tx, 289 [guard_if_dyn(x) for x in args], 290 ) 291 elif self.value is torch._functorch.eager_transforms.jvp_increment_nesting: 292 assert len(args) == 0 293 return JvpIncrementNestingCtxManagerVariable.create(tx) 294 elif self.value is torch.autograd.forward_ad._set_fwd_grad_enabled: 295 assert len(args) == 1 296 return SetFwdGradEnabledContextManager.create( 297 tx, 298 [guard_if_dyn(x) for x in args], 299 ) 300 elif self.value is torch.autograd.forward_ad.dual_level: 301 assert len(args) == 0 302 return DualLevelContextManager.create(tx) 303 elif self.value is torch._functorch.eager_transforms.grad_increment_nesting: 304 assert len(args) == 0 305 return GradIncrementNestingCtxManagerVariable.create(tx) 306 elif ( 307 self.value is torch._functorch.eager_transforms.enable_inplace_requires_grad 308 ): 309 assert len(args) == 1 310 return GradInplaceRequiresGradCtxManagerVariable.create( 311 tx, 312 [guard_if_dyn(x) for x in args], 313 ) 314 elif self.value is torch.autograd.graph.disable_saved_tensors_hooks: 315 assert len(args) == 1 316 return DisabledSavedTensorsHooksVariable.create( 317 tx, args[0].as_python_constant() 318 ) 319 elif ( 320 _fsdp_param_group is not None 321 and self.value is _fsdp_param_group.FSDPParamGroup.use_training_state 322 ): 323 assert len(args) == 2 324 return FSDPParamGroupUseTrainingStateVariable.create( 325 tx, args[0], args[1].as_python_constant() 326 ) 327 328 return super().call_function(tx, args, kwargs) 329 330 331class TorchInGraphFunctionVariable(BaseTorchVariable): 332 """Points to a torch function/method that should be put in FX graph""" 333 334 def __repr__(self) -> str: 335 return f"TorchInGraphFunctionVariable({self.value})" 336 337 def get_function(self): 338 return self.value 339 340 @staticmethod 341 @functools.lru_cache(None) 342 def _get_handlers(): 343 """Build a dict from function -> method to handle it so that we are O(1) 344 in terms of the number of function with special handling.""" 345 handlers = {} 346 347 def register(*fns): 348 def _register(handler): 349 for fn in fns: 350 assert fn not in handlers, fn 351 handlers[fn] = handler 352 return handler 353 354 assert callable(fns[0]) 355 return _register 356 357 from torch.backends.cuda import SDPAParams 358 359 from . import ( 360 ConstantVariable, 361 DeterministicAlgorithmsVariable, 362 GradModeVariable, 363 StreamContextVariable, 364 SymNodeVariable, 365 TensorVariable, 366 UserDefinedObjectVariable, 367 ) 368 from .builder import SourcelessBuilder, wrap_fx_proxy, wrap_fx_proxy_cls 369 370 @register(*tracing_state_functions) 371 def handle_tracing_state_functions( 372 self, tx: "InstructionTranslator", *args, **kwargs 373 ): 374 assert not args and not kwargs 375 # See: https://github.com/pytorch/pytorch/issues/110765 376 if self.value in ( 377 torch._utils.is_compiling, 378 torch._dynamo.external_utils.is_compiling, 379 torch.compiler.is_compiling, 380 torch.compiler.is_dynamo_compiling, 381 ): 382 tx.mark_inconsistent_side_effects() 383 return ConstantVariable.create(tracing_state_functions[self.value]) 384 385 @register(torch.overrides.get_default_nowrap_functions.__wrapped__) 386 def handle_get_default_nowrap_functions( 387 self, tx: "InstructionTranslator", *args, **kwargs 388 ): 389 # [Note: __torch_function__] we return empty here because we restrict 390 # the set of functions that we trace __torch_function__ on to 391 # functions outside of the actual set. Implementing this properly will require implementing 392 # some variable types to track and compare tensor getset descriptors 393 return SourcelessBuilder.create( 394 tx, torch.overrides.get_default_nowrap_functions() 395 ) 396 397 @register(torch.ops.inductor.accumulate_grad_.default) 398 def handle_accumulate_grad_(self, tx: "InstructionTranslator", *args, **kwargs): 399 return tx.inline_user_function_return( 400 SourcelessBuilder.create(tx, polyfills.accumulate_grad), args, kwargs 401 ) 402 403 @register(math.radians) 404 def handle_radians(self, tx: "InstructionTranslator", *args, **kwargs): 405 if not check_unspec_or_constant_args(args, kwargs): 406 # Use polyfill to convert math.radians(x) into math.pi * x / 180.0 407 return tx.inline_user_function_return( 408 SourcelessBuilder.create(tx, polyfills.radians), args, kwargs 409 ) 410 411 @register(torch.is_tensor, torch.overrides.is_tensor_like) 412 def handle_is_tensor(self, tx: "InstructionTranslator", arg): 413 if isinstance(arg, TensorVariable) or ( 414 self.value is torch.overrides.is_tensor_like 415 and isinstance(arg, UserDefinedObjectVariable) 416 and hasattr(arg.value, "__torch_function__") 417 ): 418 return ConstantVariable.create(True) 419 else: 420 return ConstantVariable.create(False) 421 422 @register( 423 torch.is_floating_point, 424 torch.is_complex, 425 ) 426 def handle_is_floating_point(self, tx: "InstructionTranslator", input): 427 input_arg = input 428 if isinstance(input_arg, TensorVariable) and input_arg.dtype is not None: 429 if self.value is torch.is_floating_point: 430 return ConstantVariable.create(input_arg.dtype.is_floating_point) 431 elif self.value is torch.is_complex: 432 return ConstantVariable.create(input_arg.dtype.is_complex) 433 else: 434 raise AssertionError(f"calling {self.value}") 435 436 @register(torch.numel) 437 def handle_numel(self, tx: "InstructionTranslator", input): 438 if isinstance(input, TensorVariable) and input.size is not None: 439 return ConstantVariable.create(product(input.size)) 440 elif isinstance(input, TensorVariable): 441 # Workaround dynamic shapes issue 442 return input.call_method(tx, "numel", [], {}) 443 444 @register(*REWRITE_OPS_TO_TENSOR_SIZE_METHOD) 445 def handle_tensor_size_rewrites(self, tx: "InstructionTranslator", input): 446 assert isinstance(input, TensorVariable) 447 return input.call_method(tx, "size", [], {}) 448 449 @register( 450 torch.nn.modules.utils._single, 451 torch.nn.modules.utils._pair, 452 torch.nn.modules.utils._triple, 453 torch.nn.modules.utils._quadruple, 454 torch.nn.modules.utils._ntuple, 455 ) 456 def handle_ntuple(self, tx: "InstructionTranslator", *args, **kwargs): 457 return self._call_ntuple(tx, args, kwargs) 458 459 @register(torch.is_grad_enabled) 460 def handle_is_grad_enabled(self, tx): 461 install_guard(GradModeVariable._guards_singleton) 462 return ConstantVariable.create(torch.is_grad_enabled()) 463 464 @register(torch.use_deterministic_algorithms) 465 def handle_use_deterministic_algorithms( 466 self, tx: "InstructionTranslator", mode, warn_only=False 467 ): 468 if warn_only and warn_only.as_python_constant(): 469 unimplemented("torch.use_deterministic_algorithms(warn_only=True)") 470 return DeterministicAlgorithmsVariable.create(tx, mode.as_python_constant()) 471 472 @register(torch.are_deterministic_algorithms_enabled) 473 def handle_are_deterministic_algorithms_enabled(self, tx): 474 install_guard(DeterministicAlgorithmsVariable._guards_singleton) 475 return ConstantVariable.create(torch.are_deterministic_algorithms_enabled()) 476 477 @register(torch._C._is_torch_function_enabled) 478 def handle_is_torch_function_enabled(self, tx): 479 install_guard(TorchFunctionDisableVariable._guards_singleton) 480 return ConstantVariable.create(tx.output.torch_function_enabled) 481 482 @register( 483 torch.overrides.has_torch_function, 484 torch.overrides.has_torch_function_variadic, 485 torch.overrides.has_torch_function_unary, 486 ) 487 def handle_has_torch_function(self, tx: "InstructionTranslator", *args): 488 elems = ( 489 args[0].unpack_var_sequence(tx) 490 if len(args) == 1 and isinstance(args[0], TupleVariable) 491 else args 492 ) 493 return ConstantVariable.create( 494 any(has_torch_function(x) for x in elems), 495 ) 496 497 @register( 498 *dict.fromkeys( # remove duplicates 499 device_interface.stream 500 for _, device_interface in get_registered_device_interfaces() 501 ) 502 ) 503 def handle_device_interface_stream(self, tx: "InstructionTranslator", stream): 504 return StreamContextVariable.create(tx, stream) 505 506 @register(torch.from_numpy) 507 def handle_from_numpy(self, tx: "InstructionTranslator", *args): 508 if not config.trace_numpy: 509 unimplemented("torch.from_numpy. config.trace_numpy is False") 510 if not np: 511 unimplemented("torch.from_numpy. NumPy is not available") 512 return wrap_fx_proxy_cls( 513 target_cls=TensorVariable, 514 tx=tx, 515 proxy=tx.output.create_proxy( 516 "call_function", 517 torch.as_tensor, 518 *proxy_args_kwargs(args, {}), 519 ), 520 example_value=None, 521 ) 522 523 @register(torch.jit.annotate) 524 def handle_jit_annotate(self, tx: "InstructionTranslator", the_type, the_value): 525 return the_value 526 527 @register(torch.backends.cudnn.is_acceptable) 528 def handle_cudnn_is_acceptable( 529 self, tx: "InstructionTranslator", tensor, *extra 530 ): 531 # is_acceptable(tensor) returns true if 532 # (a) tensor dtype/device are supported by cudnn 533 # (b) cudnn is available 534 # (c) some initialization has completed 535 # technically, it depends on some global state from (c) (torch.backends.cudnn.__cudnn_version) 536 assert not extra, "Expect 1 input to cudnn.is_acceptable" 537 assert isinstance( 538 tensor, TensorVariable 539 ), "Expect input to cudnn.is_acceptable to be a tensor" 540 tensor_inp = torch.tensor(0, dtype=tensor.dtype, device=tensor.device) 541 return ConstantVariable.create( 542 torch.backends.cudnn.is_acceptable(tensor_inp) 543 ) 544 545 @register(torch.utils.hooks.BackwardHook) 546 def handle_backward_hook(self, tx: "InstructionTranslator", *args, **kwargs): 547 return variables.BackwardHookVariable.create(tx, *args, **kwargs) 548 549 @register(torch.nn.Parameter) 550 def handle_parameter(self, tx: "InstructionTranslator", *args, **kwargs): 551 return self.call_nn_parameter(tx, *args, **kwargs) 552 553 @register(torch.ops.aten.sym_size, torch.ops.aten.sym_size.int) 554 def handle_sym_size(self_, tx, self, dim=None): 555 # we see this when retracing already traced code 556 if dim is not None: 557 return self.call_method(tx, "size", [dim], {}) 558 559 @register(torch.ops.aten.sym_stride, torch.ops.aten.sym_stride.int) 560 def handle_sym_stride(self_, tx, self, dim=None): 561 if dim is not None: 562 return self.call_method(tx, "stride", [dim], {}) 563 564 @register(torch.addcdiv) 565 def handle_addcdiv(self, tx: "InstructionTranslator", *args, **kwargs): 566 if len(args) == 3 and "value" in kwargs and len(kwargs) == 1: 567 # decompose addcdiv into constituent ops, prevents a graph break due to converting 568 # value to a scalar 569 result = TorchInGraphFunctionVariable(torch.div).call_function( 570 tx, [*args[1:]], {} 571 ) 572 result = TorchInGraphFunctionVariable(torch.mul).call_function( 573 tx, [result, kwargs["value"]], {} 574 ) 575 return TorchInGraphFunctionVariable(torch.add).call_function( 576 tx, [args[0], result], {} 577 ) 578 579 @register(torch._foreach_lerp_) 580 def handle_inplace_foreach_lerp_scalar( 581 self, tx: "InstructionTranslator", *args, **kwargs 582 ): 583 if len(args) == 3 and not isinstance(args[2], ListVariable) and not kwargs: 584 return tx.inline_user_function_return( 585 SourcelessBuilder.create(tx, polyfills.foreach_lerp_inplace), 586 args, 587 kwargs, 588 ) 589 590 @register(torch._foreach_pow) 591 def handle_foreach_pow_scalar( 592 self, tx: "InstructionTranslator", *args, **kwargs 593 ): 594 # In eager it's more performant to call item() from within the C op implementation 595 # in compile, it's more performant to not graph break. 596 if len(args) == 2 and isinstance(args[0], TensorVariable) and not kwargs: 597 return tx.inline_user_function_return( 598 SourcelessBuilder.create(tx, polyfills.foreach_pow_scalar), 599 args, 600 kwargs, 601 ) 602 603 @register(torch._assert) 604 def handle_assert(self, tx: "InstructionTranslator", condition, message): 605 if (condition.is_python_constant() and condition.as_python_constant()) or ( 606 isinstance(condition, variables.SymNodeVariable) 607 and condition.evaluate_expr() 608 ): 609 return ConstantVariable(None) 610 611 @register(SDPAParams) 612 def handle_sdpa_params(self, tx: "InstructionTranslator", *args, **kwargs): 613 return wrap_fx_proxy( 614 tx, 615 proxy=tx.output.create_proxy( 616 "call_function", 617 torch._C._SDPAParams, 618 *proxy_args_kwargs(args, kwargs), 619 ), 620 param_vars=args, 621 ) 622 623 if DistributedVariable.is_available(): 624 from torch.distributed.distributed_c10d import ( 625 _get_group_size_by_name, 626 _get_group_tag, 627 _rank_not_in_group, 628 _resolve_group_name_by_ranks_and_tag, 629 get_process_group_ranks, 630 ) 631 from torch.distributed.tensor import DTensor 632 633 @register( 634 _get_group_size_by_name, 635 _get_group_tag, 636 _rank_not_in_group, 637 get_process_group_ranks, 638 _resolve_group_name_by_ranks_and_tag, 639 ) 640 def handle_constant_processgroup_functions( 641 self, tx: "InstructionTranslator", *args 642 ): 643 # because the input is a "ProcessGroupVariable", we'll be guarding on its 644 # ID_MATCH based on how it was constructed. 645 646 # We desugar it at trace-time into ranks by directly calling util 647 # bake the result into the trace 648 if len(args) == 1: 649 # group or group name 650 assert isinstance(args[0], (ProcessGroupVariable, ConstantVariable)) 651 elif len(args) == 2: 652 # ranks + tag 653 assert isinstance(args[0], ListVariable) and isinstance( 654 args[1], ConstantVariable 655 ) 656 else: 657 raise AssertionError( 658 f"Invalid group value ({args}) for constant pg " 659 f"function {self.value}" 660 ) 661 args_as_value = [arg.as_python_constant() for arg in args] 662 invocation_result = self.value(*args_as_value) 663 664 # Note - while we *could* cook up sources around invocations, like a FunctionSource 665 # the space of invoking functions in the middle of the guard chain is very iffy. As such, 666 # guard propagation via options is the best we can do. 667 return SourcelessBuilder.create(tx, invocation_result) 668 669 @register(DTensor.from_local) 670 def handle_from_local(self, tx: "InstructionTranslator", *args, **kwargs): 671 # rewrite non-primitive args/kwargs to be included in the on-the-fly prim function 672 # and rewrite args to have only proxyable args, then insert call_function 673 args_as_value = [x.as_python_constant() for x in args[1:]] 674 kwargs_as_value = { 675 k: v.as_python_constant() 676 for k, v in kwargs.items() 677 if k not in ["shape", "stride"] 678 } 679 kwargs_to_be_proxied = { 680 k: kwargs[k] for k in ["shape", "stride"] if k in kwargs 681 } 682 683 def fn_with_prim_types(x, shape=None, stride=None): 684 return self.value( 685 x, *args_as_value, **kwargs_as_value, shape=shape, stride=stride 686 ) 687 688 # attach the same function name for better debugging 689 fn_with_prim_types.__name__ = "prim " + self.value.__name__ 690 691 return wrap_fx_proxy( 692 tx=tx, 693 proxy=tx.output.create_proxy( 694 "call_function", 695 fn_with_prim_types, 696 *proxy_args_kwargs( 697 [args[0]], 698 kwargs_to_be_proxied, 699 ), 700 ), 701 ) 702 703 @register(torch.nested.nested_tensor) 704 def handle_nested_tensor( 705 self, 706 tx: "InstructionTranslator", 707 tensor_list=None, 708 *args, 709 layout=None, 710 **kwargs, 711 ): 712 from .lists import BaseListVariable 713 714 if layout and layout.as_python_constant() == torch.strided: 715 unimplemented("torch.compile does not support strided NestedTensor") 716 if not isinstance(tensor_list, BaseListVariable): 717 unimplemented("nested_tensor with non-list input") 718 719 @register(torch.nn.functional.one_hot) 720 def handle_one_hot(self, tx: "InstructionTranslator", *args, **kwargs): 721 if len(args) + len(kwargs) == 1 or ( 722 len(args) == 2 723 and args[1].is_python_constant() 724 and args[1].as_python_constant() == -1 725 ): 726 unimplemented( 727 "torch.nn.functional.one_hot with data-dependent output shape" 728 ) 729 730 @register(torch.fx.experimental.symbolic_shapes.guard_size_oblivious) 731 def handle_guard_size_oblivious(self, tx: "InstructionTranslator", expr): 732 if isinstance(expr, SymNodeVariable): 733 # TODO: this probably should be folded somewhere else but I'm not sure where 734 # TODO: some of the other symbolic_shapes special tools can also get this treatment too 735 return variables.ConstantVariable.create( 736 torch.fx.experimental.symbolic_shapes.guard_size_oblivious( 737 expr.sym_num 738 ) 739 ) 740 elif isinstance(expr, ConstantVariable): 741 return expr 742 743 @register(torch._C._autograd._unsafe_set_version_counter) 744 def handle_unsafe_set_version_counter( 745 self, tx: "InstructionTranslator", *args, **kwargs 746 ): 747 from ..tensor_version_op import _unsafe_set_version_counter 748 749 return TorchInGraphFunctionVariable( 750 _unsafe_set_version_counter 751 ).call_function(tx, [*args], kwargs) 752 753 @register(torch.tensor) 754 def handle_torch_tensor(self, tx: "InstructionTranslator", *args, **kwargs): 755 def check_any_unspec(x): 756 # NB: This includes UnspecializedPythonVariable 757 if isinstance(x, (TensorVariable, SymNodeVariable)): 758 return True 759 elif isinstance(x, (ListVariable, TupleVariable)): 760 return any(check_any_unspec(y) for y in x.items) 761 # TODO: there maybe other recursive structures you need to 762 # check 763 else: 764 return False 765 766 data_arg = None 767 if args: 768 data_arg = args[0] 769 elif "data" in kwargs: 770 data_arg = kwargs["data"] 771 772 # NB: OK to pass torch.tensor(tensor), this will trace fine 773 if not isinstance(data_arg, TensorVariable) and check_any_unspec(data_arg): 774 # This is slower and less canonical, so only use it if we 775 # have to 776 return TorchInGraphFunctionVariable(torch._refs.tensor).call_function( 777 tx, [*args], kwargs 778 ) 779 780 @register(torch._C._pop_torch_function_stack) 781 def handle_pop_torch_function( 782 self, tx: "InstructionTranslator", *args, **kwargs 783 ): 784 assert not args and not kwargs 785 if not tx.symbolic_torch_function_mode_stack: 786 raise unimplemented("Popping from an empty torch function mode stack") 787 TorchFunctionModeStackVariable.register_mutation(tx) 788 return tx.symbolic_torch_function_mode_stack.pop() 789 790 @register(torch._C._push_on_torch_function_stack) 791 def handle_push_torch_function( 792 self, tx: "InstructionTranslator", *args, **kwargs 793 ): 794 assert len(args) == 1 and not kwargs 795 TorchFunctionModeStackVariable.register_mutation(tx) 796 tx.symbolic_torch_function_mode_stack.append(args[0]) 797 return ConstantVariable.create(None) 798 799 @register(torch._C._len_torch_function_stack) 800 def handle_len_torch_function( 801 self, tx: "InstructionTranslator", *args, **kwargs 802 ): 803 assert not args and not kwargs 804 return ConstantVariable.create(len(tx.symbolic_torch_function_mode_stack)) 805 806 @register(torch.set_default_device) 807 def handle_set_default_device( 808 self, tx: "InstructionTranslator", *args, **kwargs 809 ): 810 # Today this is inserted in the graph, once TF mode 811 # handling is complete, we can trace the device context 812 # like any other TF mode and remove this special handling 813 # Insert the TF mode representing the device context at 814 # the bottom of the stack to match the eager semantics 815 # Running the graph will ensure that the DeviceContext mode is 816 # at the correct position in the stack 817 TorchFunctionModeStackVariable.register_mutation(tx) 818 if args[0].is_python_constant() and args[0].as_python_constant() is None: 819 TorchFunctionModeStackVariable.clear_default_device(tx) 820 else: 821 TorchFunctionModeStackVariable.register_device_context_insertion(tx) 822 823 return None 824 825 return handlers 826 827 def call_function( 828 self, 829 tx: "InstructionTranslator", 830 args: "List[VariableTracker]", 831 kwargs: "Dict[str, VariableTracker]", 832 ) -> "VariableTracker": 833 from . import ConstantVariable, SymNodeVariable, TensorVariable 834 from .builder import wrap_fx_proxy 835 836 if self.can_constant_fold_through() and check_unspec_or_constant_args( 837 args, kwargs 838 ): 839 # constant fold 840 return ConstantVariable.create( 841 self.as_python_constant()( 842 *[x.as_python_constant() for x in args], 843 **{k: v.as_python_constant() for k, v in kwargs.items()}, 844 ), 845 ) 846 847 special_handler = self._get_handlers().get(self.value) 848 if special_handler: 849 result = special_handler(self, tx, *args, **kwargs) 850 if result: 851 return result 852 853 if can_dispatch_torch_function(tx, args, kwargs): 854 return dispatch_torch_function(tx, self, args, kwargs) 855 else: 856 any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args) 857 858 all_ints_or_floats = all( 859 isinstance(x, (variables.ConstantVariable, variables.SymNodeVariable)) 860 for x in args 861 ) 862 if ( 863 getattr(self.value, "__module__", "") == "torch" 864 and self.value.__name__ in bin_ops 865 and any_symints_or_symfloats 866 and all_ints_or_floats 867 ): 868 msg = f"""\ 869Calling {str(self.value)} on only torch.SymInt arguments is not yet supported. 870To support this behavior, we need to allow const-propping tensors that store symint data. 871For now, dynamo will explicitly graph break when it encounters user code with this behavior. 872""" 873 log.warning(msg) 874 unimplemented(msg) 875 876 # TODO(voz): Replace w/ dynamic shape rewrite table. 877 # Ideally, we would be able to do this at ctor time, but alas we need a combination 878 # of value + args to determine this. 879 fn_ = self.value 880 if any_symints_or_symfloats: 881 torch_sym_op = f"_sym_{self.value.__name__}" 882 if getattr(self.value, "__module__", None) == "math" and hasattr( 883 torch, torch_sym_op 884 ): 885 fn_ = getattr(torch, torch_sym_op) 886 887 fake_out_shape = None 888 if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable): 889 # Calling fake tensor propagation can mutate the out= tensor in 890 # tx.output.tracked_fakes. tracked_fakes are used to apply 891 # symbolic_shape guards. Mutating them destroys the information 892 # prior to tracing, which is essential for creating right 893 # guards. So save the shape now, and check later if it has 894 # changed. If it has, graph break. 895 fake_out_shape = kwargs["out"].proxy.node.meta["example_value"].shape 896 897 tensor_variable = wrap_fx_proxy( 898 tx=tx, 899 proxy=tx.output.create_proxy( 900 "call_function", 901 fn_, 902 *proxy_args_kwargs(args, kwargs), 903 ), 904 ) 905 906 if ( 907 isinstance(tensor_variable, TensorVariable) 908 and "requires_grad" in kwargs 909 and kwargs["requires_grad"].as_python_constant() 910 ): 911 unimplemented( 912 """factory functions that return tensors that require grad are not supported. 913Either create the tensor outside the compiled region, or do not set the tensor to require_grad""" 914 ) 915 916 if "out" in kwargs and not ( 917 isinstance(kwargs["out"], variables.ConstantVariable) 918 and kwargs["out"].as_python_constant() is None 919 ): 920 # out variants of torch operators like torch.sort and 921 # torch.sigmoid mutate the tensors in the out field. Track such 922 # tensors and rewrite the symbolic locals. 923 if isinstance(tensor_variable, TupleVariable): 924 assert isinstance(kwargs["out"], (TupleVariable, ListVariable)) 925 output_tensor_names = [ 926 tx.find_symbolic_locals_name(x) for x in kwargs["out"].items 927 ] 928 for idx, name in enumerate(output_tensor_names): 929 if name in tx.symbolic_locals: 930 tx.symbolic_locals[name] = tensor_variable.items[idx] 931 for out_tensor, result_tensor in zip( 932 kwargs["out"].items, tensor_variable.items 933 ): 934 if ( 935 out_tensor.source 936 and out_tensor in tx.output.graphargs 937 and isinstance(out_tensor, variables.TensorVariable) 938 and isinstance(result_tensor, variables.TensorVariable) 939 and out_tensor.size != result_tensor.size 940 ): 941 # It's hard to get out variants with resizing on graph inputs work 942 # properly across dynamo/aot/inductor, just fall back. 943 unimplemented("out variants with resizing on graph inputs") 944 elif isinstance(tensor_variable, TensorVariable): 945 assert isinstance(kwargs["out"], TensorVariable) 946 assert "example_value" in kwargs["out"].proxy.node.meta 947 fake_tensor = tensor_variable.proxy.node.meta["example_value"] 948 fake_out = kwargs["out"].proxy.node.meta["example_value"] 949 if ( 950 kwargs["out"].source 951 and kwargs["out"] in tx.output.graphargs 952 and fake_out_shape != fake_tensor.shape 953 ): 954 # It's hard to get out variants with resizing on graph inputs work 955 # properly across dynamo/aot/inductor, just fall back. 956 unimplemented("out variants with resizing on graph inputs") 957 if not torch._prims_common.is_contiguous(fake_out): 958 # It's difficult to handle strides correctly in functionalization 959 # when calling an out= op with a non-contiguous out argument 960 unimplemented( 961 "out= op was called where output tensor was non-contiguous" 962 ) 963 name = tx.find_symbolic_locals_name(kwargs["out"]) 964 if name in tx.symbolic_locals: 965 tx.symbolic_locals[name] = tensor_variable 966 elif ( 967 isinstance(tensor_variable, ConstantVariable) 968 and tensor_variable.value is None 969 ): 970 # Handle out-variant custom ops that return None. 971 if isinstance(kwargs["out"], TensorVariable): 972 assert "example_value" in kwargs["out"].proxy.node.meta 973 fake_out = kwargs["out"].proxy.node.meta["example_value"] 974 if not torch._prims_common.is_contiguous(fake_out): 975 # It's difficult to handle strides correctly in functionalization 976 # when calling an out= op with a non-contiguous out argument 977 unimplemented( 978 "out= op was called where output tensor was non-contiguous" 979 ) 980 elif isinstance(kwargs["out"], ListVariable): 981 for idx, x in enumerate(kwargs["out"].items): 982 assert "example_value" in x.proxy.node.meta # type: ignore[attr-defined] 983 fake_out = x.proxy.node.meta["example_value"] # type: ignore[attr-defined] 984 if not torch._prims_common.is_contiguous(fake_out): 985 # It's difficult to handle strides correctly in functionalization 986 # when calling an out= op with a non-contiguous out argument 987 unimplemented( 988 "out= op was called where some of the output tensors were non-contiguous" 989 ) 990 else: 991 unimplemented(f"out variant of {type(kwargs['out'])}") 992 993 return tensor_variable 994 995 def _call_ntuple(self, tx: "InstructionTranslator", args, kwargs): 996 """inline behavior of torch.nn.modules.utils._ntuple""" 997 if self.value is torch.nn.modules.utils._ntuple: 998 count = args[0].as_python_constant() 999 else: 1000 count = self.value.__closure__[0].cell_contents 1001 assert isinstance(count, int) 1002 assert not kwargs 1003 1004 def handle_ntuple(value): 1005 if value.has_unpack_var_sequence(tx): 1006 return variables.TupleVariable( 1007 list(value.unpack_var_sequence(tx)), 1008 ) 1009 elif value.is_python_constant(): 1010 # constant prop through it 1011 return variables.ConstantVariable.create( 1012 torch.nn.modules.utils._ntuple(count)(value.as_python_constant()), 1013 ) 1014 else: 1015 unimplemented(f"torch.nn.modules.utils._ntuple({value})") 1016 1017 if self.value is torch.nn.modules.utils._ntuple: 1018 return variables.LambdaVariable(handle_ntuple) 1019 else: 1020 return handle_ntuple(args[0]) 1021 1022 @classmethod 1023 def call_nn_parameter(cls, tx, data=None, requires_grad=True): 1024 """A call to torch.nn.Parameter() gets lifted to before the graph""" 1025 if tx.export: 1026 unimplemented("nn parameter construction not supported with export") 1027 1028 if isinstance(requires_grad, variables.VariableTracker): 1029 try: 1030 requires_grad = requires_grad.as_python_constant() 1031 except NotImplementedError: 1032 unimplemented("Parameter(requires_grad=...) not constant") 1033 1034 if not isinstance(data, variables.TensorVariable): 1035 unimplemented(f"Parameter(data={data}) not implemented") 1036 1037 # this results in cleaner graphs, but only works for inputs 1038 if data.source: 1039 return cls._nn_param_via_prefix_insert(tx, data, requires_grad) 1040 1041 if is_traceable_wrapper_subclass_type(data.class_type): 1042 unimplemented("Parameter constructor with tensor subclass NYI") 1043 1044 if not can_convert_to_tracable_parameter(): 1045 unimplemented("Workaround for issues with nn_parameter construction") 1046 1047 try: 1048 shape = tuple(data.var_getattr(tx, "shape").as_python_constant()) 1049 dtype = data.var_getattr(tx, "dtype").as_python_constant() 1050 device = data.var_getattr(tx, "device").as_python_constant() 1051 except NotImplementedError as e: 1052 unimplemented(f"Parameter not python_constant: {e}") 1053 1054 placeholder = tx.output.synthetic_graph_input( 1055 new_parameter_placeholder, [shape, dtype, device, requires_grad] 1056 ) 1057 if data.requires_grad: 1058 data = data.call_method(tx, "detach", [], {}) 1059 1060 from .builder import wrap_fx_proxy 1061 1062 result = wrap_fx_proxy( 1063 tx, 1064 tx.output.create_proxy( 1065 "call_function", 1066 tracable_create_parameter, 1067 (data.as_proxy(), placeholder.as_proxy()), 1068 {}, 1069 ), 1070 ) 1071 assert isinstance(result, variables.TensorVariable) 1072 result.class_type = torch.nn.Parameter 1073 1074 # TODO(jansel/bdhirsh) - There is some issue with 1075 # tracable_create_paramter. It does not seem to use the right 1076 # grad_enabled. Since this is parameter, we can just override the 1077 # has_grad_fn field to False to workaround the issue. 1078 result.has_grad_fn = False 1079 1080 # In reconstruct() should use the original parameter. The one returned by the graph will be an alias. 1081 result.source = placeholder.source 1082 1083 # TODO(jansel): if the new param falls out of scope, currently it won't get freed until 1084 # the end of the graph. We should fix this. 1085 return result 1086 1087 @staticmethod 1088 def _nn_param_via_prefix_insert(tx: "InstructionTranslator", data, requires_grad): 1089 # Alternate version if we have a .source 1090 from .builder import VariableBuilder 1091 1092 varname = tx.output.new_var() 1093 1094 # construct the nn.Parmeter before the graph save it to varname 1095 cg = PyCodegen(tx) 1096 cg.add_push_null(lambda: cg.load_import_from("torch.nn", "Parameter")) 1097 cg(data.source) 1098 cg(variables.ConstantVariable(requires_grad)) 1099 cg.call_function(2, False) 1100 cg.store(varname) 1101 tx.output.pregraph_bytecode.extend(cg.get_instructions()) 1102 1103 data_node = data.as_proxy().node 1104 if data_node.op not in ("placeholder", "get_attr"): 1105 unimplemented( 1106 "Unexpected type of data placeholder op for parameter construction" 1107 ) 1108 1109 # add the newly constructed nn.Parameter as a graph input 1110 source = SyntheticLocalSource(varname) 1111 example_value = torch.nn.Parameter( 1112 tx.output.example_value_from_input_node(data.as_proxy().node) 1113 ) 1114 result = VariableBuilder(tx, source)(example_value) 1115 # No need to guard on this since we already guarded on `data`. 1116 # These guards would fail since varname doesn't exist until after the function starts 1117 TracingContext.get().guards_context.dynamo_guards.remove_guards_with_source( 1118 source 1119 ) 1120 return result 1121