1# mypy: allow-untyped-defs 2import torch 3 4from torch.fx.node import Node 5from torch.fx._symbolic_trace import symbolic_trace 6from torch.fx.passes.tools_common import legalize_graph 7import itertools 8import operator 9 10from typing import Dict, List, Tuple 11 12 13def split_result_tensors( 14 result: torch.Tensor, inputs: List[torch.Tensor] 15) -> Tuple[torch.Tensor, ...]: 16 """ 17 A free function for use in the merge_matmul graph transformation below that 18 splits the output from a merged matmul into the individual results for each 19 input tensor. 20 21 Arguments: 22 result: The merged matmul result tensor. 23 inputs: The list of inputs that were merged into one for the matmul. 24 25 Returns: 26 List of matmul results for each input tensor. 27 """ 28 # When fx tracer is running, x.shape[0] will be torch.fx.Attribute but we 29 # need an int even when tracing 30 if isinstance(result, torch.fx.Proxy): 31 splits = [0] * len(inputs) 32 else: 33 splits = [x.shape[0] for x in inputs] 34 35 return torch.split(result, splits) 36 37 38def may_depend_on(a: Node, b: Node, search_depth: int = 6): 39 """ 40 Determine if one node depends on another in a torch.fx.Graph. 41 42 Arguments: 43 a: The node that may have a dependency on b. 44 b: The node that a may have a dependency on. 45 search_depth: In the case of an indirect dependency, this function 46 searches upto this many nodes away in search of a 47 data dependency. If none is found, the function 48 makes the conservative assumption that there is a 49 dependency. 50 51 Returns: 52 True if a may depend on b, False if it definitely does not. 53 """ 54 # Equivalence is defined as dependence. 55 if a == b: 56 return True 57 58 # If a has no inputs, it cannot depend on b. 59 if len(a.all_input_nodes) == 0: 60 return False 61 62 # If the search depth has been exhausted and no conclusion has been 63 # reached, assume that there is a data dependency. 64 if search_depth == 0: 65 return True 66 67 # Recursively check all inputs of a. 68 for inp in a.all_input_nodes: 69 if may_depend_on(inp, b, search_depth - 1): 70 return True 71 72 return False 73 74 75def are_nodes_independent(nodes: List[Node]): 76 """ 77 Check if all of the given nodes are pairwise-data independent. 78 79 Arguments: 80 nodes: The nodes to check for data dependencies. 81 82 Returns: 83 True if any pair in nodes has a data dependency. 84 """ 85 # For each pair in nodes: 86 for i, j in itertools.combinations(nodes, 2): 87 if may_depend_on(i, j) or may_depend_on(j, i): 88 return False 89 90 return True 91 92 93def merge_matmul(in_mod: torch.nn.Module): 94 """ 95 A graph transformation that merges matrix multiplication operations that share the same right-hand 96 side operand into one large matrix multiplication. 97 ____ _________ _________ 98 ---- | | | | M| A * C | 99 M| A | T| B | * K| C | = |---------| 100 ---- , | | | | T| B * C | 101 K ---- --------- --------- 102 K R R 103 """ 104 gm = symbolic_trace(in_mod) 105 106 rhs_users: Dict[Node, List[Node]] = {} 107 lhs_users: Dict[Node, List[Node]] = {} 108 109 # Populate rhs_users and lhs_users - maps from LHS/RHS matrix multiply operands to 110 # the matmul of which they are the LHS/RHS. 111 for node in gm.graph.nodes: 112 if node.op != "call_function" or node.target is not torch.matmul: 113 continue 114 115 lhs, rhs = node.args 116 117 # TODO: Properly handle aliasing caused by get_attr. For now, 118 # use the attribute name as the operand if the node is a 119 # get_attr. 120 lhs = lhs.target if lhs.op == "get_attr" else lhs 121 rhs = rhs.target if rhs.op == "get_attr" else rhs 122 123 lhs_users.setdefault(lhs, []).append(node) 124 rhs_users.setdefault(rhs, []).append(node) 125 126 for rhs, mms in rhs_users.items(): 127 # There must be at least matmuls for a merge to make sense. 128 if len(mms) < 2: 129 continue 130 131 # All matmuls must not depend on each other directly or indirectly 132 # in order for the merge to be possible. 133 if not are_nodes_independent(mms): 134 continue 135 136 lhs_vals = [mm.args[0] for mm in mms] 137 138 # Merge the matmul. 139 # Collect a list of LHS operands and the single RHS operand. 140 lhs = [gm.graph.get_attr(l) if isinstance(l, str) else l for l in lhs_vals] 141 rhs = gm.graph.get_attr(rhs) if isinstance(rhs, str) else rhs 142 143 # Concatenate all the LHS operands. 144 merge_mm_cat = gm.graph.call_function(torch.cat, (lhs,), {}) 145 146 # Multiply the concatenated LHS operands with the one RHS. This will produce 147 # the same results as all the individual matmuls involving rhs in the original graph, 148 # but they will all be concatenated together. 149 merge_mm = gm.graph.call_function(torch.matmul, (merge_mm_cat, rhs,), {}) 150 151 # Split the result of the merged matmul using the shapes of the LHS operands 152 # to ascertain how large each chunk should be. 153 merge_mm_split = gm.graph.call_function( 154 split_result_tensors, (merge_mm, lhs), {} 155 ) 156 merge_mm_res = [ 157 gm.graph.call_function(operator.getitem, (merge_mm_split, out), {}) 158 for out in range(len(lhs)) 159 ] 160 161 # Replace all uses of the original, unmerged matmuls with the equivalent split chunk from the merged matmul. 162 for old, new in zip(mms, merge_mm_res): 163 old.replace_all_uses_with(new) 164 gm.graph.erase_node(old) 165 166 # All of the new nodes created above were inserted at the end, so we need to sort 167 # the nodes topologically to make sure all definitions precede uses. 168 legalize_graph(gm) 169 170 gm.recompile() 171 gm.graph.lint() 172 return gm 173