1# mypy: allow-untyped-defs 2import torch 3import torch.fx 4import warnings 5import functools 6import builtins 7 8from typing import Any, Callable, Dict, Optional, Union 9 10def embedding_override(self, input): 11 return torch.empty(*input.shape, self.weight.shape[-1], device='meta') 12 13 14def nn_layernorm_override(self, input): 15 return input 16 17 18def torch_relu_override(x): 19 return x 20 21 22def torch_nn_relu_override(self, x): 23 return x 24 25 26def functional_relu_override(x, inplace=False): 27 assert not inplace, 'dont support inplace functional.relu for metatensor analysis' 28 return x 29 30 31def torch_where_override(condition, x, y): 32 # torch.where returns the broadcasted tensor of condition, x, and y, 33 # so hack it by using addition 34 return condition.to(device='meta') + x.to(device='meta') + y.to(device='meta') 35 36 37def torch_abs_override(input, *, out=None): 38 assert out is None, 'Dont support in-place abs for MetaTensor analysis' 39 return input 40 41manual_meta_overrides : Dict[Callable, Callable] = { 42 torch.nn.Embedding: embedding_override, 43 torch.nn.LayerNorm: nn_layernorm_override, 44 torch.relu: torch_relu_override, 45 torch.nn.functional.relu: functional_relu_override, 46 torch.nn.ReLU: torch_nn_relu_override, 47 torch.where: torch_where_override, 48 torch.abs: torch_abs_override, 49} 50 51def gen_constructor_wrapper(target): 52 @functools.wraps(target) 53 def wrapper(*args, **kwargs): 54 proxy = None 55 56 def check_has_proxy(v): 57 if isinstance(v, torch.fx.Proxy): 58 nonlocal proxy 59 proxy = v 60 torch.fx.node.map_aggregate(args, check_has_proxy) 61 torch.fx.node.map_aggregate(kwargs, check_has_proxy) 62 63 if proxy is not None: 64 return proxy.tracer.create_proxy('call_function', target, args, kwargs) 65 else: 66 return target(*args, **kwargs) 67 return wrapper, target 68 69class MetaProxy(torch.fx.Proxy): 70 def install_tensor_meta(self, tensor_meta): 71 self._tensor_meta = tensor_meta 72 73 def size(self, dim=None): 74 if hasattr(self, '_tensor_meta') and self._tensor_meta is not None: 75 return self._tensor_meta.size(*[dim] if dim else []) 76 return self.tracer.create_proxy('call_method', 'size', (self, dim) if dim else (self,), {}) 77 78 def dim(self): 79 if hasattr(self, '_tensor_meta') and self._tensor_meta is not None: 80 return self._tensor_meta.dim() 81 return self.tracer.create_proxy('call_method', 'dim', (self,), {}) 82 83 @property 84 def shape(self): 85 if hasattr(self, '_tensor_meta') and self._tensor_meta is not None: 86 return self._tensor_meta.shape 87 return self.tracer.create_proxy('call_function', builtins.getattr, (self, 'shape'), {}) 88 89 @property 90 def dtype(self): 91 if hasattr(self, '_tensor_meta') and self._tensor_meta is not None: 92 return self._tensor_meta.dtype 93 return self.tracer.create_proxy('call_function', builtins.getattr, (self, 'dtype'), {}) 94 95 @property 96 def device(self): 97 # Hack so we can track when devices are used. During meta-tensor propagation, 98 # replace these values with a constant 'meta' 99 return MetaDeviceAttribute(self, 'device') 100 101 def __getattr__(self, k): 102 if k == '_tensor_meta': 103 return self.__getattribute__(k) 104 # note: not added to the graph yet, if this is a method call 105 # we peephole optimize to the method invocation 106 return MetaAttribute(self, k) 107 108class MetaAttribute(MetaProxy): 109 def __init__(self, root, attr: str): 110 111 self.root = root 112 self.attr = attr 113 self.tracer = root.tracer 114 self._node = None 115 116 @property 117 def node(self): 118 # the node for attributes is added lazily, since most will just be method calls 119 # which do not rely on the getitem call 120 if self._node is None: 121 self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node 122 return self._node 123 124 def __call__(self, *args, **kwargs): 125 return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs) 126 127class MetaDeviceAttribute(MetaAttribute): 128 pass 129 130def proxys_to_metas(v): 131 if isinstance(v, MetaDeviceAttribute): 132 return 'meta' 133 if isinstance(v, torch.fx.Proxy): 134 assert isinstance(v, MetaProxy), f'Expected MetaProxy but got {type(v)}' 135 assert hasattr(v, '_tensor_meta'), 'MetaProxy does not have an associated meta' 136 return v._tensor_meta 137 return v 138 139class MetaTracer(torch.fx.Tracer): 140 allow_insert_stateless_mods : bool = True 141 142 _TORCH_METHODS_TO_PATCH = ['arange', 'zeros', 'ones', 'full_like', 'eye'] 143 144 def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None): 145 rv = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn) 146 147 if kind == 'placeholder' and target in self.meta_args: 148 rv.install_tensor_meta(self.meta_args[target]) 149 return rv 150 151 if target in self.orig_fns: 152 # NOTE: tensor constructors in PyTorch define the `device` argument as 153 # *kwargs-only*. That is why this works. If you add methods to 154 # _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only, 155 # this will break and you will likely see issues where we cannot infer 156 # the size of the output. 157 if 'device' in kwargs: 158 kwargs['device'] = 'meta' 159 160 try: 161 args_metas = torch.fx.node.map_aggregate(args, proxys_to_metas) 162 kwargs_metas = torch.fx.node.map_aggregate(kwargs, proxys_to_metas) 163 164 if kind == 'call_function': 165 meta_target = manual_meta_overrides.get(target, target) 166 meta_out = meta_target(*args_metas, **kwargs_metas) 167 elif kind == 'call_method': 168 meta_out = getattr(args_metas[0], target)(*args_metas[1:], **kwargs_metas) # type: ignore[index] 169 elif kind == 'call_module': 170 assert hasattr(self, 'orig_forward') 171 self._disable_module_getattr = True 172 try: 173 mod = self.root.get_submodule(target) 174 mod_type = type(mod) 175 if mod_type in manual_meta_overrides: 176 meta_out = manual_meta_overrides[mod_type](mod, *args_metas, **kwargs_metas) # type: ignore[misc, arg-type] 177 else: 178 meta_out = self.orig_forward(*args_metas, **kwargs_metas) 179 finally: 180 self._disable_module_getattr = False 181 elif kind == 'get_attr': 182 self._disable_module_getattr = True 183 try: 184 attr_itr = self.root 185 atoms = target.split('.') 186 for atom in atoms: 187 attr_itr = getattr(attr_itr, atom) 188 assert isinstance(attr_itr, torch.Tensor) 189 meta_out = attr_itr.to(device='meta') 190 finally: 191 self._disable_module_getattr = False 192 else: 193 return rv 194 195 # TODO 196 assert isinstance(rv, torch.fx.Proxy), 'Dont support composite output yet' 197 rv.install_tensor_meta(meta_out) 198 except Exception as e: 199 warnings.warn(f'Could not compute metadata for {kind} target {target}: {e}') 200 201 return rv 202 203 def getattr(self, attr, attr_val, parameter_proxy_cache): 204 if getattr(self, '_disable_module_getattr', False): 205 return attr_val 206 else: 207 return super().getattr(attr, attr_val, parameter_proxy_cache) 208 209 def call_module(self, m, forward, args, kwargs): 210 self.orig_forward = forward 211 return super().call_module(m, forward, args, kwargs) 212 213 def _insert_module_as_submodule(self, mod: torch.nn.Module) -> str: 214 """ 215 Helper method which tries to insert a module that was not declared as submodule. 216 """ 217 idx = 0 218 mod_name = mod.__class__.__name__.lower() 219 path = f"{mod_name}_{idx}" 220 while hasattr(self.root, path): 221 path = f"{mod_name}_{idx}" 222 idx += 1 223 224 self.root.add_module(path, mod) 225 return path 226 227 def path_of_module(self, mod: torch.nn.Module) -> str: 228 try: 229 return super().path_of_module(mod) 230 except NameError as e: 231 if self.allow_insert_stateless_mods and len(list(mod.parameters())) == 0 and len(list(mod.buffers())) == 0: 232 path = self._insert_module_as_submodule(mod) 233 self.prev_module = path 234 return path 235 raise 236 237 def proxy(self, node): 238 return MetaProxy(node, self) 239 240 def trace(self, root, meta_args : Dict[str, torch.Tensor], concrete_args=None): # type: ignore[override] 241 assert isinstance(meta_args, dict) 242 self.meta_args = meta_args 243 244 self.patched_torch_methods = { 245 target: gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH 246 } 247 self.orig_fns = set() 248 249 for name, (wrapper, orig) in self.patched_torch_methods.items(): 250 setattr(torch, name, wrapper) 251 self.orig_fns.add(orig) 252 253 try: 254 graph = super().trace(root, concrete_args) 255 graph._tracer_extras = {'meta_args': meta_args} 256 return graph 257 finally: 258 for name, (_, orig) in self.patched_torch_methods.items(): 259 setattr(torch, name, orig) 260 261 262def symbolic_trace(root : Union[torch.nn.Module, Callable[..., Any]], 263 meta_args : Optional[Dict[str, torch.Tensor]] = None, 264 concrete_args: Optional[Dict[str, Any]] = None) -> torch.fx.GraphModule: 265 tracer = MetaTracer() 266 graph = tracer.trace(root, meta_args, concrete_args) # type: ignore[arg-type] 267 name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ 268 gm = torch.fx.GraphModule(tracer.root, graph, name) 269 return gm 270