xref: /aosp_15_r20/external/pytorch/torch/_dynamo/variables/constant.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import operator
4from typing import Dict, List, TYPE_CHECKING
5
6import torch
7from torch._dynamo.source import GetItemSource
8
9from .. import variables
10from ..exc import unimplemented, UserError, UserErrorType
11from ..guards import GuardBuilder, install_guard
12from ..utils import common_constant_types, istype, np
13from .base import typestr, VariableTracker
14
15
16if TYPE_CHECKING:
17    from torch._dynamo.symbolic_convert import InstructionTranslator
18
19
20_type_to_assert_reason = {
21    # NB - We CAN have ConstantVariable.create(set) because of how sets interact with guards.
22    # A locally created set should always become a SetVariable, as the items in the set will already either be sourced
23    # from somewhere else, or unsourced. An input set would imply sources derived from set contents. For example, an
24    # input list's contents will have a source like some_list[0], some_list[1][1], etc. For a set, arbitrary access is
25    # not possible. This is a solvable problem, but one we have not taken on yet. As such, input sets are not allowed to
26    # become SetVariables. The solution here is to create a ConstantSetVariable that is more like a ConstantVariable.
27    # As this does not exist, we cannot add sets to this invariant.
28    list: "List types must use ListVariable.",
29    dict: "Dict types must use ConstDictVariable.",
30    torch.Tensor: "Tensor types must use TensorVariable.",
31    torch.SymInt: "SymInts must use SymNodeVariable. "
32    "If the underlying value is static, we will create a ConstantVariable and specialize.",
33    torch.SymFloat: "SymInts must use SymNodeVariable",
34}
35
36
37class ConstantVariable(VariableTracker):
38    @staticmethod
39    def create(value, **kwargs) -> VariableTracker:
40        source = kwargs.get("source", None)
41        is_literal = ConstantVariable.is_literal(value)
42        if not is_literal:
43            for disallowed_type, reason in _type_to_assert_reason.items():
44                assert not isinstance(value, disallowed_type), reason
45
46        # Routing for list and tuple literals.
47        if is_literal and isinstance(value, (set, frozenset)):
48            items = []
49            for i, x in enumerate(value):
50                items.append(ConstantVariable.create(x))
51            return variables.SetVariable(items, **kwargs)
52        elif is_literal and isinstance(value, (list, tuple)):
53            items = []
54            for i, x in enumerate(value):
55                item_source = GetItemSource(source, i) if source else None
56                if item_source:
57                    install_guard(item_source.make_guard(GuardBuilder.CONSTANT_MATCH))
58                items.append(
59                    ConstantVariable.create(
60                        x,
61                        source=item_source,
62                    )
63                )
64            return variables.BaseListVariable.cls_for(type(value))(items, **kwargs)
65
66        return ConstantVariable(value, **kwargs)
67
68    def __init__(self, value, **kwargs) -> None:
69        super().__init__(**kwargs)
70        if not ConstantVariable.is_literal(value):
71            for disallowed_type, reason in _type_to_assert_reason.items():
72                assert not isinstance(value, disallowed_type), reason
73
74        assert not isinstance(
75            value, (list, tuple)
76        ), "ConstantVariable(list) is banned - please create a ListVariable(items)"
77        if np is not None and isinstance(value, np.number):
78            self.value = value.item()
79        else:
80            self.value = value
81
82    def as_proxy(self):
83        return self.value
84
85    def __str__(self) -> str:
86        return f"ConstantVariable({type(self.value).__name__}: {repr(self.value)})"
87
88    def as_python_constant(self):
89        return self.value
90
91    def is_python_constant(self):
92        return True
93
94    @property
95    def items(self):
96        """
97        Need this when adding a BaseListVariable and a ConstantVariable together.
98        Happens in detectron2.
99        """
100        return self.unpack_var_sequence(tx=None)
101
102    def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
103        return ConstantVariable.create(
104            self.value[arg.as_python_constant()],
105        )
106
107    @staticmethod
108    def is_literal(obj):
109        if type(obj) in common_constant_types:
110            return True
111        # The structure within is_literal get routed to variables.BaseListVariable
112        if type(obj) in (list, tuple, set, frozenset, torch.Size):
113            return all(ConstantVariable.is_literal(x) for x in obj)
114        return False
115
116    def unpack_var_sequence(self, tx):
117        try:
118            return [ConstantVariable.create(x) for x in self.as_python_constant()]
119        except TypeError as e:
120            raise NotImplementedError from e
121
122    def const_getattr(self, tx: "InstructionTranslator", name):
123        if isinstance(self.value, type):
124            raise UserError(
125                UserErrorType.ANTI_PATTERN,
126                "Can't access members of type(obj) for a generated custom object. "
127                "Please use __class__ instead",
128                case_name="type_reflection_method",
129            )
130        member = getattr(self.value, name)
131        if callable(member):
132            raise NotImplementedError
133        return member
134
135    def call_method(
136        self,
137        tx,
138        name,
139        args: "List[VariableTracker]",
140        kwargs: "Dict[str, VariableTracker]",
141    ) -> "VariableTracker":
142        from .tensor import SymNodeVariable
143
144        if name == "format" and istype(self.value, str):
145            return variables.BuiltinVariable(str.format).call_function(
146                tx, [self, *args], kwargs
147            )
148        elif name == "join" and istype(self.value, str):
149            assert len(args) == 1 and len(kwargs) == 0
150            arg_unpacked = args[0].force_unpack_var_sequence(tx)
151            try:
152                arg_const = [x.as_python_constant() for x in arg_unpacked]
153                return ConstantVariable.create(self.value.join(arg_const))
154            except NotImplementedError:
155                return super().call_method(tx, name, args, kwargs)
156
157        if any(isinstance(x, SymNodeVariable) for x in args):
158            # Promote to SymNodeVariable for operations involving dynamic shapes.
159            return variables.SymNodeVariable(self.as_proxy(), self.value).call_method(
160                tx, name, args, kwargs
161            )
162
163        try:
164            const_args = [a.as_python_constant() for a in args]
165            const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
166        except NotImplementedError:
167            return super().call_method(tx, name, args, kwargs)
168
169        if isinstance(self.value, str) and name in str.__dict__.keys():
170            method = getattr(self.value, name)
171            return ConstantVariable.create(method(*const_args, **const_kwargs))
172        elif isinstance(self.value, (float, int)):
173            if not (args or kwargs):
174                return ConstantVariable.create(getattr(self.value, name)())
175            if (
176                hasattr(operator, name)
177                and len(args) == 1
178                and args[0].is_python_constant()
179            ):
180                add_target = const_args[0]
181                op = getattr(operator, name)
182                if isinstance(
183                    add_target, (torch.SymBool, torch.SymFloat, torch.SymInt)
184                ):
185                    # Addition between a non sym and sym makes a sym
186                    proxy = tx.output.create_proxy(
187                        "call_function", op, (self.value, add_target), {}
188                    )
189                    return SymNodeVariable.create(tx, proxy, add_target)
190                else:
191                    return ConstantVariable.create(op(self.value, add_target))
192        elif isinstance(self.value, bytes) and name == "decode":
193            method = getattr(self.value, name)
194            return ConstantVariable.create(method(*const_args, **const_kwargs))
195
196        if name == "__len__" and not (args or kwargs):
197            return ConstantVariable.create(len(self.value))
198        elif name == "__contains__" and len(args) == 1 and args[0].is_python_constant():
199            assert not kwargs
200            search = args[0].as_python_constant()
201            result = search in self.value
202            return ConstantVariable.create(result)
203
204        unimplemented(f"const method call {typestr(self.value)}.{name}")
205
206    def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
207        result = hasattr(self.value, name)
208        return variables.ConstantVariable.create(result)
209
210
211class EnumVariable(VariableTracker):
212    def __init__(self, value, **kwargs) -> None:
213        super().__init__(**kwargs)
214        self.value = value
215
216    @classmethod
217    def create(cls, cls_type, value_vt, options):
218        if isinstance(value_vt, variables.ConstantVariable):
219            for member in list(cls_type):
220                if member.value == value_vt.as_python_constant():
221                    return cls(member, **options)
222        unimplemented("Enum variable is constructed with non constant values")
223
224    def as_proxy(self):
225        return self.value
226
227    def __str__(self) -> str:
228        return f"EnumVariable({type(self.value)})"
229
230    def as_python_constant(self):
231        return self.value
232
233    def const_getattr(self, tx: "InstructionTranslator", name):
234        member = getattr(self.value, name)
235        if callable(member):
236            raise NotImplementedError
237        return member
238