1Writing Graph Transformations on ATen IR 2======================================== 3 4Passes 5------ 6 7Since the ATen IR sits at the FX Graph/GraphModule level, any 8transformations written for FX Graphs can be easily applied onto the 9ATen IR. If you’re familiar with writing FX graph transformations, then 10this will be the same. 11 12The most direct way of writing transformations is by looping through the 13given graph and directly manipulating the nodes within the graph. 14 15For example, let’s say we want to replace 16``torch.ops.aten.add.Tensor()`` calls with 17``torch.ops.aten.mul.Tensor()`` calls: 18 19.. code:: python 20 21 import torch 22 23 def replace_add_with_mul(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: 24 for node in gm.graph.nodes: 25 if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor: 26 node.target = torch.ops.aten.mul.Tensor 27 28We can also delete and append new nodes through FX utility functions 29that can be found in the 30`Graph <https://pytorch.org/docs/stable/fx.html#torch.fx.Graph>`__ 31documentation. For example, if we want to insert a 32``torch.ops.aten.relu.default()`` after the ``add`` call: 33 34.. code:: python 35 36 import torch 37 38 def insert_relu_after_add(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: 39 for node in gm.graph.nodes: 40 if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor: 41 42 # Specifies the insertion point. Any nodes added to the graph within 43 # this scope will be inserted after `node` 44 with gm.graph.inserting_after(node): 45 # Insert a new `call_function` node with op `torch.ops.aten.relu.default` 46 new_relu_node = gm.graph.call_function(torch.ops.aten.relu.default, args=(node,)) 47 # Replace all the places that use `node` to now use the `new_relu_node` 48 node.replace_all_uses_with(new_relu_node) 49 50In general, transformations can be roughly categorized into a couple of 51axis: 52 53Axis A: 1. Creating one-to-X mapping (eg. decomposition) 2. Creating 54many-to-one mapping (eg. fusion) 55 56Axis B: 1. Doing forwards iteration (eg. shape propagation) 2. Doing 57backwards iteration (eg. dead code elimination) 58 59Axis C: 1. Dependent on local node information (eg. out-variant 60conversion) 2. Dependent on global graph information (eg. memory 61planning) 62 63Our projection on the frequency of these use cases are: 1. A.1, B.1, C.1 642. A.2 3. B.2, C.2 65 66Although we can make all graph transformations through directly 67manipulating the graph, we also provide some helper utilities for some 68ease of use for the level 1 and 2 use-cases. 69 70Transformer 71~~~~~~~~~~~ 72 73For level 1 uses cases (creating one-to-X mappings, doing forwards 74iterations, and looking at local node information), we can utilize the 75`Transformer <https://pytorch.org/docs/stable/fx.html#torch.fx.Transformer>`__ 76class to execute each node and recreate a graph, except with the 77transformations specified. 78 79One-to-One Pass 80^^^^^^^^^^^^^^^ 81 82An example for one-to-one mappings, if we wanted to replace an op A with 83another op B, we can run the GraphModule, and very time we see op A, 84return op B. 85 86An example is: 87 88.. code:: python 89 90 class ReplaceAddWithMul(torch.fx.Transformer): 91 def call_function(self, target, args, kwargs): 92 if target != torch.ops.aten.add.Tensor: 93 return super().call_function(target, args, kwargs) 94 return super().call_function(torch.ops.aten.mul.Tensor, args, kwargs) 95 96 transformed_graph_module = ReplaceAddWithMul(graph_module).transform() 97 98The ``super().call_function(target, args, kwargs, meta)`` call creates a 99``call_function`` FX node, and returns the result of running the 100operator with the given arguments. 101 102One-to-X Pass 103^^^^^^^^^^^^^ 104 105If we wanted to do one-to-X mappings, like replacing op A with 2 other 106ops B and C, we would then make 2 calls to ``super().call_function`` to 107create 2 FX nodes, one with op B and another with op C, and return the 108result of running op C. 109 110For example: 111 112.. code:: python 113 114 class ReplaceAddWithMulSub(torch.fx.Transformer): 115 """ 116 Original: 117 def f(x, y): 118 return x + y 119 120 After pass: 121 def f(x, y): 122 z = x * y 123 return z - y 124 """ 125 def call_function(self, target, args, kwargs): 126 if target != torch.ops.aten.add.Tensor: 127 return super().call_function(target, args, kwargs) 128 129 x, y = args 130 131 mul_res = super().call_function(torch.ops.aten.mul.Tensor, args, {}) 132 return super().call_function(torch.ops.aten.sub.Tensor, (mul_res, y), {}) 133 134 transformed_graph_module = ReplaceAddWithMulSub(graph_module).transform() 135 136One-to-None Pass 137^^^^^^^^^^^^^^^^ 138 139If we wanted to remove an op, we can just return the value passed into 140the function: 141 142.. code:: python 143 144 class RemoveDetachPass(torch.fx.Transformer): 145 def call_function(self, target, args, kwargs): 146 if target not in ( 147 torch.ops.aten.detach.default, 148 torch.ops.aten.detach_copy.default, 149 ): 150 return super().call_function(target, args, kwargs, meta) 151 152 assert len(args) == 1 153 return args[0] 154 155 transformed_graph_module = RemoveDetachPass(graph_module).transform() 156 157Utilizing Local Information 158^^^^^^^^^^^^^^^^^^^^^^^^^^^ 159 160An example of utilizing local node information is, if we wanted to 161convert all the scalars within the graph to tensors, we can run the 162given ``fx.GraphModule``, and for every argument that contains a scalar, 163we convert it to a tensor. It might look something like: 164 165.. code:: python 166 167 def args_map(target, fn, args, kwargs): 168 assert isinstance(args, tuple) 169 assert isinstance(kwargs, dict) 170 args = list(args) 171 kwargs = kwargs.copy() 172 173 # Update the argument based on the function passed 174 def update(key, args, schema): 175 args[key] = fn(args[key], schema) 176 177 # Update each argument in the schema 178 for i, schema in enumerate(target._schema.arguments): 179 if schema.name in kwargs: 180 update(schema.name, kwargs, schema) 181 elif not schema.kwarg_only and i < len(args): 182 update(i, args, schema) 183 return tuple(args), kwargs 184 185 class ScalarToTensorPass(torch.fx.Transformer): 186 def call_function(self, target, args, kwargs): 187 breakpoint() 188 def try_coerce(value, arg): 189 return ( 190 torch.tensor(value) 191 if isinstance(value, (float, int, bool)) 192 and type(arg.type) == torch.TensorType 193 else value 194 ) 195 196 args, kwargs = args_map(target, try_coerce, args, kwargs) 197 return super().call_function(target, args, kwargs) 198 199 transformed_graph_module = ScalarToTensorPass(graph_module).transform() 200 201Subgraph Rewriter 202~~~~~~~~~~~~~~~~~ 203 204For creating many-to-one mappings, we can utilize FX’s `subgraph 205rewriter <https://github.com/pytorch/pytorch/blob/main/torch/fx/subgraph_rewriter.py>`__. 206Given a ``pattern``, it creates a subgraph of operators matching to the 207pattern, and then replaces each matched subgraph with the 208``replacement``. 209 210Note: 211 212:: 213 214 This is an inplace operation. 215 216The ``pattern`` and ``replacement`` inputs must be callable functions or 217GraphModules containing the same operators that are used within the 218graph (ATen ops) so that the subgraph rewriter can find the correct 219pattern in the graph. Inputs to the pattern/replacement callables will 220be treated as wildcards when matching. 221 222An example: 223 224.. code:: python 225 226 from torch.fx import subgraph_rewriter 227 228 def replace_patterns(graph_module): 229 def pattern(x, y): 230 x = torch.ops.aten.add.Tensor(x, y) 231 x = torch.ops.aten.mul.Tensor(x, y) 232 return x 233 234 def replacement(x, y): 235 return torch.ops.aten.sub.Tensor(x, y) 236 237 replaced_patterns = subgraph_rewriter.replace_pattern_with_filters( 238 traced_module, pattern, replacement 239 ) 240 241The subgraph rewriter returns a list of ``ReplacedPatterns``: 242 243.. code:: python 244 245 @dataclass 246 class ReplacedPatterns: 247 # Node from which the match was found 248 anchor: Node 249 # Maps nodes in the pattern subgraph to nodes in the larger graph 250 nodes_map: Dict[Node, Node] 251 # List of nodes that were added into the graph 252 replacements: List[Node] 253 254Note: 255 256:: 257 258 The nodes created by the subgraph rewriter will not have the metadata that 259 is populated in the matched nodes, but you can use 260 `ReplacedPatterns.nodes_map` to find the nodes in the original graph that 261 were matched, and `ReplacedPatterns.replacements` to find the nodes that 262 were replaced in the transformed graph. 263 264Pass Manager 265------------ 266 267The 268```PassManager`` <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/infra/pass_manager.py>`__ 269is a class used to run multiple passes on a given graph module. When 270initializing a ``PassManager`` instance, we pass in a list of passes 271that we want to run and set a couple of flags. To run the collection of 272passes on a graph module, we can pass the graph module directly to the 273``PassManager`` instance. 274 275An example: 276 277.. code:: python 278 279 from torch.fx.passes.infra.pass_manager import PassManager 280 281 pm = PassManager( 282 passes=[replace_add_with_div, replace_div_with_mul], 283 run_checks_after_each_pass=True, 284 suppress_check_failures=False, 285 ) 286 graph_module_out = pm(graph_module) 287 288To add a common set of checks that are run after each pass, we can call 289the function ``set_checks(check: Callable)`` which takes in a callable 290function as input. If the ``run_checks_after_each_pass`` flag is set, 291the ``check`` will be called after each pass is run on the graph module. 292 293An example: 294 295.. code:: python 296 297 pm = PassManager(passes=[replace_add_with_div, replace_div_with_mul]) 298 299 def check_div_target(graph_module): 300 for node in graph_module.graph.nodes: 301 if node.op == "call_function" and node.target != torch.div: 302 raise ValueError("Target should be div!") 303 304 pm.add_checks(check_div_target) 305 306 pm(graph_module) # raises ValueError after replace_div_with_mul pass 307 308Partitioner 309----------- 310 311There are a couple of common FX graph based partitioners we can use to 312partition the graph. 313 314Subgraph Matcher 315~~~~~~~~~~~~~~~~ 316 317For finding subgraphs within a graph that match a specific pattern, we 318can utilize FX’s 319```SubgraphMatcher`` <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/utils/matcher_utils.py>`__. 320 321Class Attributes: 322 323- ``pattern (Graph)``: The targeted matching pattern. Placeholder nodes 324 in the graph will be treated as wildcards when matching. 325- ``match_output (bool)``: If True, output node in the pattern graph 326 will be treated as a part of the targeted pattern. If False, output 327 node is ignored during match. 328- ``match_placeholder (bool)``: If True, placeholder node in the 329 pattern graph will be treated as a part of the targeted pattern. If 330 False, placeholder nodes will be used a wildcard. 331- ``remove_overlapping_matches (bool)``: If True, in the case of 332 overlapping matches, only the first match will be returned. 333- ``ignore_literals (bool)``: If True, will not check if literals are 334 equal and will instead treat them as wildcards. 335 336An example: 337 338.. code:: python 339 340 from torch.fx.passes.utils.matcher_utils import SubgraphMatcher 341 342 class LargeModel(torch.nn.Module): 343 def __init__(self): 344 super().__init__() 345 self._weight = torch.nn.Parameter(torch.ones(3, 3)) 346 self._bias = torch.nn.Parameter(torch.ones(3, 3)) 347 348 def forward(self, x): 349 return torch.ops.aten.addmm.default(self._bias, x, self._weight) 350 351 large_model_graph = torch.export(LargeModel(), inputs).graph 352 353 class PatternModel(torch.nn.Module): 354 def __init__(self): 355 super().__init__() 356 self._weight_1 = torch.nn.Parameter(torch.ones(5, 5)) 357 self._bias_1 = torch.nn.Parameter(torch.ones(5, 5)) 358 359 def forward(self, x): 360 return torch.ops.aten.addmm.default(self._bias_1, x, self._weight_1) 361 362 pattern_graph = torch.export(PatternModel(), inputs).graph 363 364 subgraph_matcher = SubgraphMatcher(pattern_graph) 365 match_result = subgraph_matcher.match(large_model_graph) 366 367The ``match`` function returns a list of ``InternalMatch``: 368 369.. code:: python 370 371 @dataclass 372 class InternalMatch(): 373 # Nodes from which the match was found 374 anchors: List[Node] 375 # Maps nodes in the pattern subgraph to nodes in the larger graph 376 nodes_map: Dict[Node, Node] = field(default_factory=dict) 377 # Nodes in target graph that are matched placeholder in pattern 378 placeholder_nodes: List[Node] = field(default_factory=list) 379 # Nodes in matched subgraph returned by output 380 returning_nodes: List[Node] = field(default_factory=list) 381 382Capability Based Partitioner 383~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 384 385To find the largest subgraphs of nodes that support a specific 386invariant, we can utilize FX’s 387```CapabilityBasedPartitioner`` <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/infra/partitioner.py#L34>`__. 388 389Class Attributes 390 391- ``graph_module (torch.fx.GraphModule)``: The graph module we are 392 partitioning on. 393- ``operator_support (OperatorSupportBase)``: The object used to 394 determine if a node in the graph is supported in the partition. 395- ``allows_single_node_partition (bool)``: If True, allows single node 396 partitions to be formed. 397- ``non_compute_ops (Optional[Sequence[str]])``: A set of ops that are 398 considered to be “non-compute” (ex ``torch.ops.aten.view`` and 399 ``_operator.getitem``, so that the partitioner will not create graphs 400 that only contain these non-compute ops 401- ``allowed_single_node_partition_ops (Optional[Sequence[str]])``: A 402 set of ops that are allowed to be in a single node partition. 403 404The 405```OperatorSupportBase`` <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/operator_support.py#LL28C1-L28C1>`__ 406class is used by the partitioner to determine if a specific node in the 407graph belongs in the partition. This is done by overriding the 408``is_node_supported`` function. You can chain multiple 409``OperatorSupportBase`` by using 410```chain`` <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/operator_support.py#L150>`__\ (which 411returns False if any of the OperatorSupportBase return False) and 412```any_chain`` <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/operator_support.py#L164>`__ 413(which returns True if any of the OperatorSupportBase returns True). 414 415An example: 416 417.. code:: python 418 419 from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner 420 from torch.fx.passes.operator_support import any_chain, OperatorSupportBase 421 422 class AddMulOperatorSupport(OperatorSupportBase): 423 def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: 424 return node.op == "call_function" and node.target in [ 425 torch.ops.aten.add.Tensor, torch.ops.aten.mul.Tensor, 426 ] 427 428 capability_partitioner = CapabilityBasedPartitioner( 429 graph_module, 430 op_support, 431 ) 432 433 # Returns a list of partitions (list of nodes that belong in each partition) 434 partition_list = capability_partitioner.propose_partitions() 435 # Fuses the partitions into graph modules and inserts `call_module` nodes in the graph 436 fused_graph_module = capability_partitioner.fuse_partitions(partition_list) 437