1.. currentmodule:: torch.fx 2 3torch.fx 4============= 5 6Overview 7-------- 8.. automodule:: torch.fx 9 10.. _Writing Transformations: 11 12 13Writing Transformations 14----------------------- 15 16What is an FX transform? Essentially, it's a function that looks like this. 17 18:: 19 20 import torch 21 import torch.fx 22 23 def transform(m: nn.Module, 24 tracer_class : type = torch.fx.Tracer) -> torch.nn.Module: 25 # Step 1: Acquire a Graph representing the code in `m` 26 27 # NOTE: torch.fx.symbolic_trace is a wrapper around a call to 28 # fx.Tracer.trace and constructing a GraphModule. We'll 29 # split that out in our transform to allow the caller to 30 # customize tracing behavior. 31 graph : torch.fx.Graph = tracer_class().trace(m) 32 33 # Step 2: Modify this Graph or create a new one 34 graph = ... 35 36 # Step 3: Construct a Module to return 37 return torch.fx.GraphModule(m, graph) 38 39Your transform will take in a :class:`torch.nn.Module`, acquire a :class:`Graph` 40from it, do some modifications, and return a new 41:class:`torch.nn.Module`. You should think of the :class:`torch.nn.Module` that your FX 42transform returns as identical to a regular :class:`torch.nn.Module` -- you can pass it to another 43FX transform, you can pass it to TorchScript, or you can 44run it. Ensuring that the inputs and outputs of your FX transform are a 45:class:`torch.nn.Module` will allow for composability. 46 47.. note:: 48 49 It is also possible to modify an existing :class:`GraphModule` instead of 50 creating a new one, like so:: 51 52 import torch 53 import torch.fx 54 55 def transform(m : nn.Module) -> nn.Module: 56 gm : torch.fx.GraphModule = torch.fx.symbolic_trace(m) 57 58 # Modify gm.graph 59 # <...> 60 61 # Recompile the forward() method of `gm` from its Graph 62 gm.recompile() 63 64 return gm 65 66 Note that you MUST call :meth:`GraphModule.recompile` to bring the generated 67 ``forward()`` method on the ``GraphModule`` in sync with the modified :class:`Graph`. 68 69Given that you’ve passed in a :class:`torch.nn.Module` that has been traced into a 70:class:`Graph`, there are now two primary approaches you can take to building a new 71:class:`Graph`. 72 73A Quick Primer on Graphs 74^^^^^^^^^^^^^^^^^^^^^^^^ 75 76Full treatment of the semantics of graphs can be found in the :class:`Graph` 77documentation, but we are going to cover the basics here. A :class:`Graph` is 78a data structure that represents a method on a :class:`GraphModule`. The 79information that this requires is: 80 81- What are the inputs to the method? 82- What are the operations that run inside the method? 83- What is the output (i.e. return) value from the method? 84 85All three of these concepts are represented with :class:`Node` instances. 86Let's see what we mean by that with a short example: 87 88:: 89 90 import torch 91 import torch.fx 92 93 class MyModule(torch.nn.Module): 94 def __init__(self): 95 super().__init__() 96 self.param = torch.nn.Parameter(torch.rand(3, 4)) 97 self.linear = torch.nn.Linear(4, 5) 98 99 def forward(self, x): 100 return torch.topk(torch.sum( 101 self.linear(x + self.linear.weight).relu(), dim=-1), 3) 102 103 m = MyModule() 104 gm = torch.fx.symbolic_trace(m) 105 106 gm.graph.print_tabular() 107 108Here we define a module ``MyModule`` for demonstration purposes, instantiate it, 109symbolically trace it, then call the :meth:`Graph.print_tabular` method to print 110out a table showing the nodes of this :class:`Graph`: 111 112 +---------------+---------------+----------------------------+--------------------+-------------+ 113 | opcode | name | target | args | kwargs | 114 +===============+===============+============================+====================+=============+ 115 | placeholder | x | x | () | {} | 116 +---------------+---------------+----------------------------+--------------------+-------------+ 117 | get_attr | linear_weight | linear.weight | () | {} | 118 +---------------+---------------+----------------------------+--------------------+-------------+ 119 | call_function | add_1 | <built-in function add> | (x, linear_weight) | {} | 120 +---------------+---------------+----------------------------+--------------------+-------------+ 121 | call_module | linear_1 | linear | (add_1,) | {} | 122 +---------------+---------------+----------------------------+--------------------+-------------+ 123 | call_method | relu_1 | relu | (linear_1,) | {} | 124 +---------------+---------------+----------------------------+--------------------+-------------+ 125 | call_function | sum_1 | <built-in method sum ...> | (relu_1,) | {'dim': -1} | 126 +---------------+---------------+----------------------------+--------------------+-------------+ 127 | call_function | topk_1 | <built-in method topk ...> | (sum_1, 3) | {} | 128 +---------------+---------------+----------------------------+--------------------+-------------+ 129 | output | output | output | (topk_1,) | {} | 130 +---------------+---------------+----------------------------+--------------------+-------------+ 131 132We can use this information to answer the questions we posed above. 133 134- What are the inputs to the method? In FX, method inputs are specified 135 via special ``placeholder`` nodes. In this case, we have a single 136 ``placeholder`` node with a ``target`` of ``x``, meaning we have 137 a single (non-self) argument named x. 138- What are the operations within the method? The ``get_attr``, 139 ``call_function``, ``call_module``, and ``call_method`` nodes 140 represent the operations in the method. A full treatment of 141 the semantics of all of these can be found in the :class:`Node` 142 documentation. 143- What is the return value of the method? The return value in a 144 :class:`Graph` is specified by a special ``output`` node. 145 146Given that we now know the basics of how code is represented in 147FX, we can now explore how we would edit a :class:`Graph`. 148 149Graph Manipulation 150^^^^^^^^^^^^^^^^^^ 151 152Direct Graph Manipulation 153~~~~~~~~~~~~~~~~~~~~~~~~~ 154 155One approach to building this new :class:`Graph` is to directly manipulate your old 156one. To aid in this, we can simply take the :class:`Graph` we obtain from symbolic 157tracing and modify it. For example, let’s say we desire to replace 158:func:`torch.add` calls with :func:`torch.mul` calls. 159 160:: 161 162 import torch 163 import torch.fx 164 165 # Sample module 166 class M(torch.nn.Module): 167 def forward(self, x, y): 168 return torch.add(x, y) 169 170 def transform(m: torch.nn.Module, 171 tracer_class : type = fx.Tracer) -> torch.nn.Module: 172 graph : fx.Graph = tracer_class().trace(m) 173 # FX represents its Graph as an ordered list of 174 # nodes, so we can iterate through them. 175 for node in graph.nodes: 176 # Checks if we're calling a function (i.e: 177 # torch.add) 178 if node.op == 'call_function': 179 # The target attribute is the function 180 # that call_function calls. 181 if node.target == torch.add: 182 node.target = torch.mul 183 184 graph.lint() # Does some checks to make sure the 185 # Graph is well-formed. 186 187 return fx.GraphModule(m, graph) 188 189 190We can also do more involved :class:`Graph` rewrites, such as 191deleting or appending nodes. To aid in these transformations, 192FX has utility functions for transforming the graph that can 193be found in the :class:`Graph` documentation. An 194example of using these APIs to append a :func:`torch.relu` call 195can be found below. 196 197:: 198 199 # Specifies the insertion point. Any nodes added to the 200 # Graph within this scope will be inserted after `node` 201 with traced.graph.inserting_after(node): 202 # Insert a new `call_function` node calling `torch.relu` 203 new_node = traced.graph.call_function( 204 torch.relu, args=(node,)) 205 206 # We want all places that used the value of `node` to 207 # now use that value after the `relu` call we've added. 208 # We use the `replace_all_uses_with` API to do this. 209 node.replace_all_uses_with(new_node) 210 211For simple transformations that only consist of substitutions, you can also 212make use of the `subgraph rewriter. <https://github.com/pytorch/pytorch/blob/main/torch/fx/subgraph_rewriter.py>`__ 213 214Subgraph Rewriting With replace_pattern() 215~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 216 217FX also provides another level of automation on top of direct graph manipulation. 218The :func:`replace_pattern` API is essentially a "find/replace" tool for editing 219:class:`Graph`\s. It allows you to specify a ``pattern`` and ``replacement`` function 220and it will trace through those functions, find instances of the group of operations 221in the ``pattern`` graph, and replace those instances with copies of the ``replacement`` 222graph. This can help to greatly automate tedious graph manipulation code, which can 223get unwieldy as the transformations get more complex. 224 225Graph Manipulation Examples 226~~~~~~~~~~~~~~~~~~~~~~~~~~~ 227 228- `Replace one 229 op <https://github.com/pytorch/examples/blob/master/fx/replace_op.py>`__ 230- `Conv/Batch Norm 231 fusion <https://github.com/pytorch/pytorch/blob/40cbf342d3c000712da92cfafeaca651b3e0bd3e/torch/fx/experimental/optimization.py#L50>`__ 232- `replace_pattern: Basic usage <https://github.com/pytorch/examples/blob/master/fx/subgraph_rewriter_basic_use.py>`__ 233- `Quantization <https://pytorch.org/docs/main/quantization.html#prototype-fx-graph-mode-quantization>`__ 234- `Invert Transformation <https://github.com/pytorch/examples/blob/master/fx/invert.py>`__ 235 236Proxy/Retracing 237^^^^^^^^^^^^^^^ 238 239Another way of manipulating :class:`Graph`\s is by reusing the :class:`Proxy` 240machinery used in symbolic tracing. For example, let’s 241imagine that we wanted to write a transformation that decomposed 242PyTorch functions into smaller operations. It would transform every 243``F.relu(x)`` call into ``(x > 0) * x``. One possibility would be to 244perform the requisite graph rewriting to insert the comparison and 245multiplication after the ``F.relu``, and then clean up the original 246``F.relu``. However, we can automate this process by using :class:`Proxy` 247objects to automatically record operations into the :class:`Graph`. 248 249To use this method, we write the operations that we want inserted as regular 250PyTorch code and invoke that code with :class:`Proxy` objects as arguments. 251These :class:`Proxy` objects will capture the operations that are performed 252on them and append them to the :class:`Graph`. 253 254:: 255 256 # Note that this decomposition rule can be read as regular Python 257 def relu_decomposition(x): 258 return (x > 0) * x 259 260 decomposition_rules = {} 261 decomposition_rules[F.relu] = relu_decomposition 262 263 def decompose(model: torch.nn.Module, 264 tracer_class : type = fx.Tracer) -> torch.nn.Module: 265 """ 266 Decompose `model` into smaller constituent operations. 267 Currently,this only supports decomposing ReLU into its 268 mathematical definition: (x > 0) * x 269 """ 270 graph : fx.Graph = tracer_class().trace(model) 271 new_graph = fx.Graph() 272 env = {} 273 tracer = torch.fx.proxy.GraphAppendingTracer(new_graph) 274 for node in graph.nodes: 275 if node.op == 'call_function' and node.target in decomposition_rules: 276 # By wrapping the arguments with proxies, 277 # we can dispatch to the appropriate 278 # decomposition rule and implicitly add it 279 # to the Graph by symbolically tracing it. 280 proxy_args = [ 281 fx.Proxy(env[x.name], tracer) if isinstance(x, fx.Node) else x for x in node.args] 282 output_proxy = decomposition_rules[node.target](*proxy_args) 283 284 # Operations on `Proxy` always yield new `Proxy`s, and the 285 # return value of our decomposition rule is no exception. 286 # We need to extract the underlying `Node` from the `Proxy` 287 # to use it in subsequent iterations of this transform. 288 new_node = output_proxy.node 289 env[node.name] = new_node 290 else: 291 # Default case: we don't have a decomposition rule for this 292 # node, so just copy the node over into the new graph. 293 new_node = new_graph.node_copy(node, lambda x: env[x.name]) 294 env[node.name] = new_node 295 return fx.GraphModule(model, new_graph) 296 297In addition to avoiding explicit graph manipulation, using :class:`Proxy`\s 298also allows you to specify your rewrite rules as native Python code. 299For transformations that require a large amount of rewrite rules 300(such as vmap or grad), this can often improve readability and 301maintainability of the rules. Note that while calling :class:`Proxy` we also 302passed a tracer pointing to the underlying variable `graph`. This is done so 303if in case the operations in graph are n-ary (e.g. add is a binary operator) 304the call to :class:`Proxy` does not create multiple instances of a graph 305tracer which can lead to unexpected runtime errors. We recommend this method 306of using :class:`Proxy` especially when the underlying operators can not be 307safely assumed to be unary. 308 309A worked example of using :class:`Proxy`\s for :class:`Graph` manipulation 310can be found 311`here <https://github.com/pytorch/examples/blob/master/fx/proxy_based_graph_creation.py>`__. 312 313The Interpreter Pattern 314^^^^^^^^^^^^^^^^^^^^^^^ 315 316A useful code organizational pattern in FX is to loop over all the :class:`Node`\s 317in a :class:`Graph` and execute them. This can be used for several things including 318runtime analysis of values flowing through the graph or transformation of the code 319via retracing with :class:`Proxy`\s. For example, suppose we want to run a 320:class:`GraphModule` and record the :class:`torch.Tensor` shape and dtype 321properties on the nodes as we see them at runtime. That might look like: 322 323:: 324 325 import torch 326 import torch.fx 327 from torch.fx.node import Node 328 329 from typing import Dict 330 331 class ShapeProp: 332 """ 333 Shape propagation. This class takes a `GraphModule`. 334 Then, its `propagate` method executes the `GraphModule` 335 node-by-node with the given arguments. As each operation 336 executes, the ShapeProp class stores away the shape and 337 element type for the output values of each operation on 338 the `shape` and `dtype` attributes of the operation's 339 `Node`. 340 """ 341 def __init__(self, mod): 342 self.mod = mod 343 self.graph = mod.graph 344 self.modules = dict(self.mod.named_modules()) 345 346 def propagate(self, *args): 347 args_iter = iter(args) 348 env : Dict[str, Node] = {} 349 350 def load_arg(a): 351 return torch.fx.graph.map_arg(a, lambda n: env[n.name]) 352 353 def fetch_attr(target : str): 354 target_atoms = target.split('.') 355 attr_itr = self.mod 356 for i, atom in enumerate(target_atoms): 357 if not hasattr(attr_itr, atom): 358 raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}") 359 attr_itr = getattr(attr_itr, atom) 360 return attr_itr 361 362 for node in self.graph.nodes: 363 if node.op == 'placeholder': 364 result = next(args_iter) 365 elif node.op == 'get_attr': 366 result = fetch_attr(node.target) 367 elif node.op == 'call_function': 368 result = node.target(*load_arg(node.args), **load_arg(node.kwargs)) 369 elif node.op == 'call_method': 370 self_obj, *args = load_arg(node.args) 371 kwargs = load_arg(node.kwargs) 372 result = getattr(self_obj, node.target)(*args, **kwargs) 373 elif node.op == 'call_module': 374 result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs)) 375 376 # This is the only code specific to shape propagation. 377 # you can delete this `if` branch and this becomes 378 # a generic GraphModule interpreter. 379 if isinstance(result, torch.Tensor): 380 node.shape = result.shape 381 node.dtype = result.dtype 382 383 env[node.name] = result 384 385 return load_arg(self.graph.result) 386 387As you can see, a full interpreter for FX is not that complicated 388but it can be very useful. To ease using this pattern, we provide 389the :class:`Interpreter` class, which encompasses the above logic 390in a way that certain aspects of the interpreter's execution can 391be overridden via method overrides. 392 393In addition to executing operations, we can also generate a new 394`Graph` by feeding :class:`Proxy` values through an interpreter. 395Similarly, we provide the :class:`Transformer` class to encompass 396this pattern. :class:`Transformer` behaves similarly to 397:class:`Interpreter`, but instead of calling the ``run`` method to 398get a concrete output value from the Module, you would call the 399:meth:`Transformer.transform` method to return a new 400:class:`GraphModule` which was subject to any transformation rules 401you installed as overridden methods. 402 403Examples of the Interpreter Pattern 404~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 405 406- `Shape 407 Propagation <https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py>`__ 408- `Performance Profiler <https://github.com/pytorch/tutorials/pull/1319>`__ 409 410 411Debugging 412----------- 413 414Introduction 415^^^^^^^^^^^^^^^^ 416 417Often in the course of authoring transformations, our code will not be quite right. 418In this case, we may need to do some debugging. The key is to work 419backwards: first, check the results of invoking the generated module to prove or 420disprove correctness. Then, inspect and debug the generated code. Then, debug the 421process of transformations that led to the generated code. 422 423If you’re not familiar with debuggers, please see the auxiliary section 424:ref:`Available Debuggers`. 425 426 427Common Pitfalls in Transform Authoring 428^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 429 430* Nondeterministic ``set`` iteration order. In Python, the ``set`` datatype is 431 unordered. Using ``set`` to contain collections of objects like ``Node``\ s, 432 for example, can cause unexpected nondeterminism. An example is iterating 433 over a set of ``Node``\ s to insert them into a ``Graph``. Because the 434 ``set`` data type is unordered, the ordering of the operations in the output 435 program will be nondeterministic and can change across program invocations. 436 The recommended alternative is to use a ``dict`` data type, which is 437 `insertion ordered <https://mail.python.org/pipermail/python-dev/2017-December/151283.html>`_ 438 as of Python 3.7 (and as of cPython 3.6). A ``dict`` can be used equivalently 439 to a set by storing values to be deduplicated in the keys of the ``dict``. 440 441Checking Correctness of Modules 442^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 443 444Because the output of most deep learning modules consists of floating 445point :class:`torch.Tensor` instances, checking for equivalence between 446the results of two :class:`torch.nn.Module` is not as straightforward 447as doing a simple equality check. To motivate this, let's use an 448example: 449 450:: 451 452 import torch 453 import torch.fx 454 import torchvision.models as models 455 456 def transform(m : torch.nn.Module) -> torch.nn.Module: 457 gm = torch.fx.symbolic_trace(m) 458 459 # Imagine we're doing some transforms here 460 # <...> 461 462 gm.recompile() 463 464 return gm 465 466 resnet18 = models.resnet18() 467 transformed_resnet18 = transform(resnet18) 468 469 input_image = torch.randn(5, 3, 224, 224) 470 471 assert resnet18(input_image) == transformed_resnet18(input_image) 472 """ 473 RuntimeError: Boolean value of Tensor with more than one value is ambiguous 474 """ 475 476Here, we've tried to check equality of the values of two deep learning 477models with the ``==`` equality operator. However, this is not well- 478defined both due to the issue of that operator returning a tensor 479and not a bool, but also because comparison of floating point values 480should use a margin of error (or epsilon) to account for the 481non-commutativity of floating point operations (see 482`here <https://floating-point-gui.de/errors/comparison/>`__ for more 483details). We can use :func:`torch.allclose` instead, which will give 484us an approximate comparison taking into account a relative and 485absolute tolerance threshold: 486 487:: 488 489 assert torch.allclose(resnet18(input_image), transformed_resnet18(input_image)) 490 491This is the first tool in our toolbox to check if transformed modules are 492behaving as we expect compared to a reference implementation. 493 494Debugging the Generated Code 495^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 496 497Because FX generates the ``forward()`` function on :class:`GraphModule`\s, using 498traditional debugging techniques like ``print`` statements or ``pdb`` is 499not as straightforward. Luckily, we have several techniques we can use 500for debugging the generated code. 501 502Use ``pdb`` 503~~~~~~~~~~~~~ 504Invoke ``pdb`` to step into the running program. Although the code that 505represents the :class:`Graph` is not in any source file, we can still step 506into it manually using ``pdb`` when the forward pass is invoked. 507 508:: 509 510 import torch 511 import torch.fx 512 import torchvision.models as models 513 514 def my_pass(inp: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module: 515 graph = tracer_class().trace(inp) 516 # Transformation logic here 517 # <...> 518 519 # Return new Module 520 return fx.GraphModule(inp, graph) 521 522 my_module = models.resnet18() 523 my_module_transformed = my_pass(my_module) 524 525 input_value = torch.randn(5, 3, 224, 224) 526 527 # When this line is executed at runtime, we will be dropped into an 528 # interactive `pdb` prompt. We can use the `step` or `s` command to 529 # step into the execution of the next line 530 import pdb; pdb.set_trace() 531 532 my_module_transformed(input_value) 533 534.. _Print the Generated Code: 535 536Print the Generated Code 537~~~~~~~~~~~~~~~~~~~~~~~~~~~ 538If you’d like to run the same code multiple times, then it can be 539a bit tedious to step to the right code with ``pdb``. In that case, one 540approach is to simply copy-paste the generated ``forward`` pass into 541your code and examine it from there. 542 543:: 544 545 # Assume that `traced` is a GraphModule that has undergone some 546 # number of transforms 547 548 # Copy this code for later 549 print(traced) 550 # Print the code generated from symbolic tracing. This outputs: 551 """ 552 def forward(self, y): 553 x = self.x 554 add_1 = x + y; x = y = None 555 return add_1 556 """ 557 558 # Subclass the original Module 559 class SubclassM(M): 560 def __init__(self): 561 super().__init__() 562 563 # Paste the generated `forward` function (the one we printed and 564 # copied above) here 565 def forward(self, y): 566 x = self.x 567 add_1 = x + y; x = y = None 568 return add_1 569 570 # Create an instance of the original, untraced Module. Then, create an 571 # instance of the Module with the copied `forward` function. We can 572 # now compare the output of both the original and the traced version. 573 pre_trace = M() 574 post_trace = SubclassM() 575 576Use the ``to_folder`` Function From ``GraphModule`` 577~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 578:meth:`GraphModule.to_folder` is a method in ``GraphModule`` that allows 579you to dump out the generated FX code to a folder. Although copying the 580forward pass into the code often suffices as in :ref:`Print the Generated Code`, 581it may be easier to examine modules and parameters using ``to_folder``. 582 583:: 584 585 m = symbolic_trace(M()) 586 m.to_folder("foo", "Bar") 587 from foo import Bar 588 y = Bar() 589 590After running the above example, we can then look at the code within 591``foo/module.py`` and modify it as desired (e.g. adding ``print`` 592statements or using ``pdb``) to debug the generated code. 593 594Debugging the Transformation 595^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 596 597Now that we've identified that a transformation is creating incorrect 598code, it's time to debug the transformation itself. First, we'll check 599the :ref:`Limitations of Symbolic Tracing` section in the documentation. 600Once we verify that tracing is working as expected, the goal 601becomes figuring out what went wrong during our ``GraphModule`` 602transformation. There may be a quick answer in 603:ref:`Writing Transformations`, but, if not, there are several ways to 604examine our traced module: 605 606:: 607 608 # Sample Module 609 class M(torch.nn.Module): 610 def forward(self, x, y): 611 return x + y 612 613 # Create an instance of `M` 614 m = M() 615 616 # Symbolically trace an instance of `M` (returns a GraphModule). In 617 # this example, we'll only be discussing how to inspect a 618 # GraphModule, so we aren't showing any sample transforms for the 619 # sake of brevity. 620 traced = symbolic_trace(m) 621 622 # Print the code produced by tracing the module. 623 print(traced) 624 # The generated `forward` function is: 625 """ 626 def forward(self, x, y): 627 add = x + y; x = y = None 628 return add 629 """ 630 631 # Print the internal Graph. 632 print(traced.graph) 633 # This print-out returns: 634 """ 635 graph(): 636 %x : [num_users=1] = placeholder[target=x] 637 %y : [num_users=1] = placeholder[target=y] 638 %add : [num_users=1] = call_function[target=operator.add](args = (%x, %y), kwargs = {}) 639 return add 640 """ 641 642 # Print a tabular representation of the internal Graph. 643 traced.graph.print_tabular() 644 # This gives us: 645 """ 646 opcode name target args kwargs 647 ------------- ------ ----------------------- ------ -------- 648 placeholder x x () {} 649 placeholder y y () {} 650 call_function add <built-in function add> (x, y) {} 651 output output output (add,) {} 652 """ 653 654Using the utility functions above, we can compare our traced Module 655before and after we've applied our transformations. Sometimes, a 656simple visual comparison is enough to trace down a bug. If it's still 657not clear what's going wrong, a debugger like ``pdb`` can be a good 658next step. 659 660Going off of the example above, consider the following code: 661 662:: 663 664 # Sample user-defined function 665 def transform_graph(module: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module: 666 # Get the Graph from our traced Module 667 g = tracer_class().trace(module) 668 669 """ 670 Transformations on `g` go here 671 """ 672 673 return fx.GraphModule(module, g) 674 675 # Transform the Graph 676 transformed = transform_graph(traced) 677 678 # Print the new code after our transforms. Check to see if it was 679 # what we expected 680 print(transformed) 681 682Using the above example, let’s say that the call to ``print(traced)`` 683showed us that there was an error in our transforms. We want to find 684what goes wrong using a debugger. We start a ``pdb`` session. We can see 685what’s happening during the transform by breaking on 686``transform_graph(traced)``, then pressing ``s`` to “step into” the call 687to ``transform_graph(traced)``. 688 689We may also have good luck by editing the ``print_tabular`` method to print 690different attributes of the Nodes in the Graph. (For example, we might 691want to see the Node’s ``input_nodes`` and ``users``.) 692 693.. _Available Debuggers: 694 695Available Debuggers 696^^^^^^^^^^^^^^^^^^^^^^ 697 698The most common Python debugger is 699`pdb <https://docs.python.org/3/library/pdb.html>`__. You can start 700your program in “debug mode” with ``pdb`` by typing 701``python -m pdb FILENAME.py`` into the command line, where ``FILENAME`` 702is the name of the file you want to debug. After that, you can use the 703``pdb`` `debugger commands 704<https://docs.python.org/3/library/pdb.html#debugger-commands>`__ 705to move through your running program stepwise. It’s common to set a 706breakpoint (``b LINE-NUMBER``) when you start ``pdb``, then call ``c`` to 707run the program until that point. This prevents you from having to step 708through each line of execution (using ``s`` or ``n``) to get to the part 709of the code you want to examine. Alternatively, you can write 710``import pdb; pdb.set_trace()`` before the line you want to break at. 711If you add ``pdb.set_trace()``, your program will automatically start 712in debug mode when you run it. (In other words, you can just type 713``python FILENAME.py`` into the command line instead of 714``python -m pdb FILENAME.py``.) Once you're running your file in 715debug mode, you can step through the code and examine your program's 716internal state using certain commands. There are many excellent 717tutorials on ``pdb`` online, including RealPython’s 718`“Python Debugging With Pdb” <https://realpython.com/python-debugging-pdb/>`__. 719 720IDEs like PyCharm or VSCode usually have a debugger built in. In your 721IDE, you can choose to either a) use ``pdb`` by pulling up a terminal 722window in your IDE (e.g. View → Terminal in VSCode), or b) use the 723built-in debugger (usually a graphical wrapper around ``pdb``). 724 725.. _Limitations of Symbolic Tracing: 726 727Limitations of Symbolic Tracing 728------------------------------- 729 730FX uses a system of **symbolic tracing** (a.k.a `symbolic 731execution <https://en.wikipedia.org/wiki/Symbolic_execution>`__) 732to capture the semantics of programs in a transformable/analyzable form. 733The system is **tracing** in that it executes the program (really a 734:class:`torch.nn.Module` or function) to record operations. It is 735**symbolic** in that the data flowing through the program during this 736execution is not real data, but rather symbols (:class:`Proxy` in FX parlance). 737 738Although symbolic tracing works for most neural net code, it has some 739limitations. 740 741Dynamic Control Flow 742^^^^^^^^^^^^^^^^^^^^ 743 744The main limitation of symbolic tracing is it does not currently support 745*dynamic control flow*. That is, loops or ``if`` statements where the 746condition may depend on the input values of the program. 747 748For example, let’s examine the following program: 749 750:: 751 752 def func_to_trace(x): 753 if x.sum() > 0: 754 return torch.relu(x) 755 else: 756 return torch.neg(x) 757 758 traced = torch.fx.symbolic_trace(func_to_trace) 759 """ 760 <...> 761 File "dyn.py", line 6, in func_to_trace 762 if x.sum() > 0: 763 File "pytorch/torch/fx/proxy.py", line 155, in __bool__ 764 return self.tracer.to_bool(self) 765 File "pytorch/torch/fx/proxy.py", line 85, in to_bool 766 raise TraceError('symbolically traced variables cannot be used as inputs to control flow') 767 torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow 768 """ 769 770The condition to the ``if`` statement relies on the value of ``x.sum()``, 771which relies on the value of ``x``, a function input. Since 772``x`` can change (i.e. if you pass a new input tensor to the traced 773function), this is *dynamic control flow*. The traceback walks back up 774through your code to show you where this situation happens. 775 776Static Control Flow 777~~~~~~~~~~~~~~~~~~~ 778 779On the other hand, so-called *static control flow* is supported. Static 780control flow is loops or ``if`` statements whose value cannot change 781across invocations. Typically, in PyTorch programs, this control flow 782arises for code making decisions about a model’s architecture based on 783hyper-parameters. As a concrete example: 784 785:: 786 787 import torch 788 import torch.fx 789 790 class MyModule(torch.nn.Module): 791 def __init__(self, do_activation : bool = False): 792 super().__init__() 793 self.do_activation = do_activation 794 self.linear = torch.nn.Linear(512, 512) 795 796 def forward(self, x): 797 x = self.linear(x) 798 # This if-statement is so-called static control flow. 799 # Its condition does not depend on any input values 800 if self.do_activation: 801 x = torch.relu(x) 802 return x 803 804 without_activation = MyModule(do_activation=False) 805 with_activation = MyModule(do_activation=True) 806 807 traced_without_activation = torch.fx.symbolic_trace(without_activation) 808 print(traced_without_activation.code) 809 """ 810 def forward(self, x): 811 linear_1 = self.linear(x); x = None 812 return linear_1 813 """ 814 815 traced_with_activation = torch.fx.symbolic_trace(with_activation) 816 print(traced_with_activation.code) 817 """ 818 import torch 819 def forward(self, x): 820 linear_1 = self.linear(x); x = None 821 relu_1 = torch.relu(linear_1); linear_1 = None 822 return relu_1 823 """ 824 825The if-statement ``if self.do_activation`` does not depend on any 826function inputs, thus it is static. ``do_activation`` can be considered 827to be a hyper-parameter, and the traces of different instances of 828``MyModule`` with different values for that parameter have different 829code. This is a valid pattern that is supported by symbolic tracing. 830 831Many instances of dynamic control flow are semantically static control 832flow. These instances can be made to support symbolic tracing by 833removing the data dependencies on input values, for example by moving 834values to ``Module`` attributes or by binding concrete values to arguments 835during symbolic tracing: 836 837:: 838 839 def f(x, flag): 840 if flag: return x 841 else: return x*2 842 843 fx.symbolic_trace(f) # Fails! 844 845 fx.symbolic_trace(f, concrete_args={'flag': True}) 846 847In the case of truly dynamic control flow, the sections of the program 848that contain this code can be traced as calls to the Method (see 849:ref:`Customizing Tracing`) or function (see 850:func:`wrap`) rather than tracing through them. 851 852Non-\ ``torch`` Functions 853^^^^^^^^^^^^^^^^^^^^^^^^^ 854 855FX uses ``__torch_function__`` as the mechanism by which it intercepts 856calls (see the `technical 857overview <https://github.com/pytorch/pytorch/blob/master/torch/fx/OVERVIEW.md#technical-details>`__ 858for more information about this). Some functions, such as builtin Python 859functions or those in the ``math`` module, are not covered by 860``__torch_function__``, but we would still like to capture them in 861symbolic tracing. For example: 862 863:: 864 865 import torch 866 import torch.fx 867 from math import sqrt 868 869 def normalize(x): 870 """ 871 Normalize `x` by the size of the batch dimension 872 """ 873 return x / sqrt(len(x)) 874 875 # It's valid Python code 876 normalize(torch.rand(3, 4)) 877 878 traced = torch.fx.symbolic_trace(normalize) 879 """ 880 <...> 881 File "sqrt.py", line 9, in normalize 882 return x / sqrt(len(x)) 883 File "pytorch/torch/fx/proxy.py", line 161, in __len__ 884 raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want " 885 RuntimeError: 'len' is not supported in symbolic tracing by default. If you want this call to be recorded, please call torch.fx.wrap('len') at module scope 886 """ 887 888The error tells us that the built-in function ``len`` is not supported. 889We can make it so that functions like this are recorded in the trace as 890direct calls using the :func:`wrap` API: 891 892:: 893 894 torch.fx.wrap('len') 895 torch.fx.wrap('sqrt') 896 897 traced = torch.fx.symbolic_trace(normalize) 898 899 print(traced.code) 900 """ 901 import math 902 def forward(self, x): 903 len_1 = len(x) 904 sqrt_1 = math.sqrt(len_1); len_1 = None 905 truediv = x / sqrt_1; x = sqrt_1 = None 906 return truediv 907 """ 908 909.. _Customizing Tracing: 910 911Customizing Tracing with the ``Tracer`` class 912^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 913 914The :class:`Tracer` class is the class that underlies the 915implementation of ``symbolic_trace``. The behavior of tracing can be 916customized by subclassing Tracer, like so: 917 918:: 919 920 class MyCustomTracer(torch.fx.Tracer): 921 # Inside here you can override various methods 922 # to customize tracing. See the `Tracer` API 923 # reference 924 pass 925 926 927 # Let's use this custom tracer to trace through this module 928 class MyModule(torch.nn.Module): 929 def forward(self, x): 930 return torch.relu(x) + torch.ones(3, 4) 931 932 mod = MyModule() 933 934 traced_graph = MyCustomTracer().trace(mod) 935 # trace() returns a Graph. Let's wrap it up in a 936 # GraphModule to make it runnable 937 traced = torch.fx.GraphModule(mod, traced_graph) 938 939Leaf Modules 940~~~~~~~~~~~~ 941 942Leaf Modules are the modules that appear as calls in the symbolic trace 943rather than being traced through. The default set of leaf modules is the 944set of standard ``torch.nn`` module instances. For example: 945 946:: 947 948 class MySpecialSubmodule(torch.nn.Module): 949 def forward(self, x): 950 return torch.neg(x) 951 952 class MyModule(torch.nn.Module): 953 def __init__(self): 954 super().__init__() 955 self.linear = torch.nn.Linear(3, 4) 956 self.submod = MySpecialSubmodule() 957 958 def forward(self, x): 959 return self.submod(self.linear(x)) 960 961 traced = torch.fx.symbolic_trace(MyModule()) 962 print(traced.code) 963 # `linear` is preserved as a call, yet `submod` is traced though. 964 # This is because the default set of "Leaf Modules" includes all 965 # standard `torch.nn` modules. 966 """ 967 import torch 968 def forward(self, x): 969 linear_1 = self.linear(x); x = None 970 neg_1 = torch.neg(linear_1); linear_1 = None 971 return neg_1 972 """ 973 974The set of leaf modules can be customized by overriding 975:meth:`Tracer.is_leaf_module`. 976 977Miscellanea 978^^^^^^^^^^^ 979 980- Tensor constructors (e.g. ``torch.zeros``, ``torch.ones``, 981 ``torch.rand``, ``torch.randn``, ``torch.sparse_coo_tensor``) 982 are currently not traceable. 983 984 - The deterministic constructors (``zeros``, ``ones``) can be used 985 and the value they produce will be embedded in the trace as a 986 constant. This is only problematic if the arguments to these 987 constructors refers to dynamic input sizes. In this case, 988 ``ones_like`` or ``zeros_like`` may be a viable substitute. 989 - Nondeterministic constructors (``rand``, ``randn``) will have a 990 single random value embedded in the trace. This is likely not the 991 intended behavior. One workaround is to wrap ``torch.randn`` in a ``torch.fx.wrap`` function and call that instead. 992 993 :: 994 995 @torch.fx.wrap 996 def torch_randn(x, shape): 997 return torch.randn(shape) 998 999 def f(x): 1000 return x + torch_randn(x, 5) 1001 fx.symbolic_trace(f) 1002 1003 - This behavior may be fixed in a future release. 1004 1005- Type annotations 1006 1007 - Python 3-style type annotations (e.g. 1008 ``func(x : torch.Tensor, y : int) -> torch.Tensor``) are supported 1009 and will be preserved by symbolic tracing. 1010 - Python 2-style comment type annotations 1011 ``# type: (torch.Tensor, int) -> torch.Tensor`` are not currently 1012 supported. 1013 - Annotations on local names within a function are not currently 1014 supported. 1015 1016 1017- Gotcha around ``training`` flag and submodules 1018 1019 - When using functionals like ``torch.nn.functional.dropout``, it will be common for the training argument to be passed in as ``self.training``. During FX tracing, this will likely be baked in as a constant value. 1020 1021 :: 1022 1023 import torch 1024 import torch.fx 1025 1026 class DropoutRepro(torch.nn.Module): 1027 def forward(self, x): 1028 return torch.nn.functional.dropout(x, training=self.training) 1029 1030 1031 traced = torch.fx.symbolic_trace(DropoutRepro()) 1032 print(traced.code) 1033 """ 1034 def forward(self, x): 1035 dropout = torch.nn.functional.dropout(x, p = 0.5, training = True, inplace = False); x = None 1036 return dropout 1037 """ 1038 1039 traced.eval() 1040 1041 x = torch.randn(5, 3) 1042 torch.testing.assert_close(traced(x), x) 1043 """ 1044 AssertionError: Tensor-likes are not close! 1045 1046 Mismatched elements: 15 / 15 (100.0%) 1047 Greatest absolute difference: 1.6207983493804932 at index (0, 2) (up to 1e-05 allowed) 1048 Greatest relative difference: 1.0 at index (0, 0) (up to 0.0001 allowed) 1049 """ 1050 1051 - However, when the standard ``nn.Dropout()`` submodule is used, the training flag is encapsulated and--because of the preservation of the ``nn.Module`` object model--can be changed. 1052 1053 :: 1054 1055 class DropoutRepro2(torch.nn.Module): 1056 def __init__(self): 1057 super().__init__() 1058 self.drop = torch.nn.Dropout() 1059 1060 def forward(self, x): 1061 return self.drop(x) 1062 1063 traced = torch.fx.symbolic_trace(DropoutRepro2()) 1064 print(traced.code) 1065 """ 1066 def forward(self, x): 1067 drop = self.drop(x); x = None 1068 return drop 1069 """ 1070 1071 traced.eval() 1072 1073 x = torch.randn(5, 3) 1074 torch.testing.assert_close(traced(x), x) 1075 1076 - Because of this difference, consider marking modules that interact with the ``training`` flag dynamically as leaf modules. 1077 1078 1079API Reference 1080------------- 1081 1082.. autofunction:: torch.fx.symbolic_trace 1083 1084.. autofunction:: torch.fx.wrap 1085 1086.. autoclass:: torch.fx.GraphModule 1087 :members: 1088 1089 .. automethod:: __init__ 1090 1091.. autoclass:: torch.fx.Graph 1092 :members: 1093 1094 .. automethod:: __init__ 1095 1096.. autoclass:: torch.fx.Node 1097 :members: 1098 1099.. autoclass:: torch.fx.Tracer 1100 :members: 1101 :inherited-members: 1102 1103.. autoclass:: torch.fx.Proxy 1104 1105.. autoclass:: torch.fx.Interpreter 1106 :members: 1107 1108.. autoclass:: torch.fx.Transformer 1109 :members: 1110 1111.. autofunction:: torch.fx.replace_pattern 1112 1113 1114.. The experimental and passes submodules are missing docs. 1115.. Adding it here for coverage but this doesn't add anything to the 1116.. rendered doc. 1117.. py:module:: torch.fx.passes 1118.. py:module:: torch.fx.passes.infra 1119.. py:module:: torch.fx.passes.backends 1120.. py:module:: torch.fx.passes.utils 1121.. py:module:: torch.fx.passes.tests 1122.. py:module:: torch.fx.experimental 1123.. py:module:: torch.fx.experimental.unification 1124.. py:module:: torch.fx.experimental.unification.multipledispatch 1125.. py:module:: torch.fx.experimental.migrate_gradual_types 1126.. py:module:: torch.fx.passes.dialect 1127.. py:module:: torch.fx.passes.dialect.common 1128.. py:module:: torch.fx.annotate 1129.. py:module:: torch.fx.config 1130.. py:module:: torch.fx.experimental.accelerator_partitioner 1131.. py:module:: torch.fx.experimental.const_fold 1132.. py:module:: torch.fx.experimental.debug 1133.. py:module:: torch.fx.experimental.graph_gradual_typechecker 1134.. py:module:: torch.fx.experimental.merge_matmul 1135.. py:module:: torch.fx.experimental.meta_tracer 1136.. py:module:: torch.fx.experimental.migrate_gradual_types.constraint 1137.. py:module:: torch.fx.experimental.migrate_gradual_types.constraint_generator 1138.. py:module:: torch.fx.experimental.migrate_gradual_types.constraint_transformation 1139.. py:module:: torch.fx.experimental.migrate_gradual_types.operation 1140.. py:module:: torch.fx.experimental.migrate_gradual_types.transform_to_z3 1141.. py:module:: torch.fx.experimental.migrate_gradual_types.util 1142.. py:module:: torch.fx.experimental.migrate_gradual_types.z3_types 1143.. py:module:: torch.fx.experimental.normalize 1144.. py:module:: torch.fx.experimental.optimization 1145.. py:module:: torch.fx.experimental.partitioner_utils 1146.. py:module:: torch.fx.experimental.recording 1147.. py:module:: torch.fx.experimental.refinement_types 1148.. py:module:: torch.fx.experimental.rewriter 1149.. py:module:: torch.fx.experimental.schema_type_annotation 1150.. py:module:: torch.fx.experimental.sym_node 1151.. py:module:: torch.fx.experimental.unification.core 1152.. py:module:: torch.fx.experimental.unification.dispatch 1153.. py:module:: torch.fx.experimental.unification.match 1154.. py:module:: torch.fx.experimental.unification.more 1155.. py:module:: torch.fx.experimental.unification.multipledispatch.conflict 1156.. py:module:: torch.fx.experimental.unification.multipledispatch.core 1157.. py:module:: torch.fx.experimental.unification.multipledispatch.dispatcher 1158.. py:module:: torch.fx.experimental.unification.multipledispatch.utils 1159.. py:module:: torch.fx.experimental.unification.multipledispatch.variadic 1160.. py:module:: torch.fx.experimental.unification.unification_tools 1161.. py:module:: torch.fx.experimental.unification.utils 1162.. py:module:: torch.fx.experimental.unification.variable 1163.. py:module:: torch.fx.experimental.unify_refinements 1164.. py:module:: torch.fx.experimental.validator 1165.. py:module:: torch.fx.graph 1166.. py:module:: torch.fx.graph_module 1167.. py:module:: torch.fx.immutable_collections 1168.. py:module:: torch.fx.interpreter 1169.. py:module:: torch.fx.node 1170.. py:module:: torch.fx.operator_schemas 1171.. py:module:: torch.fx.passes.annotate_getitem_nodes 1172.. py:module:: torch.fx.passes.backends.cudagraphs 1173.. py:module:: torch.fx.passes.dialect.common.cse_pass 1174.. py:module:: torch.fx.passes.fake_tensor_prop 1175.. py:module:: torch.fx.passes.graph_drawer 1176.. py:module:: torch.fx.passes.graph_manipulation 1177.. py:module:: torch.fx.passes.graph_transform_observer 1178.. py:module:: torch.fx.passes.infra.partitioner 1179.. py:module:: torch.fx.passes.infra.pass_base 1180.. py:module:: torch.fx.passes.infra.pass_manager 1181.. py:module:: torch.fx.passes.net_min_base 1182.. py:module:: torch.fx.passes.operator_support 1183.. py:module:: torch.fx.passes.param_fetch 1184.. py:module:: torch.fx.passes.pass_manager 1185.. py:module:: torch.fx.passes.reinplace 1186.. py:module:: torch.fx.passes.runtime_assert 1187.. py:module:: torch.fx.passes.shape_prop 1188.. py:module:: torch.fx.passes.split_module 1189.. py:module:: torch.fx.passes.split_utils 1190.. py:module:: torch.fx.passes.splitter_base 1191.. py:module:: torch.fx.passes.tests.test_pass_manager 1192.. py:module:: torch.fx.passes.tools_common 1193.. py:module:: torch.fx.passes.utils.common 1194.. py:module:: torch.fx.passes.utils.fuser_utils 1195.. py:module:: torch.fx.passes.utils.matcher_utils 1196.. py:module:: torch.fx.passes.utils.matcher_with_name_node_map_utils 1197.. py:module:: torch.fx.passes.utils.source_matcher_utils 1198.. py:module:: torch.fx.proxy 1199.. py:module:: torch.fx.subgraph_rewriter 1200.. py:module:: torch.fx.tensor_type 1201.. py:module:: torch.fx.traceback 1202