xref: /aosp_15_r20/external/pytorch/torch/fx/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1r'''
2FX is a toolkit for developers to use to transform ``nn.Module``
3instances. FX consists of three main components: a **symbolic tracer,**
4an **intermediate representation**, and **Python code generation**. A
5demonstration of these components in action:
6
7::
8
9    import torch
10    # Simple module for demonstration
11    class MyModule(torch.nn.Module):
12        def __init__(self) -> None:
13            super().__init__()
14            self.param = torch.nn.Parameter(torch.rand(3, 4))
15            self.linear = torch.nn.Linear(4, 5)
16
17        def forward(self, x):
18            return self.linear(x + self.param).clamp(min=0.0, max=1.0)
19
20    module = MyModule()
21
22    from torch.fx import symbolic_trace
23    # Symbolic tracing frontend - captures the semantics of the module
24    symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)
25
26    # High-level intermediate representation (IR) - Graph representation
27    print(symbolic_traced.graph)
28    """
29    graph():
30        %x : [num_users=1] = placeholder[target=x]
31        %param : [num_users=1] = get_attr[target=param]
32        %add : [num_users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
33        %linear : [num_users=1] = call_module[target=linear](args = (%add,), kwargs = {})
34        %clamp : [num_users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
35        return clamp
36    """
37
38    # Code generation - valid Python code
39    print(symbolic_traced.code)
40    """
41    def forward(self, x):
42        param = self.param
43        add = x + param;  x = param = None
44        linear = self.linear(add);  add = None
45        clamp = linear.clamp(min = 0.0, max = 1.0);  linear = None
46        return clamp
47    """
48
49The **symbolic tracer** performs "symbolic execution" of the Python
50code. It feeds fake values, called Proxies, through the code. Operations
51on theses Proxies are recorded. More information about symbolic tracing
52can be found in the :func:`symbolic_trace` and :class:`Tracer`
53documentation.
54
55The **intermediate representation** is the container for the operations
56that were recorded during symbolic tracing. It consists of a list of
57Nodes that represent function inputs, callsites (to functions, methods,
58or :class:`torch.nn.Module` instances), and return values. More information
59about the IR can be found in the documentation for :class:`Graph`. The
60IR is the format on which transformations are applied.
61
62**Python code generation** is what makes FX a Python-to-Python (or
63Module-to-Module) transformation toolkit. For each Graph IR, we can
64create valid Python code matching the Graph's semantics. This
65functionality is wrapped up in :class:`GraphModule`, which is a
66:class:`torch.nn.Module` instance that holds a :class:`Graph` as well as a
67``forward`` method generated from the Graph.
68
69Taken together, this pipeline of components (symbolic tracing ->
70intermediate representation -> transforms -> Python code generation)
71constitutes the Python-to-Python transformation pipeline of FX. In
72addition, these components can be used separately. For example,
73symbolic tracing can be used in isolation to capture a form of
74the code for analysis (and not transformation) purposes. Code
75generation can be used for programmatically generating models, for
76example from a config file. There are many uses for FX!
77
78Several example transformations can be found at the
79`examples <https://github.com/pytorch/examples/tree/master/fx>`__
80repository.
81'''
82
83from .graph_module import GraphModule
84from ._symbolic_trace import symbolic_trace, Tracer, wrap, PH, ProxyableClassMeta
85from .graph import Graph, CodeGen
86from .node import Node, map_arg, has_side_effect
87from .proxy import Proxy
88from .interpreter import Interpreter as Interpreter, Transformer as Transformer
89from .subgraph_rewriter import replace_pattern
90