xref: /aosp_15_r20/external/pytorch/torch/_dynamo/variables/lazy.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import collections
3import functools
4from typing import Optional
5
6from .base import VariableTracker
7from .tensor import SymNodeVariable
8
9
10class LazyCache:
11    """Container to cache the real VariableTracker"""
12
13    def __init__(self, value, source) -> None:
14        if not isinstance(value, LazySymNodeFormatString):
15            assert source
16        self.value = value
17        self.source = source
18        self.vt: Optional[VariableTracker] = None
19
20    def realize(self):
21        assert self.vt is None
22        from ..symbolic_convert import InstructionTranslator
23        from .builder import SourcelessBuilder, VariableBuilder
24
25        tx = InstructionTranslator.current_tx()
26        if isinstance(self.value, LazySymNodeFormatString):
27            self.vt = SourcelessBuilder.create(tx, self.value)
28        else:
29            self.vt = VariableBuilder(tx, self.source)(self.value)
30
31        del self.value
32        del self.source
33
34
35class LazyVariableTracker(VariableTracker):
36    """
37    A structure that defers the creation of the actual VariableTracker
38    for a given underlying value until it is accessed.
39
40    The `realize` function invokes VariableBuilder to produce the real object.
41    Once a LazyVariableTracker has been realized, internal bookkeeping will
42    prevent double realization.
43
44    This object should be utilized for processing containers, or objects that
45    reference other objects where we may not want to take on creating all the
46    VariableTrackers right away.
47    """
48
49    _nonvar_fields = {"_cache", *VariableTracker._nonvar_fields}
50
51    @staticmethod
52    def create(value, source, **options):
53        return LazyVariableTracker(LazyCache(value, source), source=source, **options)
54
55    def __init__(self, _cache, **kwargs) -> None:
56        assert isinstance(_cache, LazyCache)
57        super().__init__(**kwargs)
58        self._cache = _cache
59
60    def realize(self) -> VariableTracker:
61        """Force construction of the real VariableTracker"""
62        if self._cache.vt is None:
63            self._cache.realize()
64            assert self._cache.vt is not None
65        return self._cache.vt
66
67    def unwrap(self):
68        """Return the real VariableTracker if it already exists"""
69        if self.is_realized():
70            return self._cache.vt
71        return self
72
73    def is_realized(self):
74        return self._cache.vt is not None
75
76    def clone(self, **kwargs):
77        assert kwargs.get("_cache", self._cache) is self._cache
78        if kwargs.get("source", self.source) is not self.source:
79            self.realize()
80        return VariableTracker.clone(self.unwrap(), **kwargs)
81
82    def __str__(self) -> str:
83        if self.is_realized():
84            return self.unwrap().__str__()
85        return VariableTracker.__str__(self.unwrap())
86
87    def __getattr__(self, item):
88        return getattr(self.realize(), item)
89
90    # most methods are auto-generated below, these are the ones we want to exclude
91    visit = VariableTracker.visit  # type: ignore[assignment]
92    __repr__ = VariableTracker.__repr__
93
94    @classmethod
95    def realize_all(
96        cls,
97        value,
98        cache=None,
99    ):
100        """
101        Walk an object and realize all LazyVariableTrackers inside it.
102        """
103        if cache is None:
104            cache = {}
105
106        idx = id(value)
107        if idx in cache:
108            return cache[idx][0]
109
110        value_cls = type(value)
111        if issubclass(value_cls, LazyVariableTracker):
112            result = cls.realize_all(value.realize(), cache)
113        elif issubclass(value_cls, VariableTracker):
114            # update value in-place
115            result = value
116            value_dict = value.__dict__
117            nonvars = value._nonvar_fields
118            for key in value_dict:
119                if key not in nonvars:
120                    value_dict[key] = cls.realize_all(value_dict[key], cache)
121        elif value_cls is list:
122            result = [cls.realize_all(v, cache) for v in value]
123        elif value_cls is tuple:
124            result = tuple(cls.realize_all(v, cache) for v in value)
125        elif value_cls in (dict, collections.OrderedDict):
126            result = {k: cls.realize_all(v, cache) for k, v in list(value.items())}
127        else:
128            result = value
129
130        # save `value` to keep it alive and ensure id() isn't reused
131        cache[idx] = (result, value)
132        return result
133
134
135class LazySymNodeFormatString:
136    def __init__(
137        self, sym_node_variable: SymNodeVariable, fmt_spec_var: VariableTracker
138    ) -> None:
139        from .constant import ConstantVariable
140
141        self.sym_node_var = sym_node_variable
142        self.fmt_var = ConstantVariable.create(
143            "{:" + fmt_spec_var.as_python_constant() + "}"
144        )
145
146    def __str__(self) -> str:
147        return str.format(
148            self.fmt_var.as_python_constant(),
149            str(self.sym_node_var.evaluate_expr()),
150        )
151
152
153def _create_realize_and_forward(name):
154    @functools.wraps(getattr(VariableTracker, name))
155    def realize_and_forward(self, *args, **kwargs):
156        return getattr(self.realize(), name)(*args, **kwargs)
157
158    return realize_and_forward
159
160
161def _populate():
162    for name, value in VariableTracker.__dict__.items():
163        if name not in LazyVariableTracker.__dict__:
164            if callable(value):
165                setattr(LazyVariableTracker, name, _create_realize_and_forward(name))
166
167
168_populate()
169