xref: /aosp_15_r20/external/pytorch/torch/_dynamo/variables/base.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import collections
4from enum import Enum
5from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING
6
7from .. import variables
8from ..current_scope_id import current_scope_id
9from ..exc import unimplemented
10from ..source import AttrSource, Source
11from ..utils import istype
12
13
14if TYPE_CHECKING:
15    from torch._dynamo.symbolic_convert import InstructionTranslator
16
17
18class MutableLocalSource(Enum):
19    """
20    If the VariableTracker.mutable_local represents a Variable that:
21    - already existed that Dynamo began tracking while introspection (Existing)
22    - is a new variable that is created during Dynamo introspection (Local)
23    """
24
25    Existing = 0
26    Local = 1
27
28
29class MutableLocalBase:
30    """
31    Base class for Variable.mutable_local
32    """
33
34    def __init__(self, typ: MutableLocalSource) -> None:
35        # In HigherOrderOperator tracing, we need to distinguish
36        # between MutableLocals inside the HigherOrderOperator and
37        # ones outside it. For example, it is not safe to mutate
38        # `a` in the following example because it was constructed
39        # in a different scope.
40        #
41        # def f(x):
42        #     a = 1
43        #     def g(x):
44        #         nonlocal a
45        #         a = 2
46        #         return x
47        #     return wrap(g, x) + a
48        #
49        # We use self.scope to distinguish this.
50        # scope == 0: The object was an existing variable
51        # scope == 1: The object was created while Dynamo
52        #             was introspecting a function
53        #             (and no HigherOrderOps were involved)
54        # scope >= 2: The object was created through
55        #             Dynamo introspection of a HigherOrderOp.
56        #             The exact number corresponds to the level
57        #             of nested HigherOrderOps.
58        if typ is MutableLocalSource.Existing:
59            self.scope = 0
60        elif typ is MutableLocalSource.Local:
61            self.scope = current_scope_id()
62        else:
63            unimplemented(f"Unsupported MutableLocalSource: {typ}")
64
65
66class MutableLocal(MutableLocalBase):
67    """
68    Marker used to indicate this (list, iter, etc) was constructed in
69    local scope and can be mutated safely in analysis without leaking
70    state.
71    """
72
73    def __init__(self) -> None:
74        super().__init__(MutableLocalSource.Local)
75
76    def __hash__(self):
77        return id(self)
78
79    def __eq__(self, other):
80        return self is other
81
82
83def _is_top_level_scope(scope_id):
84    return scope_id == 1
85
86
87def is_side_effect_safe(m: MutableLocalBase):
88    scope_id = current_scope_id()
89
90    # In the top-level scope (if no HigherOrderOperators are involved),
91    # we are allowed to modify variables created in this scope as well
92    # as existing variables.
93    if _is_top_level_scope(scope_id):
94        return True
95    # Otherwise, only allow local mutation of variables created in the current scope
96    return m.scope == scope_id
97
98
99class VariableTrackerMeta(type):
100    all_subclasses = []
101
102    def __instancecheck__(cls, instance) -> bool:
103        """Make isinstance work with LazyVariableTracker"""
104        if type.__instancecheck__(
105            variables.LazyVariableTracker, instance
106        ) and cls not in (
107            VariableTracker,
108            variables.LazyVariableTracker,
109        ):
110            instance = instance.realize()
111        return type.__instancecheck__(cls, instance)
112
113    def __init__(cls, name, bases, attrs) -> None:
114        super().__init__(name, bases, attrs)
115        VariableTrackerMeta.all_subclasses.append(cls)
116
117
118class VariableTracker(metaclass=VariableTrackerMeta):
119    """
120    Base class for tracked locals and stack values
121
122    VariableTracker instances are immutable and should be copied in
123    order to change them.
124    """
125
126    # fields to leave unmodified in apply()
127    _nonvar_fields = {
128        "value",
129        "guards",
130        "source",
131        "mutable_local",
132        "parents_tracker",
133        "user_code_variable_name",
134    }
135
136    def clone(self, **kwargs):
137        """Shallow copy with some (optional) changes"""
138        args = dict(self.__dict__)
139        args.update(kwargs)
140        return self.__class__(**args)
141
142    @classmethod
143    def visit(
144        cls,
145        fn: Callable[["VariableTracker"], None],
146        value: Any,
147        cache: Optional[Dict[int, Any]] = None,
148    ) -> None:
149        """
150        Walk value and call fn on all the VariableTracker instances
151        """
152        if cache is None:
153            cache = {}
154
155        idx = id(value)
156        if idx in cache:
157            return
158        # save `value` to keep it alive and ensure id() isn't reused
159        cache[idx] = value
160
161        if isinstance(value, VariableTracker):
162            value = value.unwrap()
163            fn(value)
164            value = value.unwrap()  # calling fn() might have realized it
165            nonvars = value._nonvar_fields
166            for key, subvalue in value.__dict__.items():
167                if key not in nonvars:
168                    cls.visit(fn, subvalue, cache)
169        elif istype(value, (list, tuple)):
170            for subvalue in value:
171                cls.visit(fn, subvalue, cache)
172        elif istype(value, (dict, collections.OrderedDict)):
173            for subvalue in value.values():
174                cls.visit(fn, subvalue, cache)
175
176    def __repr__(self) -> str:
177        return f"{self.__class__.__name__}()"
178
179    def debug_repr(self):
180        # Intended to be overridden to provide more info
181        try:
182            return repr(self.as_python_constant())
183        except NotImplementedError:
184            return repr(self)
185
186    def python_type(self):
187        """
188        Abstract method to be implemented by subclasses of VariableTracker.
189
190        This method should return the type represented by the instance of the subclass.
191        The purpose is to provide a standardized way to retrieve the Python type information
192        of the variable being tracked.
193
194        Returns:
195            type: The Python type (such as int, str, list, etc.) of the variable tracked by
196                the subclass. If the type cannot be determined or is not relevant,
197                leaving it undefined or invoking super() is always sound.
198
199        Note:
200            This is an abstract method and may be overridden in subclasses.
201
202        Example:
203            class SetVariable(VariableTracker):
204                def python_type(self):
205                    return set
206
207        Raises:
208            NotImplementedError: If the method is not implemented in a subclass.
209        """
210        try:
211            return type(self.as_python_constant())
212        except NotImplementedError:
213            raise NotImplementedError(f"{self} has no type") from None
214
215    def as_python_constant(self):
216        """For constants"""
217        raise NotImplementedError(f"{self} is not a constant")
218
219    def guard_as_python_constant(self):
220        """Similar to as_python_constant(), but add ID_MATCH guards to try to force things to become constants"""
221        try:
222            return self.as_python_constant()
223        except NotImplementedError as e:
224            unimplemented(str(e))
225
226    def is_python_constant(self):
227        try:
228            self.as_python_constant()
229            return True
230        except NotImplementedError:
231            return False
232
233    def make_guard(self, fn):
234        if self.source:
235            return self.source.make_guard(fn)
236        raise NotImplementedError
237
238    def const_getattr(self, tx: "InstructionTranslator", name: str) -> Any:
239        """getattr(self, name) returning a python constant"""
240        raise NotImplementedError
241
242    def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
243        """getattr(self, name) returning a new variable"""
244        value = self.const_getattr(tx, name)
245        if not variables.ConstantVariable.is_literal(value):
246            raise NotImplementedError
247        source = None
248        if self.source:
249            source = AttrSource(self.source, name)
250        return variables.ConstantVariable.create(value, source=source)
251
252    def is_proxy(self):
253        try:
254            self.as_proxy()
255            return True
256        except NotImplementedError:
257            return False
258
259    def as_proxy(self):
260        raise NotImplementedError(str(self))
261
262    def maybe_fx_node(self):
263        try:
264            proxy = self.as_proxy()
265            import torch.fx
266
267            if isinstance(proxy, torch.fx.Proxy):
268                return proxy.node
269            return None
270        except NotImplementedError:
271            return None
272
273    def reconstruct(self, codegen):
274        raise NotImplementedError
275
276    def can_reconstruct(self, tx):
277        """If it is possible to reconstruct the Python object this
278        VariableTracker represents."""
279        assert tx is tx.output.root_tx, "Only root tx can reconstruct"
280        try:
281            from ..codegen import PyCodegen
282
283            cg = PyCodegen(tx)
284            self.reconstruct(cg)
285            return True
286        except NotImplementedError:
287            return False
288
289    def unpack_var_sequence(self, tx) -> List["VariableTracker"]:
290        raise NotImplementedError
291
292    def force_unpack_var_sequence(self, tx) -> List["VariableTracker"]:
293        # like unpack_var_sequence, but should only be used when it is
294        # safe to eagerly (vs. lazily) unpack this variable.
295        # e.g. map(f, x) is normally evaluated lazily but sometimes
296        # we want to force eager unpacking, e.g. when converting to a list.
297        # NOTE: this method is allowed to mutate the VariableTracker, so
298        # it should only be called once.
299        return self.unpack_var_sequence(tx)
300
301    def has_unpack_var_sequence(self, tx) -> bool:
302        try:
303            self.unpack_var_sequence(tx)
304            return True
305        except NotImplementedError:
306            return False
307
308    # NB: don't call force_unpack_var_sequence, especially if it mutates!
309    def has_force_unpack_var_sequence(self, tx) -> bool:
310        return self.has_unpack_var_sequence(tx)
311
312    def inspect_parameter_names(self) -> List[str]:
313        unimplemented(f"inspect_parameter_names: {self}")
314
315    def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
316        unimplemented(f"hasattr {self.__class__.__name__} {name}")
317
318    def call_function(
319        self,
320        tx: "InstructionTranslator",
321        args: "List[VariableTracker]",
322        kwargs: "Dict[str, VariableTracker]",
323    ) -> "VariableTracker":
324        unimplemented(f"call_function {self} {args} {kwargs}")
325
326    def call_method(
327        self,
328        tx,
329        name,
330        args: "List[VariableTracker]",
331        kwargs: "Dict[str, VariableTracker]",
332    ) -> "VariableTracker":
333        if name == "__len__" and self.has_unpack_var_sequence(tx):
334            assert not (args or kwargs)
335            return variables.ConstantVariable.create(len(self.unpack_var_sequence(tx)))
336        elif (
337            name == "__getattr__"
338            and len(args) == 1
339            and args[0].is_python_constant()
340            and not kwargs
341        ):
342            return self.var_getattr(tx, args[0].as_python_constant())
343        unimplemented(f"call_method {self} {name} {args} {kwargs}")
344
345    def set_name_hint(self, name):
346        pass
347
348    def realize(self) -> "VariableTracker":
349        """Used by LazyVariableTracker to build the real VariableTracker"""
350        return self
351
352    def unwrap(self) -> "VariableTracker":
353        """Used by LazyVariableTracker to return the real VariableTracker if it already exists"""
354        return self
355
356    def is_realized(self):
357        """Used by LazyVariableTracker to indicate an unrealized node"""
358        return True
359
360    def next_variable(self, tx):
361        unimplemented(f"next({self})")
362
363    def is_strict_mode(self, tx):
364        return tx.strict_checks_fn and tx.strict_checks_fn(self)
365
366    def __init__(
367        self,
368        *,
369        source: Source = None,
370        mutable_local: MutableLocal = None,
371    ) -> None:
372        super().__init__()
373        self.source = source
374        self.mutable_local = mutable_local
375
376
377def typestr(*objs):
378    if len(objs) == 1:
379        (obj,) = objs
380        if isinstance(obj, VariableTracker):
381            return str(obj)
382        else:
383            return type(obj).__name__
384    else:
385        return " ".join(map(typestr, objs))
386