xref: /aosp_15_r20/external/pytorch/torch/fx/experimental/meta_tracer.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3import torch.fx
4import warnings
5import functools
6import builtins
7
8from typing import Any, Callable, Dict, Optional, Union
9
10def embedding_override(self, input):
11    return torch.empty(*input.shape, self.weight.shape[-1], device='meta')
12
13
14def nn_layernorm_override(self, input):
15    return input
16
17
18def torch_relu_override(x):
19    return x
20
21
22def torch_nn_relu_override(self, x):
23    return x
24
25
26def functional_relu_override(x, inplace=False):
27    assert not inplace, 'dont support inplace functional.relu for metatensor analysis'
28    return x
29
30
31def torch_where_override(condition, x, y):
32    # torch.where returns the broadcasted tensor of condition, x, and y,
33    # so hack it by using addition
34    return condition.to(device='meta') + x.to(device='meta') + y.to(device='meta')
35
36
37def torch_abs_override(input, *, out=None):
38    assert out is None, 'Dont support in-place abs for MetaTensor analysis'
39    return input
40
41manual_meta_overrides : Dict[Callable, Callable] = {
42    torch.nn.Embedding: embedding_override,
43    torch.nn.LayerNorm: nn_layernorm_override,
44    torch.relu: torch_relu_override,
45    torch.nn.functional.relu: functional_relu_override,
46    torch.nn.ReLU: torch_nn_relu_override,
47    torch.where: torch_where_override,
48    torch.abs: torch_abs_override,
49}
50
51def gen_constructor_wrapper(target):
52    @functools.wraps(target)
53    def wrapper(*args, **kwargs):
54        proxy = None
55
56        def check_has_proxy(v):
57            if isinstance(v, torch.fx.Proxy):
58                nonlocal proxy
59                proxy = v
60        torch.fx.node.map_aggregate(args, check_has_proxy)
61        torch.fx.node.map_aggregate(kwargs, check_has_proxy)
62
63        if proxy is not None:
64            return proxy.tracer.create_proxy('call_function', target, args, kwargs)
65        else:
66            return target(*args, **kwargs)
67    return wrapper, target
68
69class MetaProxy(torch.fx.Proxy):
70    def install_tensor_meta(self, tensor_meta):
71        self._tensor_meta = tensor_meta
72
73    def size(self, dim=None):
74        if hasattr(self, '_tensor_meta') and self._tensor_meta is not None:
75            return self._tensor_meta.size(*[dim] if dim else [])
76        return self.tracer.create_proxy('call_method', 'size', (self, dim) if dim else (self,), {})
77
78    def dim(self):
79        if hasattr(self, '_tensor_meta') and self._tensor_meta is not None:
80            return self._tensor_meta.dim()
81        return self.tracer.create_proxy('call_method', 'dim', (self,), {})
82
83    @property
84    def shape(self):
85        if hasattr(self, '_tensor_meta') and self._tensor_meta is not None:
86            return self._tensor_meta.shape
87        return self.tracer.create_proxy('call_function', builtins.getattr, (self, 'shape'), {})
88
89    @property
90    def dtype(self):
91        if hasattr(self, '_tensor_meta') and self._tensor_meta is not None:
92            return self._tensor_meta.dtype
93        return self.tracer.create_proxy('call_function', builtins.getattr, (self, 'dtype'), {})
94
95    @property
96    def device(self):
97        # Hack so we can track when devices are used. During meta-tensor propagation,
98        # replace these values with a constant 'meta'
99        return MetaDeviceAttribute(self, 'device')
100
101    def __getattr__(self, k):
102        if k == '_tensor_meta':
103            return self.__getattribute__(k)
104        # note: not added to the graph yet, if this is a method call
105        # we peephole optimize to the method invocation
106        return MetaAttribute(self, k)
107
108class MetaAttribute(MetaProxy):
109    def __init__(self, root, attr: str):
110
111        self.root = root
112        self.attr = attr
113        self.tracer = root.tracer
114        self._node = None
115
116    @property
117    def node(self):
118        # the node for attributes is added lazily, since most will just be method calls
119        # which do not rely on the getitem call
120        if self._node is None:
121            self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node
122        return self._node
123
124    def __call__(self, *args, **kwargs):
125        return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
126
127class MetaDeviceAttribute(MetaAttribute):
128    pass
129
130def proxys_to_metas(v):
131    if isinstance(v, MetaDeviceAttribute):
132        return 'meta'
133    if isinstance(v, torch.fx.Proxy):
134        assert isinstance(v, MetaProxy), f'Expected MetaProxy but got {type(v)}'
135        assert hasattr(v, '_tensor_meta'), 'MetaProxy does not have an associated meta'
136        return v._tensor_meta
137    return v
138
139class MetaTracer(torch.fx.Tracer):
140    allow_insert_stateless_mods : bool = True
141
142    _TORCH_METHODS_TO_PATCH = ['arange', 'zeros', 'ones', 'full_like', 'eye']
143
144    def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None):
145        rv = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
146
147        if kind == 'placeholder' and target in self.meta_args:
148            rv.install_tensor_meta(self.meta_args[target])
149            return rv
150
151        if target in self.orig_fns:
152            # NOTE: tensor constructors in PyTorch define the `device` argument as
153            # *kwargs-only*. That is why this works. If you add methods to
154            # _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only,
155            # this will break and you will likely see issues where we cannot infer
156            # the size of the output.
157            if 'device' in kwargs:
158                kwargs['device'] = 'meta'
159
160        try:
161            args_metas = torch.fx.node.map_aggregate(args, proxys_to_metas)
162            kwargs_metas = torch.fx.node.map_aggregate(kwargs, proxys_to_metas)
163
164            if kind == 'call_function':
165                meta_target = manual_meta_overrides.get(target, target)
166                meta_out = meta_target(*args_metas, **kwargs_metas)
167            elif kind == 'call_method':
168                meta_out = getattr(args_metas[0], target)(*args_metas[1:], **kwargs_metas)  # type: ignore[index]
169            elif kind == 'call_module':
170                assert hasattr(self, 'orig_forward')
171                self._disable_module_getattr = True
172                try:
173                    mod = self.root.get_submodule(target)
174                    mod_type = type(mod)
175                    if mod_type in manual_meta_overrides:
176                        meta_out = manual_meta_overrides[mod_type](mod, *args_metas, **kwargs_metas)  # type: ignore[misc, arg-type]
177                    else:
178                        meta_out = self.orig_forward(*args_metas, **kwargs_metas)
179                finally:
180                    self._disable_module_getattr = False
181            elif kind == 'get_attr':
182                self._disable_module_getattr = True
183                try:
184                    attr_itr = self.root
185                    atoms = target.split('.')
186                    for atom in atoms:
187                        attr_itr = getattr(attr_itr, atom)
188                    assert isinstance(attr_itr, torch.Tensor)
189                    meta_out = attr_itr.to(device='meta')
190                finally:
191                    self._disable_module_getattr = False
192            else:
193                return rv
194
195            # TODO
196            assert isinstance(rv, torch.fx.Proxy), 'Dont support composite output yet'
197            rv.install_tensor_meta(meta_out)
198        except Exception as e:
199            warnings.warn(f'Could not compute metadata for {kind} target {target}: {e}')
200
201        return rv
202
203    def getattr(self, attr, attr_val, parameter_proxy_cache):
204        if getattr(self, '_disable_module_getattr', False):
205            return attr_val
206        else:
207            return super().getattr(attr, attr_val, parameter_proxy_cache)
208
209    def call_module(self, m, forward, args, kwargs):
210        self.orig_forward = forward
211        return super().call_module(m, forward, args, kwargs)
212
213    def _insert_module_as_submodule(self, mod: torch.nn.Module) -> str:
214        """
215        Helper method which tries to insert a module that was not declared as submodule.
216        """
217        idx = 0
218        mod_name = mod.__class__.__name__.lower()
219        path = f"{mod_name}_{idx}"
220        while hasattr(self.root, path):
221            path = f"{mod_name}_{idx}"
222            idx += 1
223
224        self.root.add_module(path, mod)
225        return path
226
227    def path_of_module(self, mod: torch.nn.Module) -> str:
228        try:
229            return super().path_of_module(mod)
230        except NameError as e:
231            if self.allow_insert_stateless_mods and len(list(mod.parameters())) == 0 and len(list(mod.buffers())) == 0:
232                path = self._insert_module_as_submodule(mod)
233                self.prev_module = path
234                return path
235            raise
236
237    def proxy(self, node):
238        return MetaProxy(node, self)
239
240    def trace(self, root, meta_args : Dict[str, torch.Tensor], concrete_args=None):  # type: ignore[override]
241        assert isinstance(meta_args, dict)
242        self.meta_args = meta_args
243
244        self.patched_torch_methods = {
245            target: gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH
246        }
247        self.orig_fns = set()
248
249        for name, (wrapper, orig) in self.patched_torch_methods.items():
250            setattr(torch, name, wrapper)
251            self.orig_fns.add(orig)
252
253        try:
254            graph = super().trace(root, concrete_args)
255            graph._tracer_extras = {'meta_args': meta_args}
256            return graph
257        finally:
258            for name, (_, orig) in self.patched_torch_methods.items():
259                setattr(torch, name, orig)
260
261
262def symbolic_trace(root : Union[torch.nn.Module, Callable[..., Any]],
263                   meta_args : Optional[Dict[str, torch.Tensor]] = None,
264                   concrete_args: Optional[Dict[str, Any]] = None) -> torch.fx.GraphModule:
265    tracer = MetaTracer()
266    graph = tracer.trace(root, meta_args, concrete_args)  # type: ignore[arg-type]
267    name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
268    gm = torch.fx.GraphModule(tracer.root, graph, name)
269    return gm
270