1# mypy: allow-untyped-defs 2from typing import List, Tuple, Union, Dict, Any, Set, Mapping, Optional 3import collections 4from dataclasses import dataclass 5import operator 6 7import torch 8import torch.fx 9from torch.fx.node import _get_qualified_name 10from torch.fx._compatibility import compatibility 11 12__all__ = ['get_acc_ops_name', 'get_node_target', 'is_node_output_tensor', 'FxNetAccFusionsFinder', 'legalize_graph'] 13 14Tensors = Union[Tuple[torch.Tensor], List[torch.Tensor]] 15TensorOrTensors = Union[torch.Tensor, Tensors] 16NodeList = List[torch.fx.Node] 17NodeSet = Set[torch.fx.Node] 18Names = List[str] 19CALLABLE_NODE_OPS = {"call_module", "call_function", "call_method"} 20 21 22@compatibility(is_backward_compatible=False) 23def get_acc_ops_name(k): 24 if isinstance(k, str): 25 return k 26 elif k.__module__ and "acc_ops" in k.__module__: 27 return f"acc_ops.{k.__name__}" 28 else: 29 module = k.__module__.replace('torch._ops', 'torch.ops') # WAR for bug in how torch.ops assigns module 30 return f"{module if module else ''}.{k.__name__}" 31 32 33@compatibility(is_backward_compatible=False) 34def get_node_target(submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node) -> str: 35 """ 36 Given a `node` returns its target typename. 37 38 For "call_method" node, return node.target which is the name of that method being called. 39 This could potential lead to conflict but should be okay because normally it's on a tensor. 40 41 For "call_function" node, return typename of node.target. 42 43 For "call_module" node, return typename of the module that node.target point to. 44 45 If seeing "_VariableFunctionsClass" in the target name string, it will be replaced by 46 "torch". e.g. _VariableFunctionsClass.relu would become torch.relu. 47 """ 48 49 assert node.op in CALLABLE_NODE_OPS, ( 50 "Expect op types of " + ", ".join(CALLABLE_NODE_OPS) + f", but found {node.op}" 51 ) 52 53 if node.op == "call_module": 54 assert isinstance(node.target, str) 55 submod = submodules[node.target] 56 submod_type = getattr(submod, "_base_class_origin", type(submod)) 57 return get_acc_ops_name(submod_type) 58 elif node.op == "call_function": 59 target: Any = node.target 60 return ( 61 f"acc_ops.{target.__name__}" 62 if target.__module__ is not None and "acc_ops" in target.__module__ 63 else _get_qualified_name(target) 64 ) 65 else: 66 assert isinstance(node.target, str) 67 return node.target 68 69@compatibility(is_backward_compatible=False) 70def is_node_output_tensor(node: torch.fx.Node) -> bool: 71 """Checks if the node output produces a Tensor or not. 72 73 NOTE: This requires to run `ShapeProp` on the containing fx graph before 74 calling this function. This is because it works by checking the `type` 75 metadata on the node. This metadata is produced by the `ShapeProp`. 76 """ 77 type_ = node.meta.get("type", None) 78 return type_ is not None and issubclass(type_, torch.Tensor) 79 80@compatibility(is_backward_compatible=False) 81class FxNetAccFusionsFinder: 82 """ 83 Finds groups of connected ACC nodes that pass non-tensor data between each other. 84 Such groups are called fusion groups. 85 """ 86 87 def __init__(self, module: torch.fx.GraphModule, acc_nodes: NodeSet): 88 self.module = module 89 self.nodes = list(module.graph.nodes) 90 self.acc_nodes = acc_nodes 91 92 @dataclass 93 class FusionGroup: 94 # The smallest idx of nodes in the fusion group after topological sorting all the nodes in the model. 95 top_node_idx: int 96 97 # Nodes in this fusion group. 98 nodes: NodeSet 99 100 # Inputs to this fusion group. 101 inputs: NodeSet 102 103 # Nodes that in the fusion group that haven't been processed yet. 104 nodes_need_process: NodeSet 105 106 def add_node(self, node): 107 """ 108 Add a node to fusion group. 109 """ 110 if node in self.nodes: 111 return 112 113 self.nodes_need_process.add(node) 114 self.nodes.add(node) 115 self.inputs.discard(node) 116 self.inputs.update( 117 { 118 n 119 for n in node.all_input_nodes 120 if n.op in CALLABLE_NODE_OPS and n not in self.nodes 121 } 122 ) 123 124 def recursive_add_node( 125 self, 126 fusion_group: "FxNetAccFusionsFinder.FusionGroup", 127 inputs: Union[NodeSet, NodeList], 128 visited: Optional[NodeSet] = None, 129 ): 130 """ 131 Start from inputs and going reverse topological order. If any upstream node 132 is in the fusion group, add all the nodes in this path to fusion group. 133 """ 134 for arg in inputs: 135 # skip the node if already seen 136 if visited is not None: 137 if arg in visited: 138 continue 139 visited.add(arg) 140 141 # Skip placeholder and get_attr because they won't be in the fusion group. 142 if arg.op not in CALLABLE_NODE_OPS: 143 continue 144 145 # If the node has smaller idx, it's already an upstream node of the fusion 146 # group. We don't need to check it anymore. 147 if self.nodes.index(arg) < fusion_group.top_node_idx: 148 continue 149 150 # If the node is in the fusion group, return True. 151 if arg in fusion_group.nodes: 152 return True 153 154 # Check the upstream nodes of the node, if any of them is in the fusion group 155 # we'll add this node to fusion group and return True. 156 if self.recursive_add_node(fusion_group, arg.all_input_nodes, visited): 157 fusion_group.add_node(arg) 158 return True 159 160 return False 161 162 def __call__(self) -> Dict[torch.fx.Node, NodeSet]: 163 result: Dict[torch.fx.Node, NodeSet] = {} 164 acc_nodes = list(self.acc_nodes) 165 166 for node in acc_nodes: 167 if node in result: 168 continue 169 if node.op not in CALLABLE_NODE_OPS: 170 continue 171 if "tensor_meta" in node.meta: 172 continue 173 if node not in self.acc_nodes: 174 continue 175 176 fusion_group: FxNetAccFusionsFinder.FusionGroup = self.FusionGroup( 177 top_node_idx=self.nodes.index(node), 178 nodes={node}, 179 inputs=set(node.all_input_nodes), 180 nodes_need_process={node}, 181 ) 182 while fusion_group.nodes_need_process: 183 node = fusion_group.nodes_need_process.pop() 184 self.recursive_add_node( 185 fusion_group, 186 fusion_group.inputs, 187 visited=set(), 188 ) 189 190 # Optionally add downstream nodes 191 if "tensor_meta" not in node.meta: 192 for user in node.users: 193 if user.op not in CALLABLE_NODE_OPS: 194 continue 195 if user in fusion_group.nodes: 196 continue 197 198 fusion_group.add_node(user) 199 self.recursive_add_node( 200 fusion_group, 201 fusion_group.inputs, 202 visited=set(), 203 ) 204 205 # Add some upstream nodes 206 for arg in node.all_input_nodes: 207 if arg.op not in CALLABLE_NODE_OPS: 208 continue 209 if "tensor_meta" in arg.meta: 210 continue 211 if arg in fusion_group.nodes: 212 continue 213 214 fusion_group.add_node(arg) 215 fusion_group.top_node_idx = min( 216 fusion_group.top_node_idx, self.nodes.index(arg) 217 ) 218 self.recursive_add_node( 219 fusion_group, 220 fusion_group.inputs, 221 visited=set(), 222 ) 223 224 if not (set(fusion_group.nodes) <= self.acc_nodes): 225 self.acc_nodes -= fusion_group.nodes 226 else: 227 for n in fusion_group.nodes: 228 result[n] = fusion_group.nodes 229 230 return result 231 232 233@compatibility(is_backward_compatible=False) 234def legalize_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: 235 """ 236 Replace the graph of the given GraphModule with one that contains the same nodes as the 237 original, but in topologically sorted order. 238 239 This is used by the merge_matmul transformation below, which disturbs the topologically sorted 240 order of its input GraphModule, so that this order is restored before further transformation. 241 242 Arguments: 243 gm: The graph module to topologically sort. It is modified in-place. 244 245 Returns: 246 The graph module in-place sorted 247 """ 248 249 # These operators are used for making runtime assertions before any 250 # data-dependent operators occur. We want to prioritize sorting these to 251 # ensure that these assertions appear before any data-dependent operations 252 # in the graph. 253 PRIORITIZED_OPS = [ 254 operator.add, 255 operator.mul, 256 operator.sub, 257 operator.floordiv, 258 operator.truediv, 259 operator.mod, 260 operator.le, 261 operator.lt, 262 operator.ge, 263 operator.gt, 264 operator.eq, 265 operator.ne, 266 torch.ops.aten.sym_constrain_range.default, 267 torch.ops.aten.sym_constrain_range_for_size.default, 268 torch.ops.aten._assert_async.msg, 269 torch.ops.aten.scalar_tensor.default, 270 torch.ops.aten._assert_scalar.default, 271 ] 272 273 indeg = dict.fromkeys(gm.graph.nodes, 0) 274 new_graph = torch.fx.Graph() 275 # Track how many unfulfilled dependencies each node has 276 for node in gm.graph.nodes: 277 for user in node.users: 278 indeg[user] += 1 279 queue: collections.deque = collections.deque() 280 # Add all nodes with no dependencies to the queue 281 for node in gm.graph.nodes: 282 if indeg[node] == 0: 283 queue.append(node) 284 env: Dict[torch.fx.Node, torch.fx.Node] = {} 285 # Pop nodes from the queue, and add nodes that have had all their 286 # dependencies fulfilled 287 while len(queue) > 0: 288 cur = queue.popleft() 289 env[cur] = new_graph.node_copy(cur, lambda x: env[x]) 290 for user in cur.users: 291 indeg[user] -= 1 292 if indeg[user] == 0: 293 if user.op == "call_function" and user.target in PRIORITIZED_OPS: 294 queue.appendleft(user) 295 else: 296 queue.append(user) 297 # If the new graph's size is not as large as the old one, then there must be 298 # a cycle (i.e. some node's dependencies were not satisfied.) 299 if len(new_graph.nodes) < len(gm.graph.nodes): 300 raise RuntimeError(f"Input graph has cycles, unable to add {[node for node in indeg if indeg[node] != 0]}") 301 new_graph._codegen = gm.graph._codegen 302 gm.graph = new_graph 303 return gm 304