xref: /aosp_15_r20/external/pytorch/docs/source/torch.compiler_transformations.rst (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1Writing Graph Transformations on ATen IR
2========================================
3
4Passes
5------
6
7Since the ATen IR sits at the FX Graph/GraphModule level, any
8transformations written for FX Graphs can be easily applied onto the
9ATen IR. If you’re familiar with writing FX graph transformations, then
10this will be the same.
11
12The most direct way of writing transformations is by looping through the
13given graph and directly manipulating the nodes within the graph.
14
15For example, let’s say we want to replace
16``torch.ops.aten.add.Tensor()`` calls with
17``torch.ops.aten.mul.Tensor()`` calls:
18
19.. code:: python
20
21   import torch
22
23   def replace_add_with_mul(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
24       for node in gm.graph.nodes:
25           if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:
26               node.target = torch.ops.aten.mul.Tensor
27
28We can also delete and append new nodes through FX utility functions
29that can be found in the
30`Graph <https://pytorch.org/docs/stable/fx.html#torch.fx.Graph>`__
31documentation. For example, if we want to insert a
32``torch.ops.aten.relu.default()`` after the ``add`` call:
33
34.. code:: python
35
36   import torch
37
38   def insert_relu_after_add(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
39       for node in gm.graph.nodes:
40           if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:
41
42               # Specifies the insertion point. Any nodes added to the graph within
43               # this scope will be inserted after `node`
44               with gm.graph.inserting_after(node):
45                   # Insert a new `call_function` node with op `torch.ops.aten.relu.default`
46                   new_relu_node = gm.graph.call_function(torch.ops.aten.relu.default, args=(node,))
47                   # Replace all the places that use `node` to now use the `new_relu_node`
48                   node.replace_all_uses_with(new_relu_node)
49
50In general, transformations can be roughly categorized into a couple of
51axis:
52
53Axis A: 1. Creating one-to-X mapping (eg. decomposition) 2. Creating
54many-to-one mapping (eg. fusion)
55
56Axis B: 1. Doing forwards iteration (eg. shape propagation) 2. Doing
57backwards iteration (eg. dead code elimination)
58
59Axis C: 1. Dependent on local node information (eg. out-variant
60conversion) 2. Dependent on global graph information (eg. memory
61planning)
62
63Our projection on the frequency of these use cases are: 1. A.1, B.1, C.1
642. A.2 3. B.2, C.2
65
66Although we can make all graph transformations through directly
67manipulating the graph, we also provide some helper utilities for some
68ease of use for the level 1 and 2 use-cases.
69
70Transformer
71~~~~~~~~~~~
72
73For level 1 uses cases (creating one-to-X mappings, doing forwards
74iterations, and looking at local node information), we can utilize the
75`Transformer <https://pytorch.org/docs/stable/fx.html#torch.fx.Transformer>`__
76class to execute each node and recreate a graph, except with the
77transformations specified.
78
79One-to-One Pass
80^^^^^^^^^^^^^^^
81
82An example for one-to-one mappings, if we wanted to replace an op A with
83another op B, we can run the GraphModule, and very time we see op A,
84return op B.
85
86An example is:
87
88.. code:: python
89
90   class ReplaceAddWithMul(torch.fx.Transformer):
91       def call_function(self, target, args, kwargs):
92           if target != torch.ops.aten.add.Tensor:
93               return super().call_function(target, args, kwargs)
94           return super().call_function(torch.ops.aten.mul.Tensor, args, kwargs)
95
96   transformed_graph_module = ReplaceAddWithMul(graph_module).transform()
97
98The ``super().call_function(target, args, kwargs, meta)`` call creates a
99``call_function`` FX node, and returns the result of running the
100operator with the given arguments.
101
102One-to-X Pass
103^^^^^^^^^^^^^
104
105If we wanted to do one-to-X mappings, like replacing op A with 2 other
106ops B and C, we would then make 2 calls to ``super().call_function`` to
107create 2 FX nodes, one with op B and another with op C, and return the
108result of running op C.
109
110For example:
111
112.. code:: python
113
114   class ReplaceAddWithMulSub(torch.fx.Transformer):
115       """
116       Original:
117           def f(x, y):
118               return x + y
119
120       After pass:
121           def f(x, y):
122               z = x * y
123               return z - y
124       """
125       def call_function(self, target, args, kwargs):
126           if target != torch.ops.aten.add.Tensor:
127               return super().call_function(target, args, kwargs)
128
129           x, y = args
130
131           mul_res = super().call_function(torch.ops.aten.mul.Tensor, args, {})
132           return super().call_function(torch.ops.aten.sub.Tensor, (mul_res, y), {})
133
134   transformed_graph_module = ReplaceAddWithMulSub(graph_module).transform()
135
136One-to-None Pass
137^^^^^^^^^^^^^^^^
138
139If we wanted to remove an op, we can just return the value passed into
140the function:
141
142.. code:: python
143
144   class RemoveDetachPass(torch.fx.Transformer):
145       def call_function(self, target, args, kwargs):
146           if target not in (
147               torch.ops.aten.detach.default,
148               torch.ops.aten.detach_copy.default,
149           ):
150               return super().call_function(target, args, kwargs, meta)
151
152           assert len(args) == 1
153           return args[0]
154
155   transformed_graph_module = RemoveDetachPass(graph_module).transform()
156
157Utilizing Local Information
158^^^^^^^^^^^^^^^^^^^^^^^^^^^
159
160An example of utilizing local node information is, if we wanted to
161convert all the scalars within the graph to tensors, we can run the
162given ``fx.GraphModule``, and for every argument that contains a scalar,
163we convert it to a tensor. It might look something like:
164
165.. code:: python
166
167   def args_map(target, fn, args, kwargs):
168       assert isinstance(args, tuple)
169       assert isinstance(kwargs, dict)
170       args = list(args)
171       kwargs = kwargs.copy()
172
173       # Update the argument based on the function passed
174       def update(key, args, schema):
175           args[key] = fn(args[key], schema)
176
177       # Update each argument in the schema
178       for i, schema in enumerate(target._schema.arguments):
179           if schema.name in kwargs:
180               update(schema.name, kwargs, schema)
181           elif not schema.kwarg_only and i < len(args):
182               update(i, args, schema)
183       return tuple(args), kwargs
184
185   class ScalarToTensorPass(torch.fx.Transformer):
186       def call_function(self, target, args, kwargs):
187           breakpoint()
188           def try_coerce(value, arg):
189               return (
190                   torch.tensor(value)
191                   if isinstance(value, (float, int, bool))
192                   and type(arg.type) == torch.TensorType
193                   else value
194               )
195
196           args, kwargs = args_map(target, try_coerce, args, kwargs)
197           return super().call_function(target, args, kwargs)
198
199   transformed_graph_module = ScalarToTensorPass(graph_module).transform()
200
201Subgraph Rewriter
202~~~~~~~~~~~~~~~~~
203
204For creating many-to-one mappings, we can utilize FX’s `subgraph
205rewriter <https://github.com/pytorch/pytorch/blob/main/torch/fx/subgraph_rewriter.py>`__.
206Given a ``pattern``, it creates a subgraph of operators matching to the
207pattern, and then replaces each matched subgraph with the
208``replacement``.
209
210Note:
211
212::
213
214   This is an inplace operation.
215
216The ``pattern`` and ``replacement`` inputs must be callable functions or
217GraphModules containing the same operators that are used within the
218graph (ATen ops) so that the subgraph rewriter can find the correct
219pattern in the graph. Inputs to the pattern/replacement callables will
220be treated as wildcards when matching.
221
222An example:
223
224.. code:: python
225
226   from torch.fx import subgraph_rewriter
227
228   def replace_patterns(graph_module):
229       def pattern(x, y):
230           x = torch.ops.aten.add.Tensor(x, y)
231           x = torch.ops.aten.mul.Tensor(x, y)
232           return x
233
234       def replacement(x, y):
235           return torch.ops.aten.sub.Tensor(x, y)
236
237   replaced_patterns = subgraph_rewriter.replace_pattern_with_filters(
238       traced_module, pattern, replacement
239   )
240
241The subgraph rewriter returns a list of ``ReplacedPatterns``:
242
243.. code:: python
244
245   @dataclass
246   class ReplacedPatterns:
247       # Node from which the match was found
248       anchor: Node
249       # Maps nodes in the pattern subgraph to nodes in the larger graph
250       nodes_map: Dict[Node, Node]
251       # List of nodes that were added into the graph
252       replacements: List[Node]
253
254Note:
255
256::
257
258   The nodes created by the subgraph rewriter will not have the metadata that
259   is populated in the matched nodes, but you can use
260   `ReplacedPatterns.nodes_map` to find the nodes in the original graph that
261   were matched, and `ReplacedPatterns.replacements` to find the nodes that
262   were replaced in the transformed graph.
263
264Pass Manager
265------------
266
267The
268```PassManager`` <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/infra/pass_manager.py>`__
269is a class used to run multiple passes on a given graph module. When
270initializing a ``PassManager`` instance, we pass in a list of passes
271that we want to run and set a couple of flags. To run the collection of
272passes on a graph module, we can pass the graph module directly to the
273``PassManager`` instance.
274
275An example:
276
277.. code:: python
278
279   from torch.fx.passes.infra.pass_manager import PassManager
280
281   pm = PassManager(
282       passes=[replace_add_with_div, replace_div_with_mul],
283       run_checks_after_each_pass=True,
284       suppress_check_failures=False,
285   )
286   graph_module_out = pm(graph_module)
287
288To add a common set of checks that are run after each pass, we can call
289the function ``set_checks(check: Callable)`` which takes in a callable
290function as input. If the ``run_checks_after_each_pass`` flag is set,
291the ``check`` will be called after each pass is run on the graph module.
292
293An example:
294
295.. code:: python
296
297   pm = PassManager(passes=[replace_add_with_div, replace_div_with_mul])
298
299   def check_div_target(graph_module):
300       for node in graph_module.graph.nodes:
301           if node.op == "call_function" and node.target != torch.div:
302               raise ValueError("Target should be div!")
303
304   pm.add_checks(check_div_target)
305
306   pm(graph_module)    # raises ValueError after replace_div_with_mul pass
307
308Partitioner
309-----------
310
311There are a couple of common FX graph based partitioners we can use to
312partition the graph.
313
314Subgraph Matcher
315~~~~~~~~~~~~~~~~
316
317For finding subgraphs within a graph that match a specific pattern, we
318can utilize FX’s
319```SubgraphMatcher`` <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/utils/matcher_utils.py>`__.
320
321Class Attributes:
322
323-  ``pattern (Graph)``: The targeted matching pattern. Placeholder nodes
324   in the graph will be treated as wildcards when matching.
325-  ``match_output (bool)``: If True, output node in the pattern graph
326   will be treated as a part of the targeted pattern. If False, output
327   node is ignored during match.
328-  ``match_placeholder (bool)``: If True, placeholder node in the
329   pattern graph will be treated as a part of the targeted pattern. If
330   False, placeholder nodes will be used a wildcard.
331-  ``remove_overlapping_matches (bool)``: If True, in the case of
332   overlapping matches, only the first match will be returned.
333-  ``ignore_literals (bool)``: If True, will not check if literals are
334   equal and will instead treat them as wildcards.
335
336An example:
337
338.. code:: python
339
340   from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
341
342   class LargeModel(torch.nn.Module):
343       def __init__(self):
344           super().__init__()
345           self._weight = torch.nn.Parameter(torch.ones(3, 3))
346           self._bias = torch.nn.Parameter(torch.ones(3, 3))
347
348       def forward(self, x):
349           return torch.ops.aten.addmm.default(self._bias, x, self._weight)
350
351   large_model_graph = torch.export(LargeModel(), inputs).graph
352
353   class PatternModel(torch.nn.Module):
354       def __init__(self):
355           super().__init__()
356           self._weight_1 = torch.nn.Parameter(torch.ones(5, 5))
357           self._bias_1 = torch.nn.Parameter(torch.ones(5, 5))
358
359       def forward(self, x):
360           return torch.ops.aten.addmm.default(self._bias_1, x, self._weight_1)
361
362   pattern_graph = torch.export(PatternModel(), inputs).graph
363
364   subgraph_matcher = SubgraphMatcher(pattern_graph)
365   match_result = subgraph_matcher.match(large_model_graph)
366
367The ``match`` function returns a list of ``InternalMatch``:
368
369.. code:: python
370
371   @dataclass
372   class InternalMatch():
373       # Nodes from which the match was found
374       anchors: List[Node]
375       # Maps nodes in the pattern subgraph to nodes in the larger graph
376       nodes_map: Dict[Node, Node] = field(default_factory=dict)
377       # Nodes in target graph that are matched placeholder in pattern
378       placeholder_nodes: List[Node] = field(default_factory=list)
379       # Nodes in matched subgraph returned by output
380       returning_nodes: List[Node] = field(default_factory=list)
381
382Capability Based Partitioner
383~~~~~~~~~~~~~~~~~~~~~~~~~~~~
384
385To find the largest subgraphs of nodes that support a specific
386invariant, we can utilize FX’s
387```CapabilityBasedPartitioner`` <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/infra/partitioner.py#L34>`__.
388
389Class Attributes
390
391-  ``graph_module (torch.fx.GraphModule)``: The graph module we are
392   partitioning on.
393-  ``operator_support (OperatorSupportBase)``: The object used to
394   determine if a node in the graph is supported in the partition.
395-  ``allows_single_node_partition (bool)``: If True, allows single node
396   partitions to be formed.
397-  ``non_compute_ops (Optional[Sequence[str]])``: A set of ops that are
398   considered to be “non-compute” (ex ``torch.ops.aten.view`` and
399   ``_operator.getitem``, so that the partitioner will not create graphs
400   that only contain these non-compute ops
401-  ``allowed_single_node_partition_ops (Optional[Sequence[str]])``: A
402   set of ops that are allowed to be in a single node partition.
403
404The
405```OperatorSupportBase`` <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/operator_support.py#LL28C1-L28C1>`__
406class is used by the partitioner to determine if a specific node in the
407graph belongs in the partition. This is done by overriding the
408``is_node_supported`` function. You can chain multiple
409``OperatorSupportBase`` by using
410```chain`` <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/operator_support.py#L150>`__\ (which
411returns False if any of the OperatorSupportBase return False) and
412```any_chain`` <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/operator_support.py#L164>`__
413(which returns True if any of the OperatorSupportBase returns True).
414
415An example:
416
417.. code:: python
418
419   from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
420   from torch.fx.passes.operator_support import any_chain, OperatorSupportBase
421
422   class AddMulOperatorSupport(OperatorSupportBase):
423       def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
424           return node.op == "call_function" and node.target in [
425               torch.ops.aten.add.Tensor, torch.ops.aten.mul.Tensor,
426           ]
427
428   capability_partitioner = CapabilityBasedPartitioner(
429       graph_module,
430       op_support,
431   )
432
433   # Returns a list of partitions (list of nodes that belong in each partition)
434   partition_list = capability_partitioner.propose_partitions()
435   # Fuses the partitions into graph modules and inserts `call_module` nodes in the graph
436   fused_graph_module = capability_partitioner.fuse_partitions(partition_list)
437