xref: /aosp_15_r20/external/pytorch/docs/source/fx.rst (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1.. currentmodule:: torch.fx
2
3torch.fx
4=============
5
6Overview
7--------
8.. automodule:: torch.fx
9
10.. _Writing Transformations:
11
12
13Writing Transformations
14-----------------------
15
16What is an FX transform? Essentially, it's a function that looks like this.
17
18::
19
20    import torch
21    import torch.fx
22
23    def transform(m: nn.Module,
24                  tracer_class : type = torch.fx.Tracer) -> torch.nn.Module:
25        # Step 1: Acquire a Graph representing the code in `m`
26
27        # NOTE: torch.fx.symbolic_trace is a wrapper around a call to
28        # fx.Tracer.trace and constructing a GraphModule. We'll
29        # split that out in our transform to allow the caller to
30        # customize tracing behavior.
31        graph : torch.fx.Graph = tracer_class().trace(m)
32
33        # Step 2: Modify this Graph or create a new one
34        graph = ...
35
36        # Step 3: Construct a Module to return
37        return torch.fx.GraphModule(m, graph)
38
39Your transform will take in a :class:`torch.nn.Module`, acquire a :class:`Graph`
40from it, do some modifications, and return a new
41:class:`torch.nn.Module`. You should think of the :class:`torch.nn.Module` that your FX
42transform returns as identical to a regular :class:`torch.nn.Module` -- you can pass it to another
43FX transform, you can pass it to TorchScript, or you can
44run it. Ensuring that the inputs and outputs of your FX transform are a
45:class:`torch.nn.Module` will allow for composability.
46
47.. note::
48
49    It is also possible to modify an existing :class:`GraphModule` instead of
50    creating a new one, like so::
51
52        import torch
53        import torch.fx
54
55        def transform(m : nn.Module) -> nn.Module:
56            gm : torch.fx.GraphModule = torch.fx.symbolic_trace(m)
57
58            # Modify gm.graph
59            # <...>
60
61            # Recompile the forward() method of `gm` from its Graph
62            gm.recompile()
63
64            return gm
65
66    Note that you MUST call :meth:`GraphModule.recompile` to bring the generated
67    ``forward()`` method on the ``GraphModule`` in sync with the modified :class:`Graph`.
68
69Given that you’ve passed in a :class:`torch.nn.Module` that has been traced into a
70:class:`Graph`, there are now two primary approaches you can take to building a new
71:class:`Graph`.
72
73A Quick Primer on Graphs
74^^^^^^^^^^^^^^^^^^^^^^^^
75
76Full treatment of the semantics of graphs can be found in the :class:`Graph`
77documentation, but we are going to cover the basics here. A :class:`Graph` is
78a data structure that represents a method on a :class:`GraphModule`. The
79information that this requires is:
80
81- What are the inputs to the method?
82- What are the operations that run inside the method?
83- What is the output (i.e. return) value from the method?
84
85All three of these concepts are represented with :class:`Node` instances.
86Let's see what we mean by that with a short example:
87
88::
89
90    import torch
91    import torch.fx
92
93    class MyModule(torch.nn.Module):
94        def __init__(self):
95            super().__init__()
96            self.param = torch.nn.Parameter(torch.rand(3, 4))
97            self.linear = torch.nn.Linear(4, 5)
98
99        def forward(self, x):
100            return torch.topk(torch.sum(
101                self.linear(x + self.linear.weight).relu(), dim=-1), 3)
102
103    m = MyModule()
104    gm = torch.fx.symbolic_trace(m)
105
106    gm.graph.print_tabular()
107
108Here we define a module ``MyModule`` for demonstration purposes, instantiate it,
109symbolically trace it, then call the :meth:`Graph.print_tabular` method to print
110out a table showing the nodes of this :class:`Graph`:
111
112    +---------------+---------------+----------------------------+--------------------+-------------+
113    | opcode        | name          | target                     | args               | kwargs      |
114    +===============+===============+============================+====================+=============+
115    | placeholder   | x             | x                          | ()                 | {}          |
116    +---------------+---------------+----------------------------+--------------------+-------------+
117    | get_attr      | linear_weight | linear.weight              | ()                 | {}          |
118    +---------------+---------------+----------------------------+--------------------+-------------+
119    | call_function | add_1         | <built-in function add>    | (x, linear_weight) | {}          |
120    +---------------+---------------+----------------------------+--------------------+-------------+
121    | call_module   | linear_1      | linear                     | (add_1,)           | {}          |
122    +---------------+---------------+----------------------------+--------------------+-------------+
123    | call_method   | relu_1        | relu                       | (linear_1,)        | {}          |
124    +---------------+---------------+----------------------------+--------------------+-------------+
125    | call_function | sum_1         | <built-in method sum ...>  | (relu_1,)          | {'dim': -1} |
126    +---------------+---------------+----------------------------+--------------------+-------------+
127    | call_function | topk_1        | <built-in method topk ...> | (sum_1, 3)         | {}          |
128    +---------------+---------------+----------------------------+--------------------+-------------+
129    | output        | output        | output                     | (topk_1,)          | {}          |
130    +---------------+---------------+----------------------------+--------------------+-------------+
131
132We can use this information to answer the questions we posed above.
133
134- What are the inputs to the method? In FX, method inputs are specified
135  via special ``placeholder`` nodes. In this case, we have a single
136  ``placeholder`` node with a ``target`` of ``x``, meaning we have
137  a single (non-self) argument named x.
138- What are the operations within the method? The ``get_attr``,
139  ``call_function``, ``call_module``, and ``call_method`` nodes
140  represent the operations in the method. A full treatment of
141  the semantics of all of these can be found in the :class:`Node`
142  documentation.
143- What is the return value of the method? The return value in a
144  :class:`Graph` is specified by a special ``output`` node.
145
146Given that we now know the basics of how code is represented in
147FX, we can now explore how we would edit a :class:`Graph`.
148
149Graph Manipulation
150^^^^^^^^^^^^^^^^^^
151
152Direct Graph Manipulation
153~~~~~~~~~~~~~~~~~~~~~~~~~
154
155One approach to building this new :class:`Graph` is to directly manipulate your old
156one. To aid in this, we can simply take the :class:`Graph` we obtain from symbolic
157tracing and modify it. For example, let’s say we desire to replace
158:func:`torch.add` calls with :func:`torch.mul` calls.
159
160::
161
162    import torch
163    import torch.fx
164
165    # Sample module
166    class M(torch.nn.Module):
167        def forward(self, x, y):
168            return torch.add(x, y)
169
170    def transform(m: torch.nn.Module,
171                  tracer_class : type = fx.Tracer) -> torch.nn.Module:
172        graph : fx.Graph = tracer_class().trace(m)
173        # FX represents its Graph as an ordered list of
174        # nodes, so we can iterate through them.
175        for node in graph.nodes:
176            # Checks if we're calling a function (i.e:
177            # torch.add)
178            if node.op == 'call_function':
179                # The target attribute is the function
180                # that call_function calls.
181                if node.target == torch.add:
182                    node.target = torch.mul
183
184        graph.lint() # Does some checks to make sure the
185                     # Graph is well-formed.
186
187        return fx.GraphModule(m, graph)
188
189
190We can also do more involved :class:`Graph` rewrites, such as
191deleting or appending nodes. To aid in these transformations,
192FX has utility functions for transforming the graph that can
193be found in the :class:`Graph` documentation. An
194example of using these APIs to append a :func:`torch.relu` call
195can be found below.
196
197::
198
199    # Specifies the insertion point. Any nodes added to the
200    # Graph within this scope will be inserted after `node`
201    with traced.graph.inserting_after(node):
202        # Insert a new `call_function` node calling `torch.relu`
203        new_node = traced.graph.call_function(
204            torch.relu, args=(node,))
205
206        # We want all places that used the value of `node` to
207        # now use that value after the `relu` call we've added.
208        # We use the `replace_all_uses_with` API to do this.
209        node.replace_all_uses_with(new_node)
210
211For simple transformations that only consist of substitutions, you can also
212make use of the `subgraph rewriter. <https://github.com/pytorch/pytorch/blob/main/torch/fx/subgraph_rewriter.py>`__
213
214Subgraph Rewriting With replace_pattern()
215~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
216
217FX also provides another level of automation on top of direct graph manipulation.
218The :func:`replace_pattern` API is essentially a "find/replace" tool for editing
219:class:`Graph`\s. It allows you to specify a ``pattern`` and ``replacement`` function
220and it will trace through those functions, find instances of the group of operations
221in the ``pattern`` graph, and replace those instances with copies of the ``replacement``
222graph. This can help to greatly automate tedious graph manipulation code, which can
223get unwieldy as the transformations get more complex.
224
225Graph Manipulation Examples
226~~~~~~~~~~~~~~~~~~~~~~~~~~~
227
228-  `Replace one
229   op <https://github.com/pytorch/examples/blob/master/fx/replace_op.py>`__
230-  `Conv/Batch Norm
231   fusion <https://github.com/pytorch/pytorch/blob/40cbf342d3c000712da92cfafeaca651b3e0bd3e/torch/fx/experimental/optimization.py#L50>`__
232-  `replace_pattern: Basic usage <https://github.com/pytorch/examples/blob/master/fx/subgraph_rewriter_basic_use.py>`__
233-  `Quantization <https://pytorch.org/docs/main/quantization.html#prototype-fx-graph-mode-quantization>`__
234-  `Invert Transformation <https://github.com/pytorch/examples/blob/master/fx/invert.py>`__
235
236Proxy/Retracing
237^^^^^^^^^^^^^^^
238
239Another way of manipulating :class:`Graph`\s is by reusing the :class:`Proxy`
240machinery used in symbolic tracing. For example, let’s
241imagine that we wanted to write a transformation that decomposed
242PyTorch functions into smaller operations. It would transform every
243``F.relu(x)`` call into ``(x > 0) * x``. One possibility would be to
244perform the requisite graph rewriting to insert the comparison and
245multiplication after the ``F.relu``, and then clean up the original
246``F.relu``. However, we can automate this process by using :class:`Proxy`
247objects to automatically record operations into the :class:`Graph`.
248
249To use this method, we write the operations that we want inserted as regular
250PyTorch code and invoke that code with :class:`Proxy` objects as arguments.
251These :class:`Proxy` objects will capture the operations that are performed
252on them and append them to the :class:`Graph`.
253
254::
255
256    # Note that this decomposition rule can be read as regular Python
257    def relu_decomposition(x):
258        return (x > 0) * x
259
260    decomposition_rules = {}
261    decomposition_rules[F.relu] = relu_decomposition
262
263    def decompose(model: torch.nn.Module,
264                  tracer_class : type = fx.Tracer) -> torch.nn.Module:
265        """
266        Decompose `model` into smaller constituent operations.
267        Currently,this only supports decomposing ReLU into its
268        mathematical definition: (x > 0) * x
269        """
270        graph : fx.Graph = tracer_class().trace(model)
271        new_graph = fx.Graph()
272        env = {}
273        tracer = torch.fx.proxy.GraphAppendingTracer(new_graph)
274        for node in graph.nodes:
275            if node.op == 'call_function' and node.target in decomposition_rules:
276                # By wrapping the arguments with proxies,
277                # we can dispatch to the appropriate
278                # decomposition rule and implicitly add it
279                # to the Graph by symbolically tracing it.
280                proxy_args = [
281                    fx.Proxy(env[x.name], tracer) if isinstance(x, fx.Node) else x for x in node.args]
282                output_proxy = decomposition_rules[node.target](*proxy_args)
283
284                # Operations on `Proxy` always yield new `Proxy`s, and the
285                # return value of our decomposition rule is no exception.
286                # We need to extract the underlying `Node` from the `Proxy`
287                # to use it in subsequent iterations of this transform.
288                new_node = output_proxy.node
289                env[node.name] = new_node
290            else:
291                # Default case: we don't have a decomposition rule for this
292                # node, so just copy the node over into the new graph.
293                new_node = new_graph.node_copy(node, lambda x: env[x.name])
294                env[node.name] = new_node
295        return fx.GraphModule(model, new_graph)
296
297In addition to avoiding explicit graph manipulation, using :class:`Proxy`\s
298also allows you to specify your rewrite rules as native Python code.
299For transformations that require a large amount of rewrite rules
300(such as vmap or grad), this can often improve readability and
301maintainability of the rules. Note that while calling :class:`Proxy` we also
302passed a tracer pointing to the underlying variable `graph`. This is done so
303if in case the operations in graph are n-ary (e.g. add is a binary operator)
304the call to :class:`Proxy` does not create multiple instances of a graph
305tracer which can lead to unexpected runtime errors. We recommend this method
306of using :class:`Proxy` especially when the underlying operators can not be
307safely assumed to be unary.
308
309A worked example of using :class:`Proxy`\s for :class:`Graph` manipulation
310can be found
311`here <https://github.com/pytorch/examples/blob/master/fx/proxy_based_graph_creation.py>`__.
312
313The Interpreter Pattern
314^^^^^^^^^^^^^^^^^^^^^^^
315
316A useful code organizational pattern in FX is to loop over all the :class:`Node`\s
317in a :class:`Graph` and execute them. This can be used for several things including
318runtime analysis of values flowing through the graph or transformation of the code
319via retracing with :class:`Proxy`\s. For example, suppose we want to run a
320:class:`GraphModule` and record the :class:`torch.Tensor` shape and dtype
321properties on the nodes as we see them at runtime. That might look like:
322
323::
324
325    import torch
326    import torch.fx
327    from torch.fx.node import Node
328
329    from typing import Dict
330
331    class ShapeProp:
332        """
333        Shape propagation. This class takes a `GraphModule`.
334        Then, its `propagate` method executes the `GraphModule`
335        node-by-node with the given arguments. As each operation
336        executes, the ShapeProp class stores away the shape and
337        element type for the output values of each operation on
338        the `shape` and `dtype` attributes of the operation's
339        `Node`.
340        """
341        def __init__(self, mod):
342            self.mod = mod
343            self.graph = mod.graph
344            self.modules = dict(self.mod.named_modules())
345
346        def propagate(self, *args):
347            args_iter = iter(args)
348            env : Dict[str, Node] = {}
349
350            def load_arg(a):
351                return torch.fx.graph.map_arg(a, lambda n: env[n.name])
352
353            def fetch_attr(target : str):
354                target_atoms = target.split('.')
355                attr_itr = self.mod
356                for i, atom in enumerate(target_atoms):
357                    if not hasattr(attr_itr, atom):
358                        raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
359                    attr_itr = getattr(attr_itr, atom)
360                return attr_itr
361
362            for node in self.graph.nodes:
363                if node.op == 'placeholder':
364                    result = next(args_iter)
365                elif node.op == 'get_attr':
366                    result = fetch_attr(node.target)
367                elif node.op == 'call_function':
368                    result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
369                elif node.op == 'call_method':
370                    self_obj, *args = load_arg(node.args)
371                    kwargs = load_arg(node.kwargs)
372                    result = getattr(self_obj, node.target)(*args, **kwargs)
373                elif node.op == 'call_module':
374                    result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))
375
376                # This is the only code specific to shape propagation.
377                # you can delete this `if` branch and this becomes
378                # a generic GraphModule interpreter.
379                if isinstance(result, torch.Tensor):
380                    node.shape = result.shape
381                    node.dtype = result.dtype
382
383                env[node.name] = result
384
385            return load_arg(self.graph.result)
386
387As you can see, a full interpreter for FX is not that complicated
388but it can be very useful. To ease using this pattern, we provide
389the :class:`Interpreter` class, which encompasses the above logic
390in a way that certain aspects of the interpreter's execution can
391be overridden via method overrides.
392
393In addition to executing operations, we can also generate a new
394`Graph` by feeding :class:`Proxy` values through an interpreter.
395Similarly, we provide the :class:`Transformer` class to encompass
396this pattern. :class:`Transformer` behaves similarly to
397:class:`Interpreter`, but instead of calling the ``run`` method to
398get a concrete output value from the Module, you would call the
399:meth:`Transformer.transform` method to return a new
400:class:`GraphModule` which was subject to any transformation rules
401you installed as overridden methods.
402
403Examples of the Interpreter Pattern
404~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
405
406-  `Shape
407   Propagation <https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py>`__
408-  `Performance Profiler <https://github.com/pytorch/tutorials/pull/1319>`__
409
410
411Debugging
412-----------
413
414Introduction
415^^^^^^^^^^^^^^^^
416
417Often in the course of authoring transformations, our code will not be quite right.
418In this case, we may need to do some debugging. The key is to work
419backwards: first, check the results of invoking the generated module to prove or
420disprove correctness. Then, inspect and debug the generated code. Then, debug the
421process of transformations that led to the generated code.
422
423If you’re not familiar with debuggers, please see the auxiliary section
424:ref:`Available Debuggers`.
425
426
427Common Pitfalls in Transform Authoring
428^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
429
430* Nondeterministic ``set`` iteration order. In Python, the ``set`` datatype is
431  unordered. Using ``set`` to contain collections of objects like ``Node``\ s,
432  for example, can cause unexpected nondeterminism. An example is iterating
433  over a set of ``Node``\ s to insert them into a ``Graph``. Because the
434  ``set`` data type is unordered, the ordering of the operations in the output
435  program will be nondeterministic and can change across program invocations.
436  The recommended alternative is to use a ``dict`` data type, which is
437  `insertion ordered <https://mail.python.org/pipermail/python-dev/2017-December/151283.html>`_
438  as of Python 3.7 (and as of cPython 3.6). A ``dict`` can be used equivalently
439  to a set by storing values to be deduplicated in the keys of the ``dict``.
440
441Checking Correctness of Modules
442^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
443
444Because the output of most deep learning modules consists of floating
445point :class:`torch.Tensor` instances, checking for equivalence between
446the results of two :class:`torch.nn.Module` is not as straightforward
447as doing a simple equality check. To motivate this, let's use an
448example:
449
450::
451
452    import torch
453    import torch.fx
454    import torchvision.models as models
455
456    def transform(m : torch.nn.Module) -> torch.nn.Module:
457        gm = torch.fx.symbolic_trace(m)
458
459        # Imagine we're doing some transforms here
460        # <...>
461
462        gm.recompile()
463
464        return gm
465
466    resnet18 = models.resnet18()
467    transformed_resnet18 = transform(resnet18)
468
469    input_image = torch.randn(5, 3, 224, 224)
470
471    assert resnet18(input_image) == transformed_resnet18(input_image)
472    """
473    RuntimeError: Boolean value of Tensor with more than one value is ambiguous
474    """
475
476Here, we've tried to check equality of the values of two deep learning
477models with the ``==`` equality operator. However, this is not well-
478defined both due to the issue of that operator returning a tensor
479and not a bool, but also because comparison of floating point values
480should use a margin of error (or epsilon) to account for the
481non-commutativity of floating point operations (see
482`here <https://floating-point-gui.de/errors/comparison/>`__ for more
483details). We can use :func:`torch.allclose` instead, which will give
484us an approximate comparison taking into account a relative and
485absolute tolerance threshold:
486
487::
488
489    assert torch.allclose(resnet18(input_image), transformed_resnet18(input_image))
490
491This is the first tool in our toolbox to check if transformed modules are
492behaving as we expect compared to a reference implementation.
493
494Debugging the Generated Code
495^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
496
497Because FX generates the ``forward()`` function on :class:`GraphModule`\s, using
498traditional debugging techniques like ``print`` statements or ``pdb`` is
499not as straightforward. Luckily, we have several techniques we can use
500for debugging the generated code.
501
502Use ``pdb``
503~~~~~~~~~~~~~
504Invoke ``pdb`` to step into the running program. Although the code that
505represents the :class:`Graph` is not in any source file, we can still step
506into it manually using ``pdb`` when the forward pass is invoked.
507
508::
509
510    import torch
511    import torch.fx
512    import torchvision.models as models
513
514    def my_pass(inp: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module:
515        graph = tracer_class().trace(inp)
516        # Transformation logic here
517        # <...>
518
519        # Return new Module
520        return fx.GraphModule(inp, graph)
521
522    my_module = models.resnet18()
523    my_module_transformed = my_pass(my_module)
524
525    input_value = torch.randn(5, 3, 224, 224)
526
527    # When this line is executed at runtime, we will be dropped into an
528    # interactive `pdb` prompt. We can use the `step` or `s` command to
529    # step into the execution of the next line
530    import pdb; pdb.set_trace()
531
532    my_module_transformed(input_value)
533
534.. _Print the Generated Code:
535
536Print the Generated Code
537~~~~~~~~~~~~~~~~~~~~~~~~~~~
538If you’d like to run the same code multiple times, then it can be
539a bit tedious to step to the right code with ``pdb``. In that case, one
540approach is to simply copy-paste the generated ``forward`` pass into
541your code and examine it from there.
542
543::
544
545    # Assume that `traced` is a GraphModule that has undergone some
546    # number of transforms
547
548    # Copy this code for later
549    print(traced)
550    # Print the code generated from symbolic tracing. This outputs:
551    """
552    def forward(self, y):
553        x = self.x
554        add_1 = x + y;  x = y = None
555        return add_1
556    """
557
558    # Subclass the original Module
559    class SubclassM(M):
560        def __init__(self):
561            super().__init__()
562
563        # Paste the generated `forward` function (the one we printed and
564        # copied above) here
565        def forward(self, y):
566            x = self.x
567            add_1 = x + y;  x = y = None
568            return add_1
569
570    # Create an instance of the original, untraced Module. Then, create an
571    # instance of the Module with the copied `forward` function. We can
572    # now compare the output of both the original and the traced version.
573    pre_trace = M()
574    post_trace = SubclassM()
575
576Use the ``to_folder`` Function From ``GraphModule``
577~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
578:meth:`GraphModule.to_folder` is a method in ``GraphModule`` that allows
579you to dump out the generated FX code to a folder. Although copying the
580forward pass into the code often suffices as in :ref:`Print the Generated Code`,
581it may be easier to examine modules and parameters using ``to_folder``.
582
583::
584
585    m = symbolic_trace(M())
586    m.to_folder("foo", "Bar")
587    from foo import Bar
588    y = Bar()
589
590After running the above example, we can then look at the code within
591``foo/module.py`` and modify it as desired (e.g. adding ``print``
592statements or using ``pdb``) to debug the generated code.
593
594Debugging the Transformation
595^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
596
597Now that we've identified that a transformation is creating incorrect
598code, it's time to debug the transformation itself. First, we'll check
599the :ref:`Limitations of Symbolic Tracing` section in the documentation.
600Once we verify that tracing is working as expected, the goal
601becomes figuring out what went wrong during our ``GraphModule``
602transformation. There may be a quick answer in
603:ref:`Writing Transformations`, but, if not, there are several ways to
604examine our traced module:
605
606::
607
608    # Sample Module
609    class M(torch.nn.Module):
610        def forward(self, x, y):
611            return x + y
612
613    # Create an instance of `M`
614    m = M()
615
616    # Symbolically trace an instance of `M` (returns a GraphModule). In
617    # this example, we'll only be discussing how to inspect a
618    # GraphModule, so we aren't showing any sample transforms for the
619    # sake of brevity.
620    traced = symbolic_trace(m)
621
622    # Print the code produced by tracing the module.
623    print(traced)
624    # The generated `forward` function is:
625    """
626    def forward(self, x, y):
627        add = x + y;  x = y = None
628        return add
629    """
630
631    # Print the internal Graph.
632    print(traced.graph)
633    # This print-out returns:
634    """
635    graph():
636        %x : [num_users=1] = placeholder[target=x]
637        %y : [num_users=1] = placeholder[target=y]
638        %add : [num_users=1] = call_function[target=operator.add](args = (%x, %y), kwargs = {})
639        return add
640    """
641
642    # Print a tabular representation of the internal Graph.
643    traced.graph.print_tabular()
644    # This gives us:
645    """
646    opcode         name    target                   args    kwargs
647    -------------  ------  -----------------------  ------  --------
648    placeholder    x       x                        ()      {}
649    placeholder    y       y                        ()      {}
650    call_function  add     <built-in function add>  (x, y)  {}
651    output         output  output                   (add,)  {}
652    """
653
654Using the utility functions above, we can compare our traced Module
655before and after we've applied our transformations. Sometimes, a
656simple visual comparison is enough to trace down a bug. If it's still
657not clear what's going wrong, a debugger like ``pdb`` can be a good
658next step.
659
660Going off of the example above, consider the following code:
661
662::
663
664    # Sample user-defined function
665    def transform_graph(module: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module:
666        # Get the Graph from our traced Module
667        g = tracer_class().trace(module)
668
669        """
670        Transformations on `g` go here
671        """
672
673        return fx.GraphModule(module, g)
674
675    # Transform the Graph
676    transformed = transform_graph(traced)
677
678    # Print the new code after our transforms. Check to see if it was
679    # what we expected
680    print(transformed)
681
682Using the above example, let’s say that the call to ``print(traced)``
683showed us that there was an error in our transforms. We want to find
684what goes wrong using a debugger. We start a ``pdb`` session. We can see
685what’s happening during the transform by breaking on
686``transform_graph(traced)``, then pressing ``s`` to “step into” the call
687to ``transform_graph(traced)``.
688
689We may also have good luck by editing the ``print_tabular`` method to print
690different attributes of the Nodes in the Graph. (For example, we might
691want to see the Node’s ``input_nodes`` and ``users``.)
692
693.. _Available Debuggers:
694
695Available Debuggers
696^^^^^^^^^^^^^^^^^^^^^^
697
698The most common Python debugger is
699`pdb <https://docs.python.org/3/library/pdb.html>`__. You can start
700your program in “debug mode” with ``pdb`` by typing
701``python -m pdb FILENAME.py`` into the command line, where ``FILENAME``
702is the name of the file you want to debug. After that, you can use the
703``pdb`` `debugger commands
704<https://docs.python.org/3/library/pdb.html#debugger-commands>`__
705to move through your running program stepwise. It’s common to set a
706breakpoint (``b LINE-NUMBER``) when you start ``pdb``, then call ``c`` to
707run the program until that point. This prevents you from having to step
708through each line of execution (using ``s`` or ``n``) to get to the part
709of the code you want to examine. Alternatively, you can write
710``import pdb; pdb.set_trace()`` before the line you want to break at.
711If you add ``pdb.set_trace()``, your program will automatically start
712in debug mode when you run it. (In other words, you can just type
713``python FILENAME.py`` into the command line instead of
714``python -m pdb FILENAME.py``.) Once you're running your file in
715debug mode, you can step through the code and examine your program's
716internal state using certain commands. There are many excellent
717tutorials on ``pdb`` online, including RealPython’s
718`“Python Debugging With Pdb” <https://realpython.com/python-debugging-pdb/>`__.
719
720IDEs like PyCharm or VSCode usually have a debugger built in. In your
721IDE, you can choose to either a) use ``pdb`` by pulling up a terminal
722window in your IDE (e.g. View → Terminal in VSCode), or b) use the
723built-in debugger (usually a graphical wrapper around ``pdb``).
724
725.. _Limitations of Symbolic Tracing:
726
727Limitations of Symbolic Tracing
728-------------------------------
729
730FX uses a system of **symbolic tracing** (a.k.a `symbolic
731execution <https://en.wikipedia.org/wiki/Symbolic_execution>`__)
732to capture the semantics of programs in a transformable/analyzable form.
733The system is **tracing** in that it executes the program (really a
734:class:`torch.nn.Module` or function) to record operations. It is
735**symbolic** in that the data flowing through the program during this
736execution is not real data, but rather symbols (:class:`Proxy` in FX parlance).
737
738Although symbolic tracing works for most neural net code, it has some
739limitations.
740
741Dynamic Control Flow
742^^^^^^^^^^^^^^^^^^^^
743
744The main limitation of symbolic tracing is it does not currently support
745*dynamic control flow*. That is, loops or ``if`` statements where the
746condition may depend on the input values of the program.
747
748For example, let’s examine the following program:
749
750::
751
752    def func_to_trace(x):
753        if x.sum() > 0:
754            return torch.relu(x)
755        else:
756            return torch.neg(x)
757
758    traced = torch.fx.symbolic_trace(func_to_trace)
759    """
760      <...>
761      File "dyn.py", line 6, in func_to_trace
762        if x.sum() > 0:
763      File "pytorch/torch/fx/proxy.py", line 155, in __bool__
764        return self.tracer.to_bool(self)
765      File "pytorch/torch/fx/proxy.py", line 85, in to_bool
766        raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
767    torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow
768    """
769
770The condition to the ``if`` statement relies on the value of ``x.sum()``,
771which relies on the value of ``x``, a function input. Since
772``x`` can change (i.e. if you pass a new input tensor to the traced
773function), this is *dynamic control flow*. The traceback walks back up
774through your code to show you where this situation happens.
775
776Static Control Flow
777~~~~~~~~~~~~~~~~~~~
778
779On the other hand, so-called *static control flow* is supported. Static
780control flow is loops or ``if`` statements whose value cannot change
781across invocations. Typically, in PyTorch programs, this control flow
782arises for code making decisions about a model’s architecture based on
783hyper-parameters. As a concrete example:
784
785::
786
787    import torch
788    import torch.fx
789
790    class MyModule(torch.nn.Module):
791        def __init__(self, do_activation : bool = False):
792            super().__init__()
793            self.do_activation = do_activation
794            self.linear = torch.nn.Linear(512, 512)
795
796        def forward(self, x):
797            x = self.linear(x)
798            # This if-statement is so-called static control flow.
799            # Its condition does not depend on any input values
800            if self.do_activation:
801                x = torch.relu(x)
802            return x
803
804    without_activation = MyModule(do_activation=False)
805    with_activation = MyModule(do_activation=True)
806
807    traced_without_activation = torch.fx.symbolic_trace(without_activation)
808    print(traced_without_activation.code)
809    """
810    def forward(self, x):
811        linear_1 = self.linear(x);  x = None
812        return linear_1
813    """
814
815    traced_with_activation = torch.fx.symbolic_trace(with_activation)
816    print(traced_with_activation.code)
817    """
818    import torch
819    def forward(self, x):
820        linear_1 = self.linear(x);  x = None
821        relu_1 = torch.relu(linear_1);  linear_1 = None
822        return relu_1
823    """
824
825The if-statement ``if self.do_activation`` does not depend on any
826function inputs, thus it is static. ``do_activation`` can be considered
827to be a hyper-parameter, and the traces of different instances of
828``MyModule`` with different values for that parameter have different
829code. This is a valid pattern that is supported by symbolic tracing.
830
831Many instances of dynamic control flow are semantically static control
832flow. These instances can be made to support symbolic tracing by
833removing the data dependencies on input values, for example by moving
834values to ``Module`` attributes or by binding concrete values to arguments
835during symbolic tracing:
836
837::
838
839        def f(x, flag):
840            if flag: return x
841            else: return x*2
842
843        fx.symbolic_trace(f) # Fails!
844
845        fx.symbolic_trace(f, concrete_args={'flag': True})
846
847In the case of truly dynamic control flow, the sections of the program
848that contain this code can be traced as calls to the Method (see
849:ref:`Customizing Tracing`) or function (see
850:func:`wrap`) rather than tracing through them.
851
852Non-\ ``torch`` Functions
853^^^^^^^^^^^^^^^^^^^^^^^^^
854
855FX uses ``__torch_function__`` as the mechanism by which it intercepts
856calls (see the `technical
857overview <https://github.com/pytorch/pytorch/blob/master/torch/fx/OVERVIEW.md#technical-details>`__
858for more information about this). Some functions, such as builtin Python
859functions or those in the ``math`` module, are not covered by
860``__torch_function__``, but we would still like to capture them in
861symbolic tracing. For example:
862
863::
864
865    import torch
866    import torch.fx
867    from math import sqrt
868
869    def normalize(x):
870        """
871        Normalize `x` by the size of the batch dimension
872        """
873        return x / sqrt(len(x))
874
875    # It's valid Python code
876    normalize(torch.rand(3, 4))
877
878    traced = torch.fx.symbolic_trace(normalize)
879    """
880      <...>
881      File "sqrt.py", line 9, in normalize
882        return x / sqrt(len(x))
883      File "pytorch/torch/fx/proxy.py", line 161, in __len__
884        raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want "
885    RuntimeError: 'len' is not supported in symbolic tracing by default. If you want this call to be recorded, please call torch.fx.wrap('len') at module scope
886    """
887
888The error tells us that the built-in function ``len`` is not supported.
889We can make it so that functions like this are recorded in the trace as
890direct calls using the :func:`wrap` API:
891
892::
893
894    torch.fx.wrap('len')
895    torch.fx.wrap('sqrt')
896
897    traced = torch.fx.symbolic_trace(normalize)
898
899    print(traced.code)
900    """
901    import math
902    def forward(self, x):
903        len_1 = len(x)
904        sqrt_1 = math.sqrt(len_1);  len_1 = None
905        truediv = x / sqrt_1;  x = sqrt_1 = None
906        return truediv
907    """
908
909.. _Customizing Tracing:
910
911Customizing Tracing with the ``Tracer`` class
912^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
913
914The :class:`Tracer` class is the class that underlies the
915implementation of ``symbolic_trace``. The behavior of tracing can be
916customized by subclassing Tracer, like so:
917
918::
919
920    class MyCustomTracer(torch.fx.Tracer):
921        # Inside here you can override various methods
922        # to customize tracing. See the `Tracer` API
923        # reference
924        pass
925
926
927    # Let's use this custom tracer to trace through this module
928    class MyModule(torch.nn.Module):
929        def forward(self, x):
930            return torch.relu(x) + torch.ones(3, 4)
931
932    mod = MyModule()
933
934    traced_graph = MyCustomTracer().trace(mod)
935    # trace() returns a Graph. Let's wrap it up in a
936    # GraphModule to make it runnable
937    traced = torch.fx.GraphModule(mod, traced_graph)
938
939Leaf Modules
940~~~~~~~~~~~~
941
942Leaf Modules are the modules that appear as calls in the symbolic trace
943rather than being traced through. The default set of leaf modules is the
944set of standard ``torch.nn`` module instances. For example:
945
946::
947
948    class MySpecialSubmodule(torch.nn.Module):
949        def forward(self, x):
950            return torch.neg(x)
951
952    class MyModule(torch.nn.Module):
953        def __init__(self):
954            super().__init__()
955            self.linear = torch.nn.Linear(3, 4)
956            self.submod = MySpecialSubmodule()
957
958        def forward(self, x):
959            return self.submod(self.linear(x))
960
961    traced = torch.fx.symbolic_trace(MyModule())
962    print(traced.code)
963    # `linear` is preserved as a call, yet `submod` is traced though.
964    # This is because the default set of "Leaf Modules" includes all
965    # standard `torch.nn` modules.
966    """
967    import torch
968    def forward(self, x):
969        linear_1 = self.linear(x);  x = None
970        neg_1 = torch.neg(linear_1);  linear_1 = None
971        return neg_1
972    """
973
974The set of leaf modules can be customized by overriding
975:meth:`Tracer.is_leaf_module`.
976
977Miscellanea
978^^^^^^^^^^^
979
980-  Tensor constructors (e.g. ``torch.zeros``, ``torch.ones``,
981   ``torch.rand``, ``torch.randn``, ``torch.sparse_coo_tensor``)
982   are currently not traceable.
983
984   -  The deterministic constructors (``zeros``, ``ones``) can be used
985      and the value they produce will be embedded in the trace as a
986      constant. This is only problematic if the arguments to these
987      constructors refers to dynamic input sizes. In this case,
988      ``ones_like`` or ``zeros_like`` may be a viable substitute.
989   -  Nondeterministic constructors (``rand``, ``randn``) will have a
990      single random value embedded in the trace. This is likely not the
991      intended behavior. One workaround is to wrap ``torch.randn`` in a ``torch.fx.wrap`` function and call that instead.
992
993    ::
994
995        @torch.fx.wrap
996        def torch_randn(x, shape):
997            return torch.randn(shape)
998
999        def f(x):
1000            return x + torch_randn(x, 5)
1001        fx.symbolic_trace(f)
1002
1003   -  This behavior may be fixed in a future release.
1004
1005-  Type annotations
1006
1007   -  Python 3-style type annotations (e.g.
1008      ``func(x : torch.Tensor, y : int) -> torch.Tensor``) are supported
1009      and will be preserved by symbolic tracing.
1010   -  Python 2-style comment type annotations
1011      ``# type: (torch.Tensor, int) -> torch.Tensor`` are not currently
1012      supported.
1013   -  Annotations on local names within a function are not currently
1014      supported.
1015
1016
1017-  Gotcha around ``training`` flag and submodules
1018
1019   -  When using functionals like ``torch.nn.functional.dropout``, it will be common for the training argument to be passed in as ``self.training``. During FX tracing, this will likely be baked in as a constant value.
1020
1021    ::
1022
1023        import torch
1024        import torch.fx
1025
1026        class DropoutRepro(torch.nn.Module):
1027          def forward(self, x):
1028            return torch.nn.functional.dropout(x, training=self.training)
1029
1030
1031        traced = torch.fx.symbolic_trace(DropoutRepro())
1032        print(traced.code)
1033        """
1034        def forward(self, x):
1035          dropout = torch.nn.functional.dropout(x, p = 0.5, training = True, inplace = False);  x = None
1036          return dropout
1037        """
1038
1039        traced.eval()
1040
1041        x = torch.randn(5, 3)
1042        torch.testing.assert_close(traced(x), x)
1043        """
1044        AssertionError: Tensor-likes are not close!
1045
1046        Mismatched elements: 15 / 15 (100.0%)
1047        Greatest absolute difference: 1.6207983493804932 at index (0, 2) (up to 1e-05 allowed)
1048        Greatest relative difference: 1.0 at index (0, 0) (up to 0.0001 allowed)
1049        """
1050
1051   - However, when the standard ``nn.Dropout()`` submodule is used, the training flag is encapsulated and--because of the preservation of the ``nn.Module`` object model--can be changed.
1052
1053    ::
1054
1055        class DropoutRepro2(torch.nn.Module):
1056          def __init__(self):
1057            super().__init__()
1058            self.drop = torch.nn.Dropout()
1059
1060          def forward(self, x):
1061            return self.drop(x)
1062
1063        traced = torch.fx.symbolic_trace(DropoutRepro2())
1064        print(traced.code)
1065        """
1066        def forward(self, x):
1067          drop = self.drop(x);  x = None
1068          return drop
1069        """
1070
1071        traced.eval()
1072
1073        x = torch.randn(5, 3)
1074        torch.testing.assert_close(traced(x), x)
1075
1076  - Because of this difference, consider marking modules that interact with the ``training`` flag dynamically as leaf modules.
1077
1078
1079API Reference
1080-------------
1081
1082.. autofunction:: torch.fx.symbolic_trace
1083
1084.. autofunction:: torch.fx.wrap
1085
1086.. autoclass:: torch.fx.GraphModule
1087  :members:
1088
1089  .. automethod:: __init__
1090
1091.. autoclass:: torch.fx.Graph
1092  :members:
1093
1094  .. automethod:: __init__
1095
1096.. autoclass:: torch.fx.Node
1097  :members:
1098
1099.. autoclass:: torch.fx.Tracer
1100  :members:
1101  :inherited-members:
1102
1103.. autoclass:: torch.fx.Proxy
1104
1105.. autoclass:: torch.fx.Interpreter
1106  :members:
1107
1108.. autoclass:: torch.fx.Transformer
1109  :members:
1110
1111.. autofunction:: torch.fx.replace_pattern
1112
1113
1114.. The experimental and passes submodules are missing docs.
1115.. Adding it here for coverage but this doesn't add anything to the
1116.. rendered doc.
1117.. py:module:: torch.fx.passes
1118.. py:module:: torch.fx.passes.infra
1119.. py:module:: torch.fx.passes.backends
1120.. py:module:: torch.fx.passes.utils
1121.. py:module:: torch.fx.passes.tests
1122.. py:module:: torch.fx.experimental
1123.. py:module:: torch.fx.experimental.unification
1124.. py:module:: torch.fx.experimental.unification.multipledispatch
1125.. py:module:: torch.fx.experimental.migrate_gradual_types
1126.. py:module:: torch.fx.passes.dialect
1127.. py:module:: torch.fx.passes.dialect.common
1128.. py:module:: torch.fx.annotate
1129.. py:module:: torch.fx.config
1130.. py:module:: torch.fx.experimental.accelerator_partitioner
1131.. py:module:: torch.fx.experimental.const_fold
1132.. py:module:: torch.fx.experimental.debug
1133.. py:module:: torch.fx.experimental.graph_gradual_typechecker
1134.. py:module:: torch.fx.experimental.merge_matmul
1135.. py:module:: torch.fx.experimental.meta_tracer
1136.. py:module:: torch.fx.experimental.migrate_gradual_types.constraint
1137.. py:module:: torch.fx.experimental.migrate_gradual_types.constraint_generator
1138.. py:module:: torch.fx.experimental.migrate_gradual_types.constraint_transformation
1139.. py:module:: torch.fx.experimental.migrate_gradual_types.operation
1140.. py:module:: torch.fx.experimental.migrate_gradual_types.transform_to_z3
1141.. py:module:: torch.fx.experimental.migrate_gradual_types.util
1142.. py:module:: torch.fx.experimental.migrate_gradual_types.z3_types
1143.. py:module:: torch.fx.experimental.normalize
1144.. py:module:: torch.fx.experimental.optimization
1145.. py:module:: torch.fx.experimental.partitioner_utils
1146.. py:module:: torch.fx.experimental.recording
1147.. py:module:: torch.fx.experimental.refinement_types
1148.. py:module:: torch.fx.experimental.rewriter
1149.. py:module:: torch.fx.experimental.schema_type_annotation
1150.. py:module:: torch.fx.experimental.sym_node
1151.. py:module:: torch.fx.experimental.unification.core
1152.. py:module:: torch.fx.experimental.unification.dispatch
1153.. py:module:: torch.fx.experimental.unification.match
1154.. py:module:: torch.fx.experimental.unification.more
1155.. py:module:: torch.fx.experimental.unification.multipledispatch.conflict
1156.. py:module:: torch.fx.experimental.unification.multipledispatch.core
1157.. py:module:: torch.fx.experimental.unification.multipledispatch.dispatcher
1158.. py:module:: torch.fx.experimental.unification.multipledispatch.utils
1159.. py:module:: torch.fx.experimental.unification.multipledispatch.variadic
1160.. py:module:: torch.fx.experimental.unification.unification_tools
1161.. py:module:: torch.fx.experimental.unification.utils
1162.. py:module:: torch.fx.experimental.unification.variable
1163.. py:module:: torch.fx.experimental.unify_refinements
1164.. py:module:: torch.fx.experimental.validator
1165.. py:module:: torch.fx.graph
1166.. py:module:: torch.fx.graph_module
1167.. py:module:: torch.fx.immutable_collections
1168.. py:module:: torch.fx.interpreter
1169.. py:module:: torch.fx.node
1170.. py:module:: torch.fx.operator_schemas
1171.. py:module:: torch.fx.passes.annotate_getitem_nodes
1172.. py:module:: torch.fx.passes.backends.cudagraphs
1173.. py:module:: torch.fx.passes.dialect.common.cse_pass
1174.. py:module:: torch.fx.passes.fake_tensor_prop
1175.. py:module:: torch.fx.passes.graph_drawer
1176.. py:module:: torch.fx.passes.graph_manipulation
1177.. py:module:: torch.fx.passes.graph_transform_observer
1178.. py:module:: torch.fx.passes.infra.partitioner
1179.. py:module:: torch.fx.passes.infra.pass_base
1180.. py:module:: torch.fx.passes.infra.pass_manager
1181.. py:module:: torch.fx.passes.net_min_base
1182.. py:module:: torch.fx.passes.operator_support
1183.. py:module:: torch.fx.passes.param_fetch
1184.. py:module:: torch.fx.passes.pass_manager
1185.. py:module:: torch.fx.passes.reinplace
1186.. py:module:: torch.fx.passes.runtime_assert
1187.. py:module:: torch.fx.passes.shape_prop
1188.. py:module:: torch.fx.passes.split_module
1189.. py:module:: torch.fx.passes.split_utils
1190.. py:module:: torch.fx.passes.splitter_base
1191.. py:module:: torch.fx.passes.tests.test_pass_manager
1192.. py:module:: torch.fx.passes.tools_common
1193.. py:module:: torch.fx.passes.utils.common
1194.. py:module:: torch.fx.passes.utils.fuser_utils
1195.. py:module:: torch.fx.passes.utils.matcher_utils
1196.. py:module:: torch.fx.passes.utils.matcher_with_name_node_map_utils
1197.. py:module:: torch.fx.passes.utils.source_matcher_utils
1198.. py:module:: torch.fx.proxy
1199.. py:module:: torch.fx.subgraph_rewriter
1200.. py:module:: torch.fx.tensor_type
1201.. py:module:: torch.fx.traceback
1202