xref: /aosp_15_r20/external/pytorch/torch/fx/interpreter.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from .graph_module import GraphModule
3from ._lazy_graph_module import _make_graph_module
4from .graph import Graph
5from .node import Argument, Node, Target, map_arg, map_aggregate
6from .proxy import Proxy
7from ._symbolic_trace import Tracer
8from ._compatibility import compatibility
9from . import config
10import torch.fx.traceback as fx_traceback
11import torch
12from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
13import inspect
14from contextlib import contextmanager
15from torch.hub import tqdm
16
17__all__ = ['Interpreter', 'Transformer']
18
19@compatibility(is_backward_compatible=True)
20class Interpreter:
21    """
22    An Interpreter executes an FX graph Node-by-Node. This pattern
23    can be useful for many things, including writing code
24    transformations as well as analysis passes.
25
26    Methods in the Interpreter class can be overridden to customize
27    the behavior of execution. The map of overrideable methods
28    in terms of call hierarchy::
29
30        run()
31            +-- run_node
32                +-- placeholder()
33                +-- get_attr()
34                +-- call_function()
35                +-- call_method()
36                +-- call_module()
37                +-- output()
38
39    Example:
40
41        Suppose we want to swap all instances of ``torch.neg`` with
42        ``torch.sigmoid`` and vice versa (including their ``Tensor``
43        method equivalents). We could subclass Interpreter like so::
44
45            class NegSigmSwapInterpreter(Interpreter):
46                def call_function(self, target : Target,
47                                  args : Tuple, kwargs : Dict) -> Any:
48                    if target == torch.sigmoid:
49                        return torch.neg(*args, **kwargs)
50                    return super().call_function(n)
51
52                def call_method(self, target : Target,
53                                args : Tuple, kwargs : Dict) -> Any:
54                    if target == 'neg':
55                        call_self, *args_tail = args
56                        return call_self.sigmoid(*args_tail, **kwargs)
57                    return super().call_method(n)
58
59            def fn(x):
60                return torch.sigmoid(x).neg()
61
62            gm = torch.fx.symbolic_trace(fn)
63            input = torch.randn(3, 4)
64            result = NegSigmSwapInterpreter(gm).run(input)
65            torch.testing.assert_close(result, torch.neg(input).sigmoid())
66
67    Args:
68        module (torch.nn.Module): The module to be executed
69        garbage_collect_values (bool): Whether to delete values after their last
70            use within the Module's execution. This ensures optimal memory usage during
71            execution. This can be disabled to, for example, examine all of the intermediate
72            values in the execution by looking at the ``Interpreter.env`` attribute.
73        graph (Optional[Graph]): If passed, the interpreter will execute this
74            graph instead of `module.graph`, using the provided `module`
75            argument to satisfy any requests for state.
76    """
77    @compatibility(is_backward_compatible=True)
78    def __init__(self, module: torch.nn.Module, garbage_collect_values: bool = True, graph: Optional[Graph] = None):
79        self.module = module
80        self.submodules = dict(self.module.named_modules())
81        if graph is not None:
82            self.graph = graph
83        else:
84            self.graph = self.module.graph
85        self.env : Dict[Node, Any] = {}
86        self.name = "Interpreter"
87        self.garbage_collect_values = garbage_collect_values
88        self.extra_traceback = True
89
90        if self.garbage_collect_values:
91            # Run through reverse nodes and record the first instance of a use
92            # of a given node. This represents the *last* use of the node in the
93            # execution order of the program, which we will use to free unused
94            # values
95            node_to_last_use : Dict[Node, Node] = {}
96            self.user_to_last_uses : Dict[Node, List[Node]] = {}
97
98            def register_last_uses(n : Node, user : Node):
99                if n not in node_to_last_use:
100                    node_to_last_use[n] = user
101                    self.user_to_last_uses.setdefault(user, []).append(n)
102
103            for node in reversed(self.graph.nodes):
104                map_arg(node.args, lambda n: register_last_uses(n, node))
105                map_arg(node.kwargs, lambda n: register_last_uses(n, node))
106
107    @compatibility(is_backward_compatible=True)
108    def run(self, *args, initial_env : Optional[Dict[Node, Any]] = None, enable_io_processing : bool = True) -> Any:
109        """
110        Run `module` via interpretation and return the result.
111
112        Args:
113            *args: The arguments to the Module to run, in positional order
114            initial_env (Optional[Dict[Node, Any]]): An optional starting environment for execution.
115                This is a dict mapping `Node` to any value. This can be used, for example, to
116                pre-populate results for certain `Nodes` so as to do only partial evaluation within
117                the interpreter.
118            enable_io_processing (bool): If true, we process the inputs and outputs with graph's process_inputs and
119                process_outputs function first before using them.
120
121        Returns:
122            Any: The value returned from executing the Module
123        """
124        self.env = initial_env if initial_env is not None else {}
125
126        # Positional function args are consumed left-to-right by
127        # `placeholder` nodes. Use an iterator to keep track of
128        # position and extract those values.
129        if enable_io_processing:
130            args = self.graph.process_inputs(*args)
131        self.args_iter : Iterator[Any] = iter(args)
132        pbar = tqdm(total=len(self.graph.nodes),
133                    desc=f"{self.name}: {str(list(self.graph.nodes)) if config.verbose_progress else ''}",
134                    initial=0, position=0, leave=True, disable=config.disable_progress, delay=0)
135
136        for node in self.graph.nodes:
137            pbar.update(1)
138            if node in self.env:
139                # Short circuit if we have this value. This could
140                # be used, for example, for partial evaluation
141                # where the caller has pre-populated `env` with
142                # values for a subset of the program.
143                continue
144
145            try:
146                self.env[node] = self.run_node(node)
147            except Exception as e:
148                if self.extra_traceback:
149                    msg = f"While executing {node.format_node()}"
150                    msg = f'{e.args[0]}\n\n{msg}' if e.args else str(msg)
151                    msg += f"\nOriginal traceback:\n{node.stack_trace}"
152                    e.args = (msg,) + e.args[1:]
153                    if isinstance(e, KeyError):
154                        raise RuntimeError(*e.args) from e
155                raise
156
157            if self.garbage_collect_values:
158                for to_delete in self.user_to_last_uses.get(node, []):
159                    del self.env[to_delete]
160
161            if node.op == 'output':
162                output_val = self.env[node]
163                return self.graph.process_outputs(output_val) if enable_io_processing else output_val
164
165    @compatibility(is_backward_compatible=True)
166    def boxed_run(self, args_list):
167        """
168        Run `module` via interpretation and return the result.  This uses the "boxed"
169        calling convention, where you pass a list of arguments, which will be cleared
170        by the interpreter.  This ensures that input tensors are promptly deallocated.
171        """
172        args_iter = iter(args_list)
173        env = {}
174        for n in self.graph.nodes:
175            if n.op == "placeholder":
176                env[n] = next(args_iter)
177        args_list.clear()
178        return self.run(initial_env=env)
179
180    @contextmanager
181    def _set_current_node(self, node):
182        with fx_traceback.set_current_meta(node):
183            yield
184
185    @compatibility(is_backward_compatible=True)
186    def run_node(self, n : Node) -> Any:
187        """
188        Run a specific node ``n`` and return the result.
189        Calls into placeholder, get_attr, call_function,
190        call_method, call_module, or output depending
191        on ``node.op``
192
193        Args:
194            n (Node): The Node to execute
195
196        Returns:
197            Any: The result of executing ``n``
198        """
199        with self._set_current_node(n):
200            args, kwargs = self.fetch_args_kwargs_from_env(n)
201            assert isinstance(args, tuple)
202            assert isinstance(kwargs, dict)
203            return getattr(self, n.op)(n.target, args, kwargs)
204
205    # Main Node running APIs
206    @compatibility(is_backward_compatible=True)
207    def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
208        """
209        Execute a ``placeholder`` node. Note that this is stateful:
210        ``Interpreter`` maintains an internal iterator over
211        arguments passed to ``run`` and this method returns
212        next() on that iterator.
213
214        Args:
215            target (Target): The call target for this node. See
216                `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for
217                details on semantics
218            args (Tuple): Tuple of positional args for this invocation
219            kwargs (Dict): Dict of keyword arguments for this invocation
220
221        Returns:
222            Any: The argument value that was retrieved.
223        """
224        assert isinstance(target, str)
225        if target.startswith('*'):
226            # For a starred parameter e.g. `*args`, retrieve all
227            # remaining values from the args list.
228            return list(self.args_iter)
229        else:
230            try:
231                return next(self.args_iter)
232            except StopIteration as si:
233                if len(args) > 0:
234                    return args[0]
235                else:
236                    raise RuntimeError(f'Expected positional argument for parameter {target}, but one was not passed in!') from si
237
238    @compatibility(is_backward_compatible=True)
239    def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
240        """
241        Execute a ``get_attr`` node. Will retrieve an attribute
242        value from the ``Module`` hierarchy of ``self.module``.
243
244        Args:
245            target (Target): The call target for this node. See
246                `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for
247                details on semantics
248            args (Tuple): Tuple of positional args for this invocation
249            kwargs (Dict): Dict of keyword arguments for this invocation
250
251        Return:
252            Any: The value of the attribute that was retrieved
253        """
254        assert isinstance(target, str)
255        return self.fetch_attr(target)
256
257    @compatibility(is_backward_compatible=True)
258    def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
259        """
260        Execute a ``call_function`` node and return the result.
261
262        Args:
263            target (Target): The call target for this node. See
264                `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for
265                details on semantics
266            args (Tuple): Tuple of positional args for this invocation
267            kwargs (Dict): Dict of keyword arguments for this invocation
268
269        Return
270            Any: The value returned by the function invocation
271        """
272        assert not isinstance(target, str)
273
274        # Execute the function and return the result
275        return target(*args, **kwargs)
276
277    @compatibility(is_backward_compatible=True)
278    def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
279        """
280        Execute a ``call_method`` node and return the result.
281
282        Args:
283            target (Target): The call target for this node. See
284                `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for
285                details on semantics
286            args (Tuple): Tuple of positional args for this invocation
287            kwargs (Dict): Dict of keyword arguments for this invocation
288
289        Return
290            Any: The value returned by the method invocation
291        """
292        # args[0] is the `self` object for this method call
293        self_obj, *args_tail = args
294
295        # Execute the method and return the result
296        assert isinstance(target, str)
297        return getattr(self_obj, target)(*args_tail, **kwargs)
298
299    @compatibility(is_backward_compatible=True)
300    def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
301        """
302        Execute a ``call_module`` node and return the result.
303
304        Args:
305            target (Target): The call target for this node. See
306                `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for
307                details on semantics
308            args (Tuple): Tuple of positional args for this invocation
309            kwargs (Dict): Dict of keyword arguments for this invocation
310
311        Return
312            Any: The value returned by the module invocation
313        """
314        # Retrieve executed args and kwargs values from the environment
315
316        # Execute the method and return the result
317        assert isinstance(target, str)
318        submod = self.fetch_attr(target)
319
320        return submod(*args, **kwargs)
321
322    @compatibility(is_backward_compatible=True)
323    def output(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
324        """
325        Execute an ``output`` node. This really just retrieves
326        the value referenced by the ``output`` node and returns it.
327
328        Args:
329            target (Target): The call target for this node. See
330                `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for
331                details on semantics
332            args (Tuple): Tuple of positional args for this invocation
333            kwargs (Dict): Dict of keyword arguments for this invocation
334
335        Return:
336            Any: The return value referenced by the output node
337        """
338        return args[0]
339
340    # Helper methods
341    @compatibility(is_backward_compatible=True)
342    def fetch_attr(self, target : str):
343        """
344        Fetch an attribute from the ``Module`` hierarchy of ``self.module``.
345
346        Args:
347            target (str): The fully-qualified name of the attribute to fetch
348
349        Return:
350            Any: The value of the attribute.
351        """
352        target_atoms = target.split('.')
353        attr_itr = self.module
354        for i, atom in enumerate(target_atoms):
355            if not hasattr(attr_itr, atom):
356                raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i+1])}")
357            attr_itr = getattr(attr_itr, atom)
358        return attr_itr
359
360    @compatibility(is_backward_compatible=True)
361    def fetch_args_kwargs_from_env(self, n : Node) -> Tuple[Tuple, Dict]:
362        """
363        Fetch the concrete values of ``args`` and ``kwargs`` of node ``n``
364        from the current execution environment.
365
366        Args:
367            n (Node): The node for which ``args`` and ``kwargs`` should be fetched.
368
369        Return:
370            Tuple[Tuple, Dict]: ``args`` and ``kwargs`` with concrete values for ``n``.
371        """
372        args = self.map_nodes_to_values(n.args, n)
373        assert isinstance(args, tuple)
374        kwargs = self.map_nodes_to_values(n.kwargs, n)
375        assert isinstance(kwargs, dict)
376        return args, kwargs
377
378    @compatibility(is_backward_compatible=True)
379    def map_nodes_to_values(self, args : Argument, n : Node) -> Argument:
380        """
381        Recursively descend through ``args`` and look up the concrete value
382        for each ``Node`` in the current execution environment.
383
384        Args:
385            args (Argument): Data structure within which to look up concrete values
386
387            n (Node): Node to which ``args`` belongs. This is only used for error reporting.
388        """
389        def load_arg(n_arg : Node) -> Any:
390            if n_arg not in self.env:
391                raise RuntimeError(f'Node {n} referenced nonexistent value {n_arg}! Run Graph.lint() '
392                                   f'to diagnose such issues')
393            return self.env[n_arg]
394        return map_arg(args, load_arg)
395
396@compatibility(is_backward_compatible=True)
397class Transformer(Interpreter):
398    """
399    ``Transformer`` is a special type of interpreter that produces a
400    new ``Module``. It exposes a ``transform()`` method that returns
401    the transformed ``Module``. ``Transformer`` does not require
402    arguments to run, as ``Interpreter`` does. ``Transformer`` works
403    entirely symbolically.
404
405    Example:
406
407        Suppose we want to swap all instances of ``torch.neg`` with
408        ``torch.sigmoid`` and vice versa (including their ``Tensor``
409        method equivalents). We could subclass ``Transformer`` like so::
410
411            class NegSigmSwapXformer(Transformer):
412                def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
413                    if target == torch.sigmoid:
414                        return torch.neg(*args, **kwargs)
415                    return super().call_function(n)
416
417                def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
418                    if target == 'neg':
419                        call_self, *args_tail = args
420                        return call_self.sigmoid(*args_tail, **kwargs)
421                    return super().call_method(n)
422
423            def fn(x):
424                return torch.sigmoid(x).neg()
425
426            gm = torch.fx.symbolic_trace(fn)
427
428            transformed : torch.nn.Module = NegSigmSwapXformer(gm).transform()
429            input = torch.randn(3, 4)
430            torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid())
431
432    Args:
433        module (GraphModule): The ``Module`` to be transformed.
434    """
435
436    @compatibility(is_backward_compatible=True)
437    def __init__(self, module):
438        super().__init__(module)
439        self.new_graph = Graph()
440        self.new_graph.set_codegen(module.graph._codegen)
441
442        class TransformerTracer(Tracer):
443            def __init__(self, graph: Graph):
444                super().__init__()
445                self.graph = graph
446                self.tensor_attrs: Dict[torch.Tensor, str] = {}  # type: ignore[assignment]
447
448            def is_leaf_module(self, _, __) -> bool:
449                return True
450
451        self.tracer = TransformerTracer(self.new_graph)
452        self.tracer.root = module
453
454    @compatibility(is_backward_compatible=True)
455    def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy:
456        """
457        Execute a ``placeholder`` node. In ``Transformer``, this is
458        overridden to insert a new ``placeholder`` into the output
459        graph.
460
461        Args:
462            target (Target): The call target for this node. See
463                `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for
464                details on semantics
465            args (Tuple): Tuple of positional args for this invocation
466            kwargs (Dict): Dict of keyword arguments for this invocation
467        """
468        assert isinstance(target, str)
469        default_value = next(iter(args)) if args else inspect.Signature.empty
470        return Proxy(self.new_graph.placeholder(target, default_value=default_value), self.tracer)
471
472    @compatibility(is_backward_compatible=True)
473    def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy:
474        """
475        Execute a ``get_attr`` node. In ``Transformer``, this is
476        overridden to insert a new ``get_attr`` node into the output
477        graph.
478
479        Args:
480            target (Target): The call target for this node. See
481                `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for
482                details on semantics
483            args (Tuple): Tuple of positional args for this invocation
484            kwargs (Dict): Dict of keyword arguments for this invocation
485        """
486        assert isinstance(target, str)
487        return self.tracer.create_proxy("get_attr", target, args, kwargs)
488
489    @compatibility(is_backward_compatible=True)
490    def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
491        # Override so that the leaf module policy from `self.tracer` is respected.
492        assert isinstance(target, str)
493        submod = self.fetch_attr(target)
494        return self.tracer.call_module(submod, submod.forward, args, kwargs)
495
496    @compatibility(is_backward_compatible=True)
497    def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
498        # Override so that functions that were wrapped are still wrapped.
499        return self.tracer.create_proxy('call_function', target, args, kwargs)
500
501    @compatibility(is_backward_compatible=True)
502    def transform(self) -> GraphModule:
503        """
504        Transform ``self.module`` and return the transformed
505        ``GraphModule``.
506        """
507        with fx_traceback.preserve_node_meta():
508            result = super().run(enable_io_processing=False)
509        if result is not None:
510            def strip_proxy(a : Union[Argument, Proxy]) -> Any:
511                return a.node if isinstance(a, Proxy) else a
512            new_output_node = self.new_graph.output(map_aggregate(result, strip_proxy))
513            # also preserve the metadata from the old output node, if it exists
514            old_output_node = list(self.graph.nodes)[-1]
515            assert old_output_node.op == "output"
516            for k, v in old_output_node.meta.items():
517                new_output_node.meta[k] = v
518
519
520        return _make_graph_module(self.module, self.new_graph)
521