xref: /aosp_15_r20/external/pytorch/torch/_dynamo/variables/torch.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3import functools
4import inspect
5import logging
6import math
7import re
8from typing import Dict, List, TYPE_CHECKING
9
10import torch._C
11import torch._refs
12import torch.fx
13import torch.nn
14import torch.onnx.operators
15from torch._guards import TracingContext
16from torch._logging import warning_once
17from torch._streambase import _StreamBase
18from torch.utils._python_dispatch import is_traceable_wrapper_subclass_type
19
20from .. import config, polyfills, variables
21from ..codegen import PyCodegen
22from ..create_parameter_op import (
23    can_convert_to_tracable_parameter,
24    new_parameter_placeholder,
25    tracable_create_parameter,
26)
27from ..device_interface import get_registered_device_interfaces
28from ..exc import unimplemented
29from ..guards import GuardBuilder, install_guard
30from ..source import SyntheticLocalSource
31from ..utils import (
32    check_unspec_or_constant_args,
33    guard_if_dyn,
34    has_torch_function,
35    hashable,
36    product,
37    proxy_args_kwargs,
38    unwrap_if_wrapper,
39)
40from .base import VariableTracker
41from .ctx_manager import (
42    AutocastModeVariable,
43    NullContextVariable,
44    TorchFunctionDisableVariable,
45)
46from .distributed import DistributedVariable, ProcessGroupVariable
47from .lists import ListVariable, TupleVariable
48from .torch_function import (
49    can_dispatch_torch_function,
50    dispatch_torch_function,
51    TorchFunctionModeStackVariable,
52)
53
54
55try:
56    import numpy as np
57except ModuleNotFoundError:
58    np = None  # type: ignore[assignment]
59
60try:
61    from torch.distributed._composable.fsdp import _fsdp_param_group
62except ModuleNotFoundError:
63    _fsdp_param_group = None  # type: ignore[assignment]
64
65
66if TYPE_CHECKING:
67    from torch._dynamo.symbolic_convert import InstructionTranslator
68
69
70log = logging.getLogger(__name__)
71
72supported_ctx_manager_classes = dict.fromkeys(
73    [
74        torch.profiler.profiler.profile,
75        torch.autograd.forward_ad._set_fwd_grad_enabled,
76        torch.autograd.forward_ad.dual_level,
77        torch.autograd.profiler.profile,
78        torch.autograd.profiler.record_function,
79        torch._C.DisableTorchFunctionSubclass,
80        torch._functorch.vmap.vmap_increment_nesting,
81        torch._functorch.eager_transforms.grad_increment_nesting,
82        torch._functorch.eager_transforms.jvp_increment_nesting,
83        torch._functorch.eager_transforms.enable_inplace_requires_grad,
84        torch.amp.autocast_mode.autocast,
85        torch.autograd.grad_mode.enable_grad,
86        torch.autograd.grad_mode.inference_mode,
87        torch.autograd.grad_mode.no_grad,
88        torch.autograd.grad_mode.set_grad_enabled,
89        torch.autograd.graph.disable_saved_tensors_hooks,
90        torch.cpu.amp.autocast_mode.autocast,
91        torch.cuda.amp.autocast_mode.autocast,
92    ]
93)
94
95
96REWRITE_OPS_TO_TENSOR_SIZE_METHOD = dict.fromkeys(
97    [
98        torch.onnx.operators.shape_as_tensor,
99        torch._shape_as_tensor,
100    ]
101)
102
103constant_fold_functions = [
104    torch._assert,
105    torch._utils._get_device_index,
106    torch._C._get_cublas_allow_tf32,
107    torch._C._is_any_autocast_enabled,
108    torch.cuda.get_device_properties,
109    torch.cuda.is_available,
110    torch.distributed.is_available,
111    torch.get_autocast_dtype,
112    torch.get_autocast_gpu_dtype,
113    torch.get_default_dtype,
114    torch.is_autocast_cache_enabled,
115    torch.is_autocast_cpu_enabled,
116    torch.is_autocast_enabled,
117    torch.is_complex,
118    torch.is_floating_point,
119    torch.nn.functional._Reduction.get_enum,  # type: ignore[attr-defined]
120    torch.promote_types,
121    torch._C._get_privateuse1_backend_name,
122    torch.autograd._is_checkpoint_valid,
123]
124if torch.distributed.is_available():
125    constant_fold_functions.extend(
126        [
127            torch.distributed.is_initialized,
128            torch.distributed.get_rank,
129            torch.distributed.get_world_size,
130        ]
131    )
132# Convert to dict for O(1) access times
133constant_fold_functions = dict.fromkeys(constant_fold_functions)
134
135
136tracing_state_functions = {
137    torch.jit.is_scripting: False,
138    torch.jit.is_tracing: False,
139    torch._C._get_tracing_state: None,
140    torch.fx._symbolic_trace.is_fx_tracing: False,
141    torch.onnx.is_in_onnx_export: False,
142    torch._dynamo.external_utils.is_compiling: True,
143    torch._utils.is_compiling: True,
144    torch.compiler.is_compiling: True,
145    torch.compiler.is_dynamo_compiling: True,
146    torch.nn.modules.activation._is_make_fx_tracing: False,
147}
148
149bin_ops = dict.fromkeys(["add", "sub", "mul", "div", "sqrt"])
150
151
152class BaseTorchVariable(VariableTracker):
153    """common base for all torch.* functions, classes, modules and other things"""
154
155    @classmethod
156    def create_with_source(cls, value, source):
157        install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH))
158        return cls(value, source=source)
159
160    def __init__(self, value, **kwargs) -> None:
161        super().__init__(**kwargs)
162        self.value = value
163
164    def reconstruct(self, codegen):
165        try:
166            name = f"{self.value.__module__}.{self.value.__name__}"
167        except Exception:
168            name = f"torch_obj_{id(self.value)}"
169        unique_var_name = "__" + re.sub(r"[^a-zA-Z0-9_]+", "_", name)
170        codegen.extend_output(
171            codegen.setup_globally_cached(unique_var_name, self.value)
172        )
173
174    def as_proxy(self):
175        return self.value
176
177    def as_python_constant(self):
178        return self.value
179
180    def call_hasattr(self, tx: "InstructionTranslator", name):
181        result = hasattr(self.value, name)
182        return variables.ConstantVariable.create(result)
183
184    def can_constant_fold_through(self):
185        if self.value in constant_fold_functions:
186            return True
187        return getattr(self.value, "__module__", None) == "math"
188
189
190class TorchCtxManagerClassVariable(BaseTorchVariable):
191    """Points to a context manager class in torch.* that dynamo has implementations"""
192
193    def __repr__(self) -> str:
194        return f"TorchCtxManagerClassVariable({self.value})"
195
196    @staticmethod
197    def is_matching_cls(value):
198        # Unwrap if it's a functools.lru_cache wrapper
199        value = unwrap_if_wrapper(value)
200        # We can't do isinstance(value, type) check because some ctx managers
201        # are implemented as a function decorated by contextlib.contextmanager,
202        # E.g., torch._functorch.vmap.vmap_increment_nesting.
203        return (
204            # Context manager type or function with @contextmanager is callable
205            callable(value)
206            and (
207                hashable(value)  # accesses value.__hash__()
208                and value in supported_ctx_manager_classes
209            )
210        )
211
212    def call_function(
213        self,
214        tx: "InstructionTranslator",
215        args: "List[VariableTracker]",
216        kwargs: "Dict[str, VariableTracker]",
217    ) -> "VariableTracker":
218        from . import (
219            DisabledSavedTensorsHooksVariable,
220            DualLevelContextManager,
221            FSDPParamGroupUseTrainingStateVariable,
222            GradIncrementNestingCtxManagerVariable,
223            GradInplaceRequiresGradCtxManagerVariable,
224            GradModeVariable,
225            InferenceModeVariable,
226            JvpIncrementNestingCtxManagerVariable,
227            SetFwdGradEnabledContextManager,
228            StreamVariable,
229            VmapIncrementNestingCtxManagerVariable,
230        )
231
232        if self.value is torch.no_grad:
233            if len(args) == 1 and isinstance(
234                args[0], variables.functions.BaseUserFunctionVariable
235            ):
236                ctx = GradModeVariable.create(tx, False)
237                return ctx.call_function(tx, args, kwargs)
238            else:
239                return GradModeVariable.create(tx, False)
240        elif self.value is torch.enable_grad:
241            if len(args) == 1 and isinstance(
242                args[0], variables.functions.BaseUserFunctionVariable
243            ):
244                ctx = GradModeVariable.create(tx, True)
245                return ctx.call_function(tx, args, kwargs)
246            return GradModeVariable.create(tx, True)
247        elif self.value is torch.set_grad_enabled and len(args) == 1:
248            return GradModeVariable.create(
249                tx, args[0].as_python_constant(), initialized=True
250            )
251        elif self.value is torch.inference_mode:
252            assert len(args) <= 1 and len(kwargs) == 0
253            inf_mode = args[0].as_python_constant() if len(args) == 1 else True
254            return InferenceModeVariable.create(tx, inf_mode)
255        elif inspect.isclass(self.value) and issubclass(self.value, _StreamBase):
256            from torch._dynamo.variables.builder import wrap_fx_proxy_cls
257
258            return wrap_fx_proxy_cls(
259                StreamVariable,
260                tx,
261                tx.output.create_proxy(
262                    "call_function",
263                    self.value,
264                    (),
265                    {},
266                ),
267            )
268        elif self.value in (
269            torch.amp.autocast_mode.autocast,
270            torch.cuda.amp.autocast,
271            torch.cpu.amp.autocast,
272        ):
273            return AutocastModeVariable.create(self.value, args, kwargs)
274        elif self.value in (
275            torch.profiler.profile,
276            torch.profiler.record_function,
277            torch.autograd.profiler.profile,
278            torch.autograd.profiler.record_function,
279        ):
280            warning_once(log, "Profiler function %s will be ignored", self.value)
281            return NullContextVariable()
282        elif self.value is torch._C.DisableTorchFunctionSubclass:
283            assert not (args or kwargs)
284            return TorchFunctionDisableVariable.create(tx)
285        elif self.value is torch._functorch.vmap.vmap_increment_nesting:
286            assert len(args) == 2
287            return VmapIncrementNestingCtxManagerVariable.create(
288                tx,
289                [guard_if_dyn(x) for x in args],
290            )
291        elif self.value is torch._functorch.eager_transforms.jvp_increment_nesting:
292            assert len(args) == 0
293            return JvpIncrementNestingCtxManagerVariable.create(tx)
294        elif self.value is torch.autograd.forward_ad._set_fwd_grad_enabled:
295            assert len(args) == 1
296            return SetFwdGradEnabledContextManager.create(
297                tx,
298                [guard_if_dyn(x) for x in args],
299            )
300        elif self.value is torch.autograd.forward_ad.dual_level:
301            assert len(args) == 0
302            return DualLevelContextManager.create(tx)
303        elif self.value is torch._functorch.eager_transforms.grad_increment_nesting:
304            assert len(args) == 0
305            return GradIncrementNestingCtxManagerVariable.create(tx)
306        elif (
307            self.value is torch._functorch.eager_transforms.enable_inplace_requires_grad
308        ):
309            assert len(args) == 1
310            return GradInplaceRequiresGradCtxManagerVariable.create(
311                tx,
312                [guard_if_dyn(x) for x in args],
313            )
314        elif self.value is torch.autograd.graph.disable_saved_tensors_hooks:
315            assert len(args) == 1
316            return DisabledSavedTensorsHooksVariable.create(
317                tx, args[0].as_python_constant()
318            )
319        elif (
320            _fsdp_param_group is not None
321            and self.value is _fsdp_param_group.FSDPParamGroup.use_training_state
322        ):
323            assert len(args) == 2
324            return FSDPParamGroupUseTrainingStateVariable.create(
325                tx, args[0], args[1].as_python_constant()
326            )
327
328        return super().call_function(tx, args, kwargs)
329
330
331class TorchInGraphFunctionVariable(BaseTorchVariable):
332    """Points to a torch function/method that should be put in FX graph"""
333
334    def __repr__(self) -> str:
335        return f"TorchInGraphFunctionVariable({self.value})"
336
337    def get_function(self):
338        return self.value
339
340    @staticmethod
341    @functools.lru_cache(None)
342    def _get_handlers():
343        """Build a dict from function -> method to handle it so that we are O(1)
344        in terms of the number of function with special handling."""
345        handlers = {}
346
347        def register(*fns):
348            def _register(handler):
349                for fn in fns:
350                    assert fn not in handlers, fn
351                    handlers[fn] = handler
352                return handler
353
354            assert callable(fns[0])
355            return _register
356
357        from torch.backends.cuda import SDPAParams
358
359        from . import (
360            ConstantVariable,
361            DeterministicAlgorithmsVariable,
362            GradModeVariable,
363            StreamContextVariable,
364            SymNodeVariable,
365            TensorVariable,
366            UserDefinedObjectVariable,
367        )
368        from .builder import SourcelessBuilder, wrap_fx_proxy, wrap_fx_proxy_cls
369
370        @register(*tracing_state_functions)
371        def handle_tracing_state_functions(
372            self, tx: "InstructionTranslator", *args, **kwargs
373        ):
374            assert not args and not kwargs
375            # See: https://github.com/pytorch/pytorch/issues/110765
376            if self.value in (
377                torch._utils.is_compiling,
378                torch._dynamo.external_utils.is_compiling,
379                torch.compiler.is_compiling,
380                torch.compiler.is_dynamo_compiling,
381            ):
382                tx.mark_inconsistent_side_effects()
383            return ConstantVariable.create(tracing_state_functions[self.value])
384
385        @register(torch.overrides.get_default_nowrap_functions.__wrapped__)
386        def handle_get_default_nowrap_functions(
387            self, tx: "InstructionTranslator", *args, **kwargs
388        ):
389            # [Note: __torch_function__] we return empty here because we restrict
390            # the set of functions that we trace __torch_function__ on to
391            # functions outside of the actual set. Implementing this properly will require implementing
392            # some variable types to track and compare tensor getset descriptors
393            return SourcelessBuilder.create(
394                tx, torch.overrides.get_default_nowrap_functions()
395            )
396
397        @register(torch.ops.inductor.accumulate_grad_.default)
398        def handle_accumulate_grad_(self, tx: "InstructionTranslator", *args, **kwargs):
399            return tx.inline_user_function_return(
400                SourcelessBuilder.create(tx, polyfills.accumulate_grad), args, kwargs
401            )
402
403        @register(math.radians)
404        def handle_radians(self, tx: "InstructionTranslator", *args, **kwargs):
405            if not check_unspec_or_constant_args(args, kwargs):
406                # Use polyfill to convert math.radians(x) into math.pi * x / 180.0
407                return tx.inline_user_function_return(
408                    SourcelessBuilder.create(tx, polyfills.radians), args, kwargs
409                )
410
411        @register(torch.is_tensor, torch.overrides.is_tensor_like)
412        def handle_is_tensor(self, tx: "InstructionTranslator", arg):
413            if isinstance(arg, TensorVariable) or (
414                self.value is torch.overrides.is_tensor_like
415                and isinstance(arg, UserDefinedObjectVariable)
416                and hasattr(arg.value, "__torch_function__")
417            ):
418                return ConstantVariable.create(True)
419            else:
420                return ConstantVariable.create(False)
421
422        @register(
423            torch.is_floating_point,
424            torch.is_complex,
425        )
426        def handle_is_floating_point(self, tx: "InstructionTranslator", input):
427            input_arg = input
428            if isinstance(input_arg, TensorVariable) and input_arg.dtype is not None:
429                if self.value is torch.is_floating_point:
430                    return ConstantVariable.create(input_arg.dtype.is_floating_point)
431                elif self.value is torch.is_complex:
432                    return ConstantVariable.create(input_arg.dtype.is_complex)
433                else:
434                    raise AssertionError(f"calling {self.value}")
435
436        @register(torch.numel)
437        def handle_numel(self, tx: "InstructionTranslator", input):
438            if isinstance(input, TensorVariable) and input.size is not None:
439                return ConstantVariable.create(product(input.size))
440            elif isinstance(input, TensorVariable):
441                # Workaround dynamic shapes issue
442                return input.call_method(tx, "numel", [], {})
443
444        @register(*REWRITE_OPS_TO_TENSOR_SIZE_METHOD)
445        def handle_tensor_size_rewrites(self, tx: "InstructionTranslator", input):
446            assert isinstance(input, TensorVariable)
447            return input.call_method(tx, "size", [], {})
448
449        @register(
450            torch.nn.modules.utils._single,
451            torch.nn.modules.utils._pair,
452            torch.nn.modules.utils._triple,
453            torch.nn.modules.utils._quadruple,
454            torch.nn.modules.utils._ntuple,
455        )
456        def handle_ntuple(self, tx: "InstructionTranslator", *args, **kwargs):
457            return self._call_ntuple(tx, args, kwargs)
458
459        @register(torch.is_grad_enabled)
460        def handle_is_grad_enabled(self, tx):
461            install_guard(GradModeVariable._guards_singleton)
462            return ConstantVariable.create(torch.is_grad_enabled())
463
464        @register(torch.use_deterministic_algorithms)
465        def handle_use_deterministic_algorithms(
466            self, tx: "InstructionTranslator", mode, warn_only=False
467        ):
468            if warn_only and warn_only.as_python_constant():
469                unimplemented("torch.use_deterministic_algorithms(warn_only=True)")
470            return DeterministicAlgorithmsVariable.create(tx, mode.as_python_constant())
471
472        @register(torch.are_deterministic_algorithms_enabled)
473        def handle_are_deterministic_algorithms_enabled(self, tx):
474            install_guard(DeterministicAlgorithmsVariable._guards_singleton)
475            return ConstantVariable.create(torch.are_deterministic_algorithms_enabled())
476
477        @register(torch._C._is_torch_function_enabled)
478        def handle_is_torch_function_enabled(self, tx):
479            install_guard(TorchFunctionDisableVariable._guards_singleton)
480            return ConstantVariable.create(tx.output.torch_function_enabled)
481
482        @register(
483            torch.overrides.has_torch_function,
484            torch.overrides.has_torch_function_variadic,
485            torch.overrides.has_torch_function_unary,
486        )
487        def handle_has_torch_function(self, tx: "InstructionTranslator", *args):
488            elems = (
489                args[0].unpack_var_sequence(tx)
490                if len(args) == 1 and isinstance(args[0], TupleVariable)
491                else args
492            )
493            return ConstantVariable.create(
494                any(has_torch_function(x) for x in elems),
495            )
496
497        @register(
498            *dict.fromkeys(  # remove duplicates
499                device_interface.stream
500                for _, device_interface in get_registered_device_interfaces()
501            )
502        )
503        def handle_device_interface_stream(self, tx: "InstructionTranslator", stream):
504            return StreamContextVariable.create(tx, stream)
505
506        @register(torch.from_numpy)
507        def handle_from_numpy(self, tx: "InstructionTranslator", *args):
508            if not config.trace_numpy:
509                unimplemented("torch.from_numpy. config.trace_numpy is False")
510            if not np:
511                unimplemented("torch.from_numpy. NumPy is not available")
512            return wrap_fx_proxy_cls(
513                target_cls=TensorVariable,
514                tx=tx,
515                proxy=tx.output.create_proxy(
516                    "call_function",
517                    torch.as_tensor,
518                    *proxy_args_kwargs(args, {}),
519                ),
520                example_value=None,
521            )
522
523        @register(torch.jit.annotate)
524        def handle_jit_annotate(self, tx: "InstructionTranslator", the_type, the_value):
525            return the_value
526
527        @register(torch.backends.cudnn.is_acceptable)
528        def handle_cudnn_is_acceptable(
529            self, tx: "InstructionTranslator", tensor, *extra
530        ):
531            # is_acceptable(tensor) returns true if
532            #   (a) tensor dtype/device are supported by cudnn
533            #   (b) cudnn is available
534            #   (c) some initialization has completed
535            # technically, it depends on some global state from (c) (torch.backends.cudnn.__cudnn_version)
536            assert not extra, "Expect 1 input to cudnn.is_acceptable"
537            assert isinstance(
538                tensor, TensorVariable
539            ), "Expect input to cudnn.is_acceptable to be a tensor"
540            tensor_inp = torch.tensor(0, dtype=tensor.dtype, device=tensor.device)
541            return ConstantVariable.create(
542                torch.backends.cudnn.is_acceptable(tensor_inp)
543            )
544
545        @register(torch.utils.hooks.BackwardHook)
546        def handle_backward_hook(self, tx: "InstructionTranslator", *args, **kwargs):
547            return variables.BackwardHookVariable.create(tx, *args, **kwargs)
548
549        @register(torch.nn.Parameter)
550        def handle_parameter(self, tx: "InstructionTranslator", *args, **kwargs):
551            return self.call_nn_parameter(tx, *args, **kwargs)
552
553        @register(torch.ops.aten.sym_size, torch.ops.aten.sym_size.int)
554        def handle_sym_size(self_, tx, self, dim=None):
555            # we see this when retracing already traced code
556            if dim is not None:
557                return self.call_method(tx, "size", [dim], {})
558
559        @register(torch.ops.aten.sym_stride, torch.ops.aten.sym_stride.int)
560        def handle_sym_stride(self_, tx, self, dim=None):
561            if dim is not None:
562                return self.call_method(tx, "stride", [dim], {})
563
564        @register(torch.addcdiv)
565        def handle_addcdiv(self, tx: "InstructionTranslator", *args, **kwargs):
566            if len(args) == 3 and "value" in kwargs and len(kwargs) == 1:
567                # decompose addcdiv into constituent ops, prevents a graph break due to converting
568                # value to a scalar
569                result = TorchInGraphFunctionVariable(torch.div).call_function(
570                    tx, [*args[1:]], {}
571                )
572                result = TorchInGraphFunctionVariable(torch.mul).call_function(
573                    tx, [result, kwargs["value"]], {}
574                )
575                return TorchInGraphFunctionVariable(torch.add).call_function(
576                    tx, [args[0], result], {}
577                )
578
579        @register(torch._foreach_lerp_)
580        def handle_inplace_foreach_lerp_scalar(
581            self, tx: "InstructionTranslator", *args, **kwargs
582        ):
583            if len(args) == 3 and not isinstance(args[2], ListVariable) and not kwargs:
584                return tx.inline_user_function_return(
585                    SourcelessBuilder.create(tx, polyfills.foreach_lerp_inplace),
586                    args,
587                    kwargs,
588                )
589
590        @register(torch._foreach_pow)
591        def handle_foreach_pow_scalar(
592            self, tx: "InstructionTranslator", *args, **kwargs
593        ):
594            # In eager it's more performant to call item() from within the C op implementation
595            # in compile, it's more performant to not graph break.
596            if len(args) == 2 and isinstance(args[0], TensorVariable) and not kwargs:
597                return tx.inline_user_function_return(
598                    SourcelessBuilder.create(tx, polyfills.foreach_pow_scalar),
599                    args,
600                    kwargs,
601                )
602
603        @register(torch._assert)
604        def handle_assert(self, tx: "InstructionTranslator", condition, message):
605            if (condition.is_python_constant() and condition.as_python_constant()) or (
606                isinstance(condition, variables.SymNodeVariable)
607                and condition.evaluate_expr()
608            ):
609                return ConstantVariable(None)
610
611        @register(SDPAParams)
612        def handle_sdpa_params(self, tx: "InstructionTranslator", *args, **kwargs):
613            return wrap_fx_proxy(
614                tx,
615                proxy=tx.output.create_proxy(
616                    "call_function",
617                    torch._C._SDPAParams,
618                    *proxy_args_kwargs(args, kwargs),
619                ),
620                param_vars=args,
621            )
622
623        if DistributedVariable.is_available():
624            from torch.distributed.distributed_c10d import (
625                _get_group_size_by_name,
626                _get_group_tag,
627                _rank_not_in_group,
628                _resolve_group_name_by_ranks_and_tag,
629                get_process_group_ranks,
630            )
631            from torch.distributed.tensor import DTensor
632
633            @register(
634                _get_group_size_by_name,
635                _get_group_tag,
636                _rank_not_in_group,
637                get_process_group_ranks,
638                _resolve_group_name_by_ranks_and_tag,
639            )
640            def handle_constant_processgroup_functions(
641                self, tx: "InstructionTranslator", *args
642            ):
643                # because the input is a "ProcessGroupVariable", we'll be guarding on its
644                # ID_MATCH based on how it was constructed.
645
646                # We desugar it at trace-time into ranks by directly calling util
647                # bake the result into the trace
648                if len(args) == 1:
649                    # group or group name
650                    assert isinstance(args[0], (ProcessGroupVariable, ConstantVariable))
651                elif len(args) == 2:
652                    # ranks + tag
653                    assert isinstance(args[0], ListVariable) and isinstance(
654                        args[1], ConstantVariable
655                    )
656                else:
657                    raise AssertionError(
658                        f"Invalid group value ({args}) for constant pg "
659                        f"function {self.value}"
660                    )
661                args_as_value = [arg.as_python_constant() for arg in args]
662                invocation_result = self.value(*args_as_value)
663
664                # Note - while we *could* cook up sources around invocations, like a FunctionSource
665                # the space of invoking functions in the middle of the guard chain is very iffy. As such,
666                # guard propagation via options is the best we can do.
667                return SourcelessBuilder.create(tx, invocation_result)
668
669            @register(DTensor.from_local)
670            def handle_from_local(self, tx: "InstructionTranslator", *args, **kwargs):
671                # rewrite non-primitive args/kwargs to be included in the on-the-fly prim function
672                # and rewrite args to have only proxyable args, then insert call_function
673                args_as_value = [x.as_python_constant() for x in args[1:]]
674                kwargs_as_value = {
675                    k: v.as_python_constant()
676                    for k, v in kwargs.items()
677                    if k not in ["shape", "stride"]
678                }
679                kwargs_to_be_proxied = {
680                    k: kwargs[k] for k in ["shape", "stride"] if k in kwargs
681                }
682
683                def fn_with_prim_types(x, shape=None, stride=None):
684                    return self.value(
685                        x, *args_as_value, **kwargs_as_value, shape=shape, stride=stride
686                    )
687
688                # attach the same function name for better debugging
689                fn_with_prim_types.__name__ = "prim " + self.value.__name__
690
691                return wrap_fx_proxy(
692                    tx=tx,
693                    proxy=tx.output.create_proxy(
694                        "call_function",
695                        fn_with_prim_types,
696                        *proxy_args_kwargs(
697                            [args[0]],
698                            kwargs_to_be_proxied,
699                        ),
700                    ),
701                )
702
703        @register(torch.nested.nested_tensor)
704        def handle_nested_tensor(
705            self,
706            tx: "InstructionTranslator",
707            tensor_list=None,
708            *args,
709            layout=None,
710            **kwargs,
711        ):
712            from .lists import BaseListVariable
713
714            if layout and layout.as_python_constant() == torch.strided:
715                unimplemented("torch.compile does not support strided NestedTensor")
716            if not isinstance(tensor_list, BaseListVariable):
717                unimplemented("nested_tensor with non-list input")
718
719        @register(torch.nn.functional.one_hot)
720        def handle_one_hot(self, tx: "InstructionTranslator", *args, **kwargs):
721            if len(args) + len(kwargs) == 1 or (
722                len(args) == 2
723                and args[1].is_python_constant()
724                and args[1].as_python_constant() == -1
725            ):
726                unimplemented(
727                    "torch.nn.functional.one_hot with data-dependent output shape"
728                )
729
730        @register(torch.fx.experimental.symbolic_shapes.guard_size_oblivious)
731        def handle_guard_size_oblivious(self, tx: "InstructionTranslator", expr):
732            if isinstance(expr, SymNodeVariable):
733                # TODO: this probably should be folded somewhere else but I'm not sure where
734                # TODO: some of the other symbolic_shapes special tools can also get this treatment too
735                return variables.ConstantVariable.create(
736                    torch.fx.experimental.symbolic_shapes.guard_size_oblivious(
737                        expr.sym_num
738                    )
739                )
740            elif isinstance(expr, ConstantVariable):
741                return expr
742
743        @register(torch._C._autograd._unsafe_set_version_counter)
744        def handle_unsafe_set_version_counter(
745            self, tx: "InstructionTranslator", *args, **kwargs
746        ):
747            from ..tensor_version_op import _unsafe_set_version_counter
748
749            return TorchInGraphFunctionVariable(
750                _unsafe_set_version_counter
751            ).call_function(tx, [*args], kwargs)
752
753        @register(torch.tensor)
754        def handle_torch_tensor(self, tx: "InstructionTranslator", *args, **kwargs):
755            def check_any_unspec(x):
756                # NB: This includes UnspecializedPythonVariable
757                if isinstance(x, (TensorVariable, SymNodeVariable)):
758                    return True
759                elif isinstance(x, (ListVariable, TupleVariable)):
760                    return any(check_any_unspec(y) for y in x.items)
761                # TODO: there maybe other recursive structures you need to
762                # check
763                else:
764                    return False
765
766            data_arg = None
767            if args:
768                data_arg = args[0]
769            elif "data" in kwargs:
770                data_arg = kwargs["data"]
771
772            # NB: OK to pass torch.tensor(tensor), this will trace fine
773            if not isinstance(data_arg, TensorVariable) and check_any_unspec(data_arg):
774                # This is slower and less canonical, so only use it if we
775                # have to
776                return TorchInGraphFunctionVariable(torch._refs.tensor).call_function(
777                    tx, [*args], kwargs
778                )
779
780        @register(torch._C._pop_torch_function_stack)
781        def handle_pop_torch_function(
782            self, tx: "InstructionTranslator", *args, **kwargs
783        ):
784            assert not args and not kwargs
785            if not tx.symbolic_torch_function_mode_stack:
786                raise unimplemented("Popping from an empty torch function mode stack")
787            TorchFunctionModeStackVariable.register_mutation(tx)
788            return tx.symbolic_torch_function_mode_stack.pop()
789
790        @register(torch._C._push_on_torch_function_stack)
791        def handle_push_torch_function(
792            self, tx: "InstructionTranslator", *args, **kwargs
793        ):
794            assert len(args) == 1 and not kwargs
795            TorchFunctionModeStackVariable.register_mutation(tx)
796            tx.symbolic_torch_function_mode_stack.append(args[0])
797            return ConstantVariable.create(None)
798
799        @register(torch._C._len_torch_function_stack)
800        def handle_len_torch_function(
801            self, tx: "InstructionTranslator", *args, **kwargs
802        ):
803            assert not args and not kwargs
804            return ConstantVariable.create(len(tx.symbolic_torch_function_mode_stack))
805
806        @register(torch.set_default_device)
807        def handle_set_default_device(
808            self, tx: "InstructionTranslator", *args, **kwargs
809        ):
810            # Today this is inserted in the graph, once TF mode
811            # handling is complete, we can trace the device context
812            # like any other TF mode and remove this special handling
813            # Insert the TF mode representing the device context at
814            # the bottom of the stack to match the eager semantics
815            # Running the graph will ensure that the DeviceContext mode is
816            # at the correct position in the stack
817            TorchFunctionModeStackVariable.register_mutation(tx)
818            if args[0].is_python_constant() and args[0].as_python_constant() is None:
819                TorchFunctionModeStackVariable.clear_default_device(tx)
820            else:
821                TorchFunctionModeStackVariable.register_device_context_insertion(tx)
822
823            return None
824
825        return handlers
826
827    def call_function(
828        self,
829        tx: "InstructionTranslator",
830        args: "List[VariableTracker]",
831        kwargs: "Dict[str, VariableTracker]",
832    ) -> "VariableTracker":
833        from . import ConstantVariable, SymNodeVariable, TensorVariable
834        from .builder import wrap_fx_proxy
835
836        if self.can_constant_fold_through() and check_unspec_or_constant_args(
837            args, kwargs
838        ):
839            # constant fold
840            return ConstantVariable.create(
841                self.as_python_constant()(
842                    *[x.as_python_constant() for x in args],
843                    **{k: v.as_python_constant() for k, v in kwargs.items()},
844                ),
845            )
846
847        special_handler = self._get_handlers().get(self.value)
848        if special_handler:
849            result = special_handler(self, tx, *args, **kwargs)
850            if result:
851                return result
852
853        if can_dispatch_torch_function(tx, args, kwargs):
854            return dispatch_torch_function(tx, self, args, kwargs)
855        else:
856            any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args)
857
858            all_ints_or_floats = all(
859                isinstance(x, (variables.ConstantVariable, variables.SymNodeVariable))
860                for x in args
861            )
862            if (
863                getattr(self.value, "__module__", "") == "torch"
864                and self.value.__name__ in bin_ops
865                and any_symints_or_symfloats
866                and all_ints_or_floats
867            ):
868                msg = f"""\
869Calling {str(self.value)} on only torch.SymInt arguments is not yet supported.
870To support this behavior, we need to allow const-propping tensors that store symint data.
871For now, dynamo will explicitly graph break when it encounters user code with this behavior.
872"""
873                log.warning(msg)
874                unimplemented(msg)
875
876            # TODO(voz): Replace w/ dynamic shape rewrite table.
877            # Ideally, we would be able to do this at ctor time, but alas we need a combination
878            # of value + args to determine this.
879            fn_ = self.value
880            if any_symints_or_symfloats:
881                torch_sym_op = f"_sym_{self.value.__name__}"
882                if getattr(self.value, "__module__", None) == "math" and hasattr(
883                    torch, torch_sym_op
884                ):
885                    fn_ = getattr(torch, torch_sym_op)
886
887            fake_out_shape = None
888            if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable):
889                # Calling fake tensor propagation can mutate the out= tensor in
890                # tx.output.tracked_fakes. tracked_fakes are used to apply
891                # symbolic_shape guards. Mutating them destroys the information
892                # prior to tracing, which is essential for creating right
893                # guards. So save the shape now, and check later if it has
894                # changed. If it has, graph break.
895                fake_out_shape = kwargs["out"].proxy.node.meta["example_value"].shape
896
897            tensor_variable = wrap_fx_proxy(
898                tx=tx,
899                proxy=tx.output.create_proxy(
900                    "call_function",
901                    fn_,
902                    *proxy_args_kwargs(args, kwargs),
903                ),
904            )
905
906            if (
907                isinstance(tensor_variable, TensorVariable)
908                and "requires_grad" in kwargs
909                and kwargs["requires_grad"].as_python_constant()
910            ):
911                unimplemented(
912                    """factory functions that return tensors that require grad are not supported.
913Either create the tensor outside the compiled region, or do not set the tensor to require_grad"""
914                )
915
916            if "out" in kwargs and not (
917                isinstance(kwargs["out"], variables.ConstantVariable)
918                and kwargs["out"].as_python_constant() is None
919            ):
920                # out variants of torch operators like torch.sort and
921                # torch.sigmoid mutate the tensors in the out field. Track such
922                # tensors and rewrite the symbolic locals.
923                if isinstance(tensor_variable, TupleVariable):
924                    assert isinstance(kwargs["out"], (TupleVariable, ListVariable))
925                    output_tensor_names = [
926                        tx.find_symbolic_locals_name(x) for x in kwargs["out"].items
927                    ]
928                    for idx, name in enumerate(output_tensor_names):
929                        if name in tx.symbolic_locals:
930                            tx.symbolic_locals[name] = tensor_variable.items[idx]
931                    for out_tensor, result_tensor in zip(
932                        kwargs["out"].items, tensor_variable.items
933                    ):
934                        if (
935                            out_tensor.source
936                            and out_tensor in tx.output.graphargs
937                            and isinstance(out_tensor, variables.TensorVariable)
938                            and isinstance(result_tensor, variables.TensorVariable)
939                            and out_tensor.size != result_tensor.size
940                        ):
941                            # It's hard to get out variants with resizing on graph inputs work
942                            # properly across dynamo/aot/inductor, just fall back.
943                            unimplemented("out variants with resizing on graph inputs")
944                elif isinstance(tensor_variable, TensorVariable):
945                    assert isinstance(kwargs["out"], TensorVariable)
946                    assert "example_value" in kwargs["out"].proxy.node.meta
947                    fake_tensor = tensor_variable.proxy.node.meta["example_value"]
948                    fake_out = kwargs["out"].proxy.node.meta["example_value"]
949                    if (
950                        kwargs["out"].source
951                        and kwargs["out"] in tx.output.graphargs
952                        and fake_out_shape != fake_tensor.shape
953                    ):
954                        # It's hard to get out variants with resizing on graph inputs work
955                        # properly across dynamo/aot/inductor, just fall back.
956                        unimplemented("out variants with resizing on graph inputs")
957                    if not torch._prims_common.is_contiguous(fake_out):
958                        # It's difficult to handle strides correctly in functionalization
959                        # when calling an out= op with a non-contiguous out argument
960                        unimplemented(
961                            "out= op was called where output tensor was non-contiguous"
962                        )
963                    name = tx.find_symbolic_locals_name(kwargs["out"])
964                    if name in tx.symbolic_locals:
965                        tx.symbolic_locals[name] = tensor_variable
966                elif (
967                    isinstance(tensor_variable, ConstantVariable)
968                    and tensor_variable.value is None
969                ):
970                    # Handle out-variant custom ops that return None.
971                    if isinstance(kwargs["out"], TensorVariable):
972                        assert "example_value" in kwargs["out"].proxy.node.meta
973                        fake_out = kwargs["out"].proxy.node.meta["example_value"]
974                        if not torch._prims_common.is_contiguous(fake_out):
975                            # It's difficult to handle strides correctly in functionalization
976                            # when calling an out= op with a non-contiguous out argument
977                            unimplemented(
978                                "out= op was called where output tensor was non-contiguous"
979                            )
980                    elif isinstance(kwargs["out"], ListVariable):
981                        for idx, x in enumerate(kwargs["out"].items):
982                            assert "example_value" in x.proxy.node.meta  # type: ignore[attr-defined]
983                            fake_out = x.proxy.node.meta["example_value"]  # type: ignore[attr-defined]
984                            if not torch._prims_common.is_contiguous(fake_out):
985                                # It's difficult to handle strides correctly in functionalization
986                                # when calling an out= op with a non-contiguous out argument
987                                unimplemented(
988                                    "out= op was called where some of the output tensors were non-contiguous"
989                                )
990                else:
991                    unimplemented(f"out variant of {type(kwargs['out'])}")
992
993            return tensor_variable
994
995    def _call_ntuple(self, tx: "InstructionTranslator", args, kwargs):
996        """inline behavior of torch.nn.modules.utils._ntuple"""
997        if self.value is torch.nn.modules.utils._ntuple:
998            count = args[0].as_python_constant()
999        else:
1000            count = self.value.__closure__[0].cell_contents
1001        assert isinstance(count, int)
1002        assert not kwargs
1003
1004        def handle_ntuple(value):
1005            if value.has_unpack_var_sequence(tx):
1006                return variables.TupleVariable(
1007                    list(value.unpack_var_sequence(tx)),
1008                )
1009            elif value.is_python_constant():
1010                # constant prop through it
1011                return variables.ConstantVariable.create(
1012                    torch.nn.modules.utils._ntuple(count)(value.as_python_constant()),
1013                )
1014            else:
1015                unimplemented(f"torch.nn.modules.utils._ntuple({value})")
1016
1017        if self.value is torch.nn.modules.utils._ntuple:
1018            return variables.LambdaVariable(handle_ntuple)
1019        else:
1020            return handle_ntuple(args[0])
1021
1022    @classmethod
1023    def call_nn_parameter(cls, tx, data=None, requires_grad=True):
1024        """A call to torch.nn.Parameter() gets lifted to before the graph"""
1025        if tx.export:
1026            unimplemented("nn parameter construction not supported with export")
1027
1028        if isinstance(requires_grad, variables.VariableTracker):
1029            try:
1030                requires_grad = requires_grad.as_python_constant()
1031            except NotImplementedError:
1032                unimplemented("Parameter(requires_grad=...) not constant")
1033
1034        if not isinstance(data, variables.TensorVariable):
1035            unimplemented(f"Parameter(data={data}) not implemented")
1036
1037        # this results in cleaner graphs, but only works for inputs
1038        if data.source:
1039            return cls._nn_param_via_prefix_insert(tx, data, requires_grad)
1040
1041        if is_traceable_wrapper_subclass_type(data.class_type):
1042            unimplemented("Parameter constructor with tensor subclass NYI")
1043
1044        if not can_convert_to_tracable_parameter():
1045            unimplemented("Workaround for issues with nn_parameter construction")
1046
1047        try:
1048            shape = tuple(data.var_getattr(tx, "shape").as_python_constant())
1049            dtype = data.var_getattr(tx, "dtype").as_python_constant()
1050            device = data.var_getattr(tx, "device").as_python_constant()
1051        except NotImplementedError as e:
1052            unimplemented(f"Parameter not python_constant: {e}")
1053
1054        placeholder = tx.output.synthetic_graph_input(
1055            new_parameter_placeholder, [shape, dtype, device, requires_grad]
1056        )
1057        if data.requires_grad:
1058            data = data.call_method(tx, "detach", [], {})
1059
1060        from .builder import wrap_fx_proxy
1061
1062        result = wrap_fx_proxy(
1063            tx,
1064            tx.output.create_proxy(
1065                "call_function",
1066                tracable_create_parameter,
1067                (data.as_proxy(), placeholder.as_proxy()),
1068                {},
1069            ),
1070        )
1071        assert isinstance(result, variables.TensorVariable)
1072        result.class_type = torch.nn.Parameter
1073
1074        # TODO(jansel/bdhirsh) - There is some issue with
1075        # tracable_create_paramter. It does not seem to use the right
1076        # grad_enabled. Since this is parameter, we can just override the
1077        # has_grad_fn field to False to workaround the issue.
1078        result.has_grad_fn = False
1079
1080        # In reconstruct() should use the original parameter.  The one returned by the graph will be an alias.
1081        result.source = placeholder.source
1082
1083        # TODO(jansel): if the new param falls out of scope, currently it won't get freed until
1084        # the end of the graph.  We should fix this.
1085        return result
1086
1087    @staticmethod
1088    def _nn_param_via_prefix_insert(tx: "InstructionTranslator", data, requires_grad):
1089        # Alternate version if we have a .source
1090        from .builder import VariableBuilder
1091
1092        varname = tx.output.new_var()
1093
1094        # construct the nn.Parmeter before the graph save it to varname
1095        cg = PyCodegen(tx)
1096        cg.add_push_null(lambda: cg.load_import_from("torch.nn", "Parameter"))
1097        cg(data.source)
1098        cg(variables.ConstantVariable(requires_grad))
1099        cg.call_function(2, False)
1100        cg.store(varname)
1101        tx.output.pregraph_bytecode.extend(cg.get_instructions())
1102
1103        data_node = data.as_proxy().node
1104        if data_node.op not in ("placeholder", "get_attr"):
1105            unimplemented(
1106                "Unexpected type of data placeholder op for parameter construction"
1107            )
1108
1109        # add the newly constructed nn.Parameter as a graph input
1110        source = SyntheticLocalSource(varname)
1111        example_value = torch.nn.Parameter(
1112            tx.output.example_value_from_input_node(data.as_proxy().node)
1113        )
1114        result = VariableBuilder(tx, source)(example_value)
1115        # No need to guard on this since we already guarded on `data`.
1116        # These guards would fail since varname doesn't exist until after the function starts
1117        TracingContext.get().guards_context.dynamo_guards.remove_guards_with_source(
1118            source
1119        )
1120        return result
1121