1# mypy: allow-untyped-defs 2import collections 3import functools 4from typing import Optional 5 6from .base import VariableTracker 7from .tensor import SymNodeVariable 8 9 10class LazyCache: 11 """Container to cache the real VariableTracker""" 12 13 def __init__(self, value, source) -> None: 14 if not isinstance(value, LazySymNodeFormatString): 15 assert source 16 self.value = value 17 self.source = source 18 self.vt: Optional[VariableTracker] = None 19 20 def realize(self): 21 assert self.vt is None 22 from ..symbolic_convert import InstructionTranslator 23 from .builder import SourcelessBuilder, VariableBuilder 24 25 tx = InstructionTranslator.current_tx() 26 if isinstance(self.value, LazySymNodeFormatString): 27 self.vt = SourcelessBuilder.create(tx, self.value) 28 else: 29 self.vt = VariableBuilder(tx, self.source)(self.value) 30 31 del self.value 32 del self.source 33 34 35class LazyVariableTracker(VariableTracker): 36 """ 37 A structure that defers the creation of the actual VariableTracker 38 for a given underlying value until it is accessed. 39 40 The `realize` function invokes VariableBuilder to produce the real object. 41 Once a LazyVariableTracker has been realized, internal bookkeeping will 42 prevent double realization. 43 44 This object should be utilized for processing containers, or objects that 45 reference other objects where we may not want to take on creating all the 46 VariableTrackers right away. 47 """ 48 49 _nonvar_fields = {"_cache", *VariableTracker._nonvar_fields} 50 51 @staticmethod 52 def create(value, source, **options): 53 return LazyVariableTracker(LazyCache(value, source), source=source, **options) 54 55 def __init__(self, _cache, **kwargs) -> None: 56 assert isinstance(_cache, LazyCache) 57 super().__init__(**kwargs) 58 self._cache = _cache 59 60 def realize(self) -> VariableTracker: 61 """Force construction of the real VariableTracker""" 62 if self._cache.vt is None: 63 self._cache.realize() 64 assert self._cache.vt is not None 65 return self._cache.vt 66 67 def unwrap(self): 68 """Return the real VariableTracker if it already exists""" 69 if self.is_realized(): 70 return self._cache.vt 71 return self 72 73 def is_realized(self): 74 return self._cache.vt is not None 75 76 def clone(self, **kwargs): 77 assert kwargs.get("_cache", self._cache) is self._cache 78 if kwargs.get("source", self.source) is not self.source: 79 self.realize() 80 return VariableTracker.clone(self.unwrap(), **kwargs) 81 82 def __str__(self) -> str: 83 if self.is_realized(): 84 return self.unwrap().__str__() 85 return VariableTracker.__str__(self.unwrap()) 86 87 def __getattr__(self, item): 88 return getattr(self.realize(), item) 89 90 # most methods are auto-generated below, these are the ones we want to exclude 91 visit = VariableTracker.visit # type: ignore[assignment] 92 __repr__ = VariableTracker.__repr__ 93 94 @classmethod 95 def realize_all( 96 cls, 97 value, 98 cache=None, 99 ): 100 """ 101 Walk an object and realize all LazyVariableTrackers inside it. 102 """ 103 if cache is None: 104 cache = {} 105 106 idx = id(value) 107 if idx in cache: 108 return cache[idx][0] 109 110 value_cls = type(value) 111 if issubclass(value_cls, LazyVariableTracker): 112 result = cls.realize_all(value.realize(), cache) 113 elif issubclass(value_cls, VariableTracker): 114 # update value in-place 115 result = value 116 value_dict = value.__dict__ 117 nonvars = value._nonvar_fields 118 for key in value_dict: 119 if key not in nonvars: 120 value_dict[key] = cls.realize_all(value_dict[key], cache) 121 elif value_cls is list: 122 result = [cls.realize_all(v, cache) for v in value] 123 elif value_cls is tuple: 124 result = tuple(cls.realize_all(v, cache) for v in value) 125 elif value_cls in (dict, collections.OrderedDict): 126 result = {k: cls.realize_all(v, cache) for k, v in list(value.items())} 127 else: 128 result = value 129 130 # save `value` to keep it alive and ensure id() isn't reused 131 cache[idx] = (result, value) 132 return result 133 134 135class LazySymNodeFormatString: 136 def __init__( 137 self, sym_node_variable: SymNodeVariable, fmt_spec_var: VariableTracker 138 ) -> None: 139 from .constant import ConstantVariable 140 141 self.sym_node_var = sym_node_variable 142 self.fmt_var = ConstantVariable.create( 143 "{:" + fmt_spec_var.as_python_constant() + "}" 144 ) 145 146 def __str__(self) -> str: 147 return str.format( 148 self.fmt_var.as_python_constant(), 149 str(self.sym_node_var.evaluate_expr()), 150 ) 151 152 153def _create_realize_and_forward(name): 154 @functools.wraps(getattr(VariableTracker, name)) 155 def realize_and_forward(self, *args, **kwargs): 156 return getattr(self.realize(), name)(*args, **kwargs) 157 158 return realize_and_forward 159 160 161def _populate(): 162 for name, value in VariableTracker.__dict__.items(): 163 if name not in LazyVariableTracker.__dict__: 164 if callable(value): 165 setattr(LazyVariableTracker, name, _create_realize_and_forward(name)) 166 167 168_populate() 169