1# mypy: allow-untyped-defs 2from typing import Optional 3 4import torch.fx 5from torch.fx import Node 6from torch.fx.node import map_aggregate 7from torch.fx._compatibility import compatibility 8from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor 9from torch.fx.experimental.proxy_tensor import snapshot_fake, py_sym_types 10 11__all__ = ['FakeTensorProp'] 12 13@compatibility(is_backward_compatible=False) 14class FakeTensorProp(torch.fx.Interpreter): 15 """ 16 Execute an FX graph Node-by-Node and record a fake tensor representing 17 the metadata for the node. Unlike ShapeProp, (1) this propagation 18 is cheap--it does the propagation with meta tensors which do not actually 19 store data, and (2) the fake tensors have much more fine grained information, 20 e.g., they have accurate alias information that can be consulted by looking 21 at the storages. 22 23 Args: 24 module (GraphModule): The module to be executed 25 mode (Optional[FakeTensorMode]): The dispatch mode used to execute computation indicated by each FX Node. 26 """ 27 def __init__(self, module: torch.fx.GraphModule, mode: Optional[FakeTensorMode] = None): 28 super().__init__(module) 29 if mode is None: 30 mode = FakeTensorMode() 31 self._mode = mode 32 mode.epoch += 1 33 mode.reset_nt_tensor_id_counter() 34 35 def run_node(self, n: Node): 36 from torch.fx.experimental.symbolic_shapes import rebind_unbacked, compute_unbacked_bindings 37 38 result = super().run_node(n) 39 rebind_unbacked(self._mode.shape_env, n, result) 40 41 def extract_val(obj): 42 if isinstance(obj, FakeTensor): 43 return snapshot_fake(obj) 44 elif isinstance(obj, torch.Tensor): 45 # TODO: How is it possible that we get a non fake tensor? We 46 # should be running under the mode... 47 return snapshot_fake(self._mode.from_tensor(obj, static_shapes=True)) 48 elif isinstance(obj, py_sym_types): 49 return obj 50 else: 51 return None 52 53 meta = map_aggregate(result, extract_val) 54 if meta is not None: 55 n.meta['val'] = meta 56 if (shape_env := self._mode.shape_env) and (symbol_to_path := compute_unbacked_bindings(shape_env, result)): 57 n.meta["unbacked_bindings"] = symbol_to_path 58 59 return result 60 61 def propagate(self, *args): 62 fake_args = [ 63 self._mode.from_tensor(a) if isinstance(a, torch.Tensor) else a 64 for a in args 65 ] 66 return self.propagate_dont_convert_inputs(*fake_args) 67 68 def propagate_dont_convert_inputs(self, *args): 69 with self._mode: 70 return super().run(*args) 71