xref: /aosp_15_r20/external/pytorch/torch/fx/passes/fake_tensor_prop.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Optional
3
4import torch.fx
5from torch.fx import Node
6from torch.fx.node import map_aggregate
7from torch.fx._compatibility import compatibility
8from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor
9from torch.fx.experimental.proxy_tensor import snapshot_fake, py_sym_types
10
11__all__ = ['FakeTensorProp']
12
13@compatibility(is_backward_compatible=False)
14class FakeTensorProp(torch.fx.Interpreter):
15    """
16    Execute an FX graph Node-by-Node and record a fake tensor representing
17    the metadata for the node.  Unlike ShapeProp, (1) this propagation
18    is cheap--it does the propagation with meta tensors which do not actually
19    store data, and (2) the fake tensors have much more fine grained information,
20    e.g., they have accurate alias information that can be consulted by looking
21    at the storages.
22
23    Args:
24         module (GraphModule): The module to be executed
25         mode (Optional[FakeTensorMode]): The dispatch mode used to execute computation indicated by each FX Node.
26    """
27    def __init__(self, module: torch.fx.GraphModule, mode: Optional[FakeTensorMode] = None):
28        super().__init__(module)
29        if mode is None:
30            mode = FakeTensorMode()
31        self._mode = mode
32        mode.epoch += 1
33        mode.reset_nt_tensor_id_counter()
34
35    def run_node(self, n: Node):
36        from torch.fx.experimental.symbolic_shapes import rebind_unbacked, compute_unbacked_bindings
37
38        result = super().run_node(n)
39        rebind_unbacked(self._mode.shape_env, n, result)
40
41        def extract_val(obj):
42            if isinstance(obj, FakeTensor):
43                return snapshot_fake(obj)
44            elif isinstance(obj, torch.Tensor):
45                # TODO: How is it possible that we get a non fake tensor?  We
46                # should be running under the mode...
47                return snapshot_fake(self._mode.from_tensor(obj, static_shapes=True))
48            elif isinstance(obj, py_sym_types):
49                return obj
50            else:
51                return None
52
53        meta = map_aggregate(result, extract_val)
54        if meta is not None:
55            n.meta['val'] = meta
56            if (shape_env := self._mode.shape_env) and (symbol_to_path := compute_unbacked_bindings(shape_env, result)):
57                n.meta["unbacked_bindings"] = symbol_to_path
58
59        return result
60
61    def propagate(self, *args):
62        fake_args = [
63            self._mode.from_tensor(a) if isinstance(a, torch.Tensor) else a
64            for a in args
65        ]
66        return self.propagate_dont_convert_inputs(*fake_args)
67
68    def propagate_dont_convert_inputs(self, *args):
69        with self._mode:
70            return super().run(*args)
71