1# mypy: ignore-errors 2 3import operator 4from typing import Dict, List, TYPE_CHECKING 5 6import torch 7from torch._dynamo.source import GetItemSource 8 9from .. import variables 10from ..exc import unimplemented, UserError, UserErrorType 11from ..guards import GuardBuilder, install_guard 12from ..utils import common_constant_types, istype, np 13from .base import typestr, VariableTracker 14 15 16if TYPE_CHECKING: 17 from torch._dynamo.symbolic_convert import InstructionTranslator 18 19 20_type_to_assert_reason = { 21 # NB - We CAN have ConstantVariable.create(set) because of how sets interact with guards. 22 # A locally created set should always become a SetVariable, as the items in the set will already either be sourced 23 # from somewhere else, or unsourced. An input set would imply sources derived from set contents. For example, an 24 # input list's contents will have a source like some_list[0], some_list[1][1], etc. For a set, arbitrary access is 25 # not possible. This is a solvable problem, but one we have not taken on yet. As such, input sets are not allowed to 26 # become SetVariables. The solution here is to create a ConstantSetVariable that is more like a ConstantVariable. 27 # As this does not exist, we cannot add sets to this invariant. 28 list: "List types must use ListVariable.", 29 dict: "Dict types must use ConstDictVariable.", 30 torch.Tensor: "Tensor types must use TensorVariable.", 31 torch.SymInt: "SymInts must use SymNodeVariable. " 32 "If the underlying value is static, we will create a ConstantVariable and specialize.", 33 torch.SymFloat: "SymInts must use SymNodeVariable", 34} 35 36 37class ConstantVariable(VariableTracker): 38 @staticmethod 39 def create(value, **kwargs) -> VariableTracker: 40 source = kwargs.get("source", None) 41 is_literal = ConstantVariable.is_literal(value) 42 if not is_literal: 43 for disallowed_type, reason in _type_to_assert_reason.items(): 44 assert not isinstance(value, disallowed_type), reason 45 46 # Routing for list and tuple literals. 47 if is_literal and isinstance(value, (set, frozenset)): 48 items = [] 49 for i, x in enumerate(value): 50 items.append(ConstantVariable.create(x)) 51 return variables.SetVariable(items, **kwargs) 52 elif is_literal and isinstance(value, (list, tuple)): 53 items = [] 54 for i, x in enumerate(value): 55 item_source = GetItemSource(source, i) if source else None 56 if item_source: 57 install_guard(item_source.make_guard(GuardBuilder.CONSTANT_MATCH)) 58 items.append( 59 ConstantVariable.create( 60 x, 61 source=item_source, 62 ) 63 ) 64 return variables.BaseListVariable.cls_for(type(value))(items, **kwargs) 65 66 return ConstantVariable(value, **kwargs) 67 68 def __init__(self, value, **kwargs) -> None: 69 super().__init__(**kwargs) 70 if not ConstantVariable.is_literal(value): 71 for disallowed_type, reason in _type_to_assert_reason.items(): 72 assert not isinstance(value, disallowed_type), reason 73 74 assert not isinstance( 75 value, (list, tuple) 76 ), "ConstantVariable(list) is banned - please create a ListVariable(items)" 77 if np is not None and isinstance(value, np.number): 78 self.value = value.item() 79 else: 80 self.value = value 81 82 def as_proxy(self): 83 return self.value 84 85 def __str__(self) -> str: 86 return f"ConstantVariable({type(self.value).__name__}: {repr(self.value)})" 87 88 def as_python_constant(self): 89 return self.value 90 91 def is_python_constant(self): 92 return True 93 94 @property 95 def items(self): 96 """ 97 Need this when adding a BaseListVariable and a ConstantVariable together. 98 Happens in detectron2. 99 """ 100 return self.unpack_var_sequence(tx=None) 101 102 def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker): 103 return ConstantVariable.create( 104 self.value[arg.as_python_constant()], 105 ) 106 107 @staticmethod 108 def is_literal(obj): 109 if type(obj) in common_constant_types: 110 return True 111 # The structure within is_literal get routed to variables.BaseListVariable 112 if type(obj) in (list, tuple, set, frozenset, torch.Size): 113 return all(ConstantVariable.is_literal(x) for x in obj) 114 return False 115 116 def unpack_var_sequence(self, tx): 117 try: 118 return [ConstantVariable.create(x) for x in self.as_python_constant()] 119 except TypeError as e: 120 raise NotImplementedError from e 121 122 def const_getattr(self, tx: "InstructionTranslator", name): 123 if isinstance(self.value, type): 124 raise UserError( 125 UserErrorType.ANTI_PATTERN, 126 "Can't access members of type(obj) for a generated custom object. " 127 "Please use __class__ instead", 128 case_name="type_reflection_method", 129 ) 130 member = getattr(self.value, name) 131 if callable(member): 132 raise NotImplementedError 133 return member 134 135 def call_method( 136 self, 137 tx, 138 name, 139 args: "List[VariableTracker]", 140 kwargs: "Dict[str, VariableTracker]", 141 ) -> "VariableTracker": 142 from .tensor import SymNodeVariable 143 144 if name == "format" and istype(self.value, str): 145 return variables.BuiltinVariable(str.format).call_function( 146 tx, [self, *args], kwargs 147 ) 148 elif name == "join" and istype(self.value, str): 149 assert len(args) == 1 and len(kwargs) == 0 150 arg_unpacked = args[0].force_unpack_var_sequence(tx) 151 try: 152 arg_const = [x.as_python_constant() for x in arg_unpacked] 153 return ConstantVariable.create(self.value.join(arg_const)) 154 except NotImplementedError: 155 return super().call_method(tx, name, args, kwargs) 156 157 if any(isinstance(x, SymNodeVariable) for x in args): 158 # Promote to SymNodeVariable for operations involving dynamic shapes. 159 return variables.SymNodeVariable(self.as_proxy(), self.value).call_method( 160 tx, name, args, kwargs 161 ) 162 163 try: 164 const_args = [a.as_python_constant() for a in args] 165 const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()} 166 except NotImplementedError: 167 return super().call_method(tx, name, args, kwargs) 168 169 if isinstance(self.value, str) and name in str.__dict__.keys(): 170 method = getattr(self.value, name) 171 return ConstantVariable.create(method(*const_args, **const_kwargs)) 172 elif isinstance(self.value, (float, int)): 173 if not (args or kwargs): 174 return ConstantVariable.create(getattr(self.value, name)()) 175 if ( 176 hasattr(operator, name) 177 and len(args) == 1 178 and args[0].is_python_constant() 179 ): 180 add_target = const_args[0] 181 op = getattr(operator, name) 182 if isinstance( 183 add_target, (torch.SymBool, torch.SymFloat, torch.SymInt) 184 ): 185 # Addition between a non sym and sym makes a sym 186 proxy = tx.output.create_proxy( 187 "call_function", op, (self.value, add_target), {} 188 ) 189 return SymNodeVariable.create(tx, proxy, add_target) 190 else: 191 return ConstantVariable.create(op(self.value, add_target)) 192 elif isinstance(self.value, bytes) and name == "decode": 193 method = getattr(self.value, name) 194 return ConstantVariable.create(method(*const_args, **const_kwargs)) 195 196 if name == "__len__" and not (args or kwargs): 197 return ConstantVariable.create(len(self.value)) 198 elif name == "__contains__" and len(args) == 1 and args[0].is_python_constant(): 199 assert not kwargs 200 search = args[0].as_python_constant() 201 result = search in self.value 202 return ConstantVariable.create(result) 203 204 unimplemented(f"const method call {typestr(self.value)}.{name}") 205 206 def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": 207 result = hasattr(self.value, name) 208 return variables.ConstantVariable.create(result) 209 210 211class EnumVariable(VariableTracker): 212 def __init__(self, value, **kwargs) -> None: 213 super().__init__(**kwargs) 214 self.value = value 215 216 @classmethod 217 def create(cls, cls_type, value_vt, options): 218 if isinstance(value_vt, variables.ConstantVariable): 219 for member in list(cls_type): 220 if member.value == value_vt.as_python_constant(): 221 return cls(member, **options) 222 unimplemented("Enum variable is constructed with non constant values") 223 224 def as_proxy(self): 225 return self.value 226 227 def __str__(self) -> str: 228 return f"EnumVariable({type(self.value)})" 229 230 def as_python_constant(self): 231 return self.value 232 233 def const_getattr(self, tx: "InstructionTranslator", name): 234 member = getattr(self.value, name) 235 if callable(member): 236 raise NotImplementedError 237 return member 238