1*da0073e9SAndroid Build Coastguard Worker# FX Technical Overview 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard WorkerFX is a toolkit for pass writers to facilitate Python-to-Python transformation of `nn.Module` instances. This toolkit aims to support a subset of Python language semantics—rather than the whole Python language—to facilitate ease of implementation of transforms. Currently, this feature is under a Beta release and its API may change. 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Worker## Table of Contents 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Worker<!-- toc --> 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Worker- [Introduction](#introduction) 10*da0073e9SAndroid Build Coastguard Worker - [Use Cases](#use-cases) 11*da0073e9SAndroid Build Coastguard Worker - [Technical Details](#technical-details) 12*da0073e9SAndroid Build Coastguard Worker- [Internal Structure](#internal-structure) 13*da0073e9SAndroid Build Coastguard Worker - [Graph](#graph) 14*da0073e9SAndroid Build Coastguard Worker - [Node](#node) 15*da0073e9SAndroid Build Coastguard Worker - [GraphModule](#graphmodule) 16*da0073e9SAndroid Build Coastguard Worker- [Tracing](#tracing) 17*da0073e9SAndroid Build Coastguard Worker - [Symbolic Tracer](#symbolic-tracer) 18*da0073e9SAndroid Build Coastguard Worker - [Proxy](#proxy) 19*da0073e9SAndroid Build Coastguard Worker - [TorchDynamo](#torchdynamo) 20*da0073e9SAndroid Build Coastguard Worker- [The FX IR Container](#the-fx-ir-container) 21*da0073e9SAndroid Build Coastguard Worker- [Transformation and Codegen](#transformation-and-codegen) 22*da0073e9SAndroid Build Coastguard Worker- [Next steps](#next-steps) 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker<!-- tocstop --> 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Worker# Introduction 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker## Use Cases ## 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard WorkerFX should be used by pass writers to provide functionality for capturing and constructing nn.Module code in a structured way. We do not expect end users to utilize FX directly. A useful property of framing FX in this way is that passes can be seen as functions of the form `pass(in_mod : nn.Module) -> nn.Module`. This means we can create composable pipelines of transformations. 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard WorkerIn this example pipeline, we have a Quantize transformation, which is then composed with a Split transformation, then a Lower to Accelerator transformation. Finally, the transformed Modules are compiled with TorchScript for deployment. This last point emphasizes that not only should FX transforms be composable with each other, but their products are composable with other systems like TorchScript compilation or tracing. 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard WorkerBy using `nn.Module` as the interface between passes, FX transforms are interoperable with each other, and the resulting model can be used anywhere an `nn.Module` can be used. 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Worker## Technical Details ## 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard WorkerThe following sections will walk us through the components that transform from original `torch.nn.Module` to FX IR and finally to generated Python code and a GraphModule instance: 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard WorkerFX’s front-end makes use of the dynamic nature of Python to intercept call-sites for various entities (PyTorch operators, Module invocations, and Tensor method invocations). The simplest way to get an FX graph is by using `torch.fx.symbolic_trace`. We can see how this works by way of an example: 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker```python 45*da0073e9SAndroid Build Coastguard Workerimport torch 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Workerclass MyModule(torch.nn.Module): 48*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 49*da0073e9SAndroid Build Coastguard Worker super().__init__() 50*da0073e9SAndroid Build Coastguard Worker self.param = torch.nn.Parameter( 51*da0073e9SAndroid Build Coastguard Worker torch.rand(3, 4)) 52*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(4, 5) 53*da0073e9SAndroid Build Coastguard Worker 54*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 55*da0073e9SAndroid Build Coastguard Worker return self.linear(x + self.param).clamp(min=0.0, max=1.0) 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Workerfrom torch.fx import symbolic_trace 58*da0073e9SAndroid Build Coastguard Workermodule = MyModule() 59*da0073e9SAndroid Build Coastguard Workersymbolic_traced : torch.fx.GraphModule = symbolic_trace(module) 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Workerinput = torch.rand(3, 4) 62*da0073e9SAndroid Build Coastguard Workertorch.testing.assert_close(symbolic_traced(input), module(input)) 63*da0073e9SAndroid Build Coastguard Worker``` 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard WorkerHere, we set up a simple Module that exercises different language features: fetching a parameter, applying an arithmetic operator, applying a submodule (linear), and applying a Tensor method. `symbolic_trace` returns an instance of GraphModule, which is in itself a subclass of `nn.Module`. We can see that the `symbolic_traced` instance runs and returns the same result as the original module instance module. 66*da0073e9SAndroid Build Coastguard Worker 67*da0073e9SAndroid Build Coastguard Worker# Internal Structure 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Worker## [Graph](https://pytorch.org/docs/main/fx.html#torch.fx.Graph) ## 70*da0073e9SAndroid Build Coastguard WorkerThe `fx.Graph` is a core data structure in FX that represents the operations and their dependencies in a structured format. It consists of a List of `fx.Node` representing individual operations and their inputs and outputs. The Graph enables simple manipulation and analysis of the model structure, which is essential for implementing various transformations and optimizations. 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Worker## Node 73*da0073e9SAndroid Build Coastguard WorkerAn `fx.Node` is a datastructure that represent individual operations within an `fx.Graph`, it maps to callsites such as operators, methods and modules. Each `fx.Node` keeps track of its inputs, the previous and next nodes, the stacktrace so you can map back the node to a line of code in your python file and some optional metadata stored in a `meta` dict. 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker## [GraphModule](https://pytorch.org/docs/main/fx.html#torch.fx.GraphModule) ## 76*da0073e9SAndroid Build Coastguard WorkerThe `fx.GraphModule` is a subclass of `nn.Module` that holds the transformed Graph, the original module's parameter attributes and its source code. It serves as the primary output of FX transformations and can be used like any other `nn.Module`. `fx.GraphModule` allows for the execution of the transformed model, as it generates a valid forward method based on the Graph's structure. 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker# Tracing 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Worker## [Symbolic Tracer](https://pytorch.org/docs/main/fx.html#torch.fx.Tracer) ## 82*da0073e9SAndroid Build Coastguard Worker 83*da0073e9SAndroid Build Coastguard Worker`Tracer` is the class that implements the symbolic tracing functionality of `torch.fx.symbolic_trace`. A call to `symbolic_trace(m)` is equivalent to `Tracer().trace(m)`. Tracer can be subclassed to override various behaviors of the tracing process. The different behaviors that can be overridden are described in the docstrings of the methods on the class. 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard WorkerIn the default implementation of `Tracer().trace`, the tracer first creates Proxy objects for all arguments in the `forward` function. (This happens in the call to `create_args_for_root`.) Next, the `forward` function is called with the new Proxy arguments. As the Proxies flow through the program, they record all the operations (`torch` function calls, method calls, and operators) that they touch into the growing FX Graph as Nodes. 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker## Proxy ## 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard WorkerProxy objects are Node wrappers used by the Tracer to record operations seen during symbolic tracing. The mechanism through which Proxy objects record computation is [`__torch_function__`](https://pytorch.org/docs/stable/notes/extending.html#extending-torch). If any custom Python type defines a method named `__torch_function__`, PyTorch will invoke that `__torch_function__` implementation when an instance of that custom type is passed to a function in the `torch` namespace. In FX, when operations on Proxy are dispatched to the `__torch_function__` handler, the `__torch_function__` handler records the operation in the Graph as a Node. The Node that was recorded in the Graph is then itself wrapped in a Proxy, facilitating further application of ops on that value. 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard WorkerConsider the following example: 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker```python 94*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 95*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 96*da0073e9SAndroid Build Coastguard Worker return torch.relu(x) 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker m = M() 99*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(m) 100*da0073e9SAndroid Build Coastguard Worker``` 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard WorkerDuring the call to `symbolic_trace`, the parameter `x` is transformed into a Proxy object and the corresponding Node (a Node with op = “placeholder” and target = “x”) is added to the Graph. Then, the Module is run with Proxies as inputs, and recording happens via the `__torch_function__` dispatch path. 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard WorkerIf you're doing graph transforms, you can wrap your own Proxy method around a raw Node so that you can use the overloaded operators to add additional things to a Graph. 105*da0073e9SAndroid Build Coastguard Worker 106*da0073e9SAndroid Build Coastguard Worker## [TorchDynamo](https://pytorch.org/docs/main/torch.compiler_dynamo_deepdive.html) ## 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard WorkerTracing has limitations in that it can't deal with dynamic control flow and is limited to outputting a single graph at a time, so a better alternative is the new `torch.compile()` infrastructure where you can output multiple subgraphs in either an aten or torch IR using `torch.fx`. [This tutorial](https://colab.research.google.com/drive/1Zh-Uo3TcTH8yYJF-LLo5rjlHVMtqvMdf) gives more context on how this works. 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard Worker 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Worker# The FX IR Container 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard WorkerTracing captures an intermediate representation (IR), which is represented as a doubly-linked list of Nodes. 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard WorkerNode is the data structure that represents individual operations within a Graph. For the most part, Nodes represent callsites to various entities, such as operators, methods, and Modules (some exceptions include Nodes that specify function inputs and outputs). Each Node has a function specified by its `op` property. The Node semantics for each value of `op` are as follows: 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Worker- `placeholder` represents a function input. The `name` attribute specifies the name this value will take on. `target` is similarly the name of the argument. `args` holds either: 1) nothing, or 2) a single argument denoting the default parameter of the function input. `kwargs` is don't-care. Placeholders correspond to the function parameters (e.g. `x`) in the graph printout. 119*da0073e9SAndroid Build Coastguard Worker- `get_attr` retrieves a parameter from the module hierarchy. `name` is similarly the name the result of the fetch is assigned to. `target` is the fully-qualified name of the parameter's position in the module hierarchy. `args` and `kwargs` are don't-care 120*da0073e9SAndroid Build Coastguard Worker- `call_function` applies a free function to some values. `name` is similarly the name of the value to assign to. `target` is the function to be applied. `args` and `kwargs` represent the arguments to the function, following the Python calling convention 121*da0073e9SAndroid Build Coastguard Worker- `call_module` applies a module in the module hierarchy's `forward()` method to given arguments. `name` is as previous. `target` is the fully-qualified name of the module in the module hierarchy to call. `args` and `kwargs` represent the arguments to invoke the module on, *including the self argument*. 122*da0073e9SAndroid Build Coastguard Worker- `call_method` calls a method on a value. `name` is as similar. `target` is the string name of the method to apply to the `self` argument. `args` and `kwargs` represent the arguments to invoke the module on, *including the self argument* 123*da0073e9SAndroid Build Coastguard Worker- `output` contains the output of the traced function in its `args[0]` attribute. This corresponds to the "return" statement in the Graph printout. 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard WorkerTo facilitate easier analysis of data dependencies, Nodes have read-only properties `input_nodes` and `users`, which specify which Nodes in the Graph are used by this Node and which Nodes use this Node, respectively. Although Nodes are represented as a doubly-linked list, the use-def relationships form an acyclic graph and can be traversed as such. 126*da0073e9SAndroid Build Coastguard Worker 127*da0073e9SAndroid Build Coastguard Worker# Transformation and Codegen 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard WorkerAn invocation of `symbolic_traced` above requires a valid `forward()` method to be defined on the Module instance. How does this work? GraphModule actually generates valid Python source code based on the IR it is instantiated with. This can be seen by accessing the code attribute on the GraphModule: `print(symbolic_traced.code)`. 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard WorkerAfter tracing, the code given under [Technical Details](#technical-details) is represented as follows: 132*da0073e9SAndroid Build Coastguard Worker 133*da0073e9SAndroid Build Coastguard Worker```python 134*da0073e9SAndroid Build Coastguard Workerdef forward(self, x): 135*da0073e9SAndroid Build Coastguard Worker param = self.param 136*da0073e9SAndroid Build Coastguard Worker add_1 = x + param; x = param = None 137*da0073e9SAndroid Build Coastguard Worker linear_1 = self.linear(add_1); add_1 = None 138*da0073e9SAndroid Build Coastguard Worker clamp_1 = linear_1.clamp(min = 0.0, max = 1.0); linear_1 = None 139*da0073e9SAndroid Build Coastguard Worker return clamp_1 140*da0073e9SAndroid Build Coastguard Worker``` 141*da0073e9SAndroid Build Coastguard Worker 142*da0073e9SAndroid Build Coastguard WorkerThis is the core of why FX is a Python-to-Python translation toolkit. Outside users can treat the results of FX transformations as they would any other `nn.Module` instance. 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Worker# Next steps 145*da0073e9SAndroid Build Coastguard WorkerIf you're interested in learning more about obtaining fx graphs, which kinds of IRs are available to you and how to execute simple transformations make sure to check out [this tutorial](https://colab.research.google.com/drive/1Zh-Uo3TcTH8yYJF-LLo5rjlHVMtqvMdf) 146