xref: /aosp_15_r20/external/pytorch/torch/fx/experimental/merge_matmul.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3
4from torch.fx.node import Node
5from torch.fx._symbolic_trace import symbolic_trace
6from torch.fx.passes.tools_common import legalize_graph
7import itertools
8import operator
9
10from typing import Dict, List, Tuple
11
12
13def split_result_tensors(
14    result: torch.Tensor, inputs: List[torch.Tensor]
15) -> Tuple[torch.Tensor, ...]:
16    """
17    A free function for use in the merge_matmul graph transformation below that
18    splits the output from a merged matmul into the individual results for each
19    input tensor.
20
21    Arguments:
22        result: The merged matmul result tensor.
23        inputs: The list of inputs that were merged into one for the matmul.
24
25    Returns:
26        List of matmul results for each input tensor.
27    """
28    # When fx tracer is running, x.shape[0] will be torch.fx.Attribute but we
29    # need an int even when tracing
30    if isinstance(result, torch.fx.Proxy):
31        splits = [0] * len(inputs)
32    else:
33        splits = [x.shape[0] for x in inputs]
34
35    return torch.split(result, splits)
36
37
38def may_depend_on(a: Node, b: Node, search_depth: int = 6):
39    """
40    Determine if one node depends on another in a torch.fx.Graph.
41
42    Arguments:
43        a: The node that may have a dependency on b.
44        b: The node that a may have a dependency on.
45        search_depth: In the case of an indirect dependency, this function
46                        searches upto this many nodes away in search of a
47                        data dependency. If none is found, the function
48                        makes the conservative assumption that there is a
49                        dependency.
50
51    Returns:
52        True if a may depend on b, False if it definitely does not.
53    """
54    # Equivalence is defined as dependence.
55    if a == b:
56        return True
57
58    # If a has no inputs, it cannot depend on b.
59    if len(a.all_input_nodes) == 0:
60        return False
61
62    # If the search depth has been exhausted and no conclusion has been
63    # reached, assume that there is a data dependency.
64    if search_depth == 0:
65        return True
66
67    # Recursively check all inputs of a.
68    for inp in a.all_input_nodes:
69        if may_depend_on(inp, b, search_depth - 1):
70            return True
71
72    return False
73
74
75def are_nodes_independent(nodes: List[Node]):
76    """
77    Check if all of the given nodes are pairwise-data independent.
78
79    Arguments:
80        nodes: The nodes to check for data dependencies.
81
82    Returns:
83        True if any pair in nodes has a data dependency.
84    """
85    # For each pair in nodes:
86    for i, j in itertools.combinations(nodes, 2):
87        if may_depend_on(i, j) or may_depend_on(j, i):
88            return False
89
90    return True
91
92
93def merge_matmul(in_mod: torch.nn.Module):
94    """
95    A graph transformation that merges matrix multiplication operations that share the same right-hand
96    side operand into one large matrix multiplication.
97               ____      _________        _________
98      ----    |    |    |         |     M|  A * C  |
99    M| A  |  T| B  | * K|    C    | =    |---------|
100      ---- ,  |    |    |         |     T|  B * C  |
101       K       ----      ---------        ---------
102                K            R                R
103    """
104    gm = symbolic_trace(in_mod)
105
106    rhs_users: Dict[Node, List[Node]] = {}
107    lhs_users: Dict[Node, List[Node]] = {}
108
109    # Populate rhs_users and lhs_users - maps from LHS/RHS matrix multiply operands to
110    # the matmul of which they are the LHS/RHS.
111    for node in gm.graph.nodes:
112        if node.op != "call_function" or node.target is not torch.matmul:
113            continue
114
115        lhs, rhs = node.args
116
117        # TODO: Properly handle aliasing caused by get_attr. For now,
118        # use the attribute name as the operand if the node is a
119        # get_attr.
120        lhs = lhs.target if lhs.op == "get_attr" else lhs
121        rhs = rhs.target if rhs.op == "get_attr" else rhs
122
123        lhs_users.setdefault(lhs, []).append(node)
124        rhs_users.setdefault(rhs, []).append(node)
125
126    for rhs, mms in rhs_users.items():
127        # There must be at least matmuls for a merge to make sense.
128        if len(mms) < 2:
129            continue
130
131        # All matmuls must not depend on each other directly or indirectly
132        # in order for the merge to be possible.
133        if not are_nodes_independent(mms):
134            continue
135
136        lhs_vals = [mm.args[0] for mm in mms]
137
138        # Merge the matmul.
139        # Collect a list of LHS operands and the single RHS operand.
140        lhs = [gm.graph.get_attr(l) if isinstance(l, str) else l for l in lhs_vals]
141        rhs = gm.graph.get_attr(rhs) if isinstance(rhs, str) else rhs
142
143        # Concatenate all the LHS operands.
144        merge_mm_cat = gm.graph.call_function(torch.cat, (lhs,), {})
145
146        # Multiply the concatenated LHS operands with the one RHS. This will produce
147        # the same results as all the individual matmuls involving rhs in the original graph,
148        # but they will all be concatenated together.
149        merge_mm = gm.graph.call_function(torch.matmul, (merge_mm_cat, rhs,), {})
150
151        # Split the result of the merged matmul using the shapes of the LHS operands
152        # to ascertain how large each chunk should be.
153        merge_mm_split = gm.graph.call_function(
154            split_result_tensors, (merge_mm, lhs), {}
155        )
156        merge_mm_res = [
157            gm.graph.call_function(operator.getitem, (merge_mm_split, out), {})
158            for out in range(len(lhs))
159        ]
160
161        # Replace all uses of the original, unmerged matmuls with the equivalent split chunk from the merged matmul.
162        for old, new in zip(mms, merge_mm_res):
163            old.replace_all_uses_with(new)
164            gm.graph.erase_node(old)
165
166        # All of the new nodes created above were inserted at the end, so we need to sort
167        # the nodes topologically to make sure all definitions precede uses.
168        legalize_graph(gm)
169
170    gm.recompile()
171    gm.graph.lint()
172    return gm
173