1# mypy: allow-untyped-defs 2import torch.fx as fx 3from torch.fx.node import Argument, Target 4from torch.nn.utils.fusion import fuse_conv_bn_eval 5from typing import Type, Dict, Any, Tuple, Iterable, Optional, List, cast 6import torch 7import torch.nn as nn 8import torch.nn.functional as F 9from torch.fx.passes.shape_prop import ShapeProp 10import copy 11from collections import defaultdict 12import torch.utils.mkldnn as th_mkldnn 13import operator 14import time 15import logging 16from enum import Enum 17 18def _parent_name(target : str) -> Tuple[str, str]: 19 """ 20 Splits a qualname into parent path and last atom. 21 For example, `foo.bar.baz` -> (`foo.bar`, `baz`) 22 """ 23 *parent, name = target.rsplit('.', 1) 24 return parent[0] if parent else '', name 25 26# Works for length 2 patterns with 2 modules 27def matches_module_pattern(pattern: Iterable[Type], node: fx.Node, modules: Dict[str, Any]): 28 if len(node.args) == 0: 29 return False 30 nodes: Tuple[Any, fx.Node] = (node.args[0], node) 31 for expected_type, current_node in zip(pattern, nodes): 32 if not isinstance(current_node, fx.Node): 33 return False 34 if current_node.op != 'call_module': 35 return False 36 if not isinstance(current_node.target, str): 37 return False 38 if current_node.target not in modules: 39 return False 40 if type(modules[current_node.target]) is not expected_type: 41 return False 42 return True 43 44 45def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module): 46 assert isinstance(node.target, str) 47 parent_name, name = _parent_name(node.target) 48 modules[node.target] = new_module 49 setattr(modules[parent_name], name, new_module) 50 51def fuse(model: torch.nn.Module, inplace=False, no_trace=False) -> torch.nn.Module: 52 """ 53 Fuses convolution/BN layers for inference purposes. Will deepcopy your 54 model by default, but can modify the model inplace as well. 55 """ 56 patterns = [(nn.Conv1d, nn.BatchNorm1d), 57 (nn.Conv2d, nn.BatchNorm2d), 58 (nn.Conv3d, nn.BatchNorm3d)] 59 if not inplace: 60 model = copy.deepcopy(model) 61 if not no_trace or not isinstance(model, torch.fx.GraphModule): 62 fx_model = fx.symbolic_trace(model) 63 else: 64 fx_model = model 65 modules = dict(fx_model.named_modules()) 66 new_graph = copy.deepcopy(fx_model.graph) 67 68 for pattern in patterns: 69 for node in new_graph.nodes: 70 if matches_module_pattern(pattern, node, modules): 71 if len(node.args[0].users) > 1: # Output of conv is used by other nodes 72 continue 73 conv = modules[node.args[0].target] 74 bn = modules[node.target] 75 if not bn.track_running_stats: 76 continue 77 fused_conv = fuse_conv_bn_eval(conv, bn) 78 replace_node_module(node.args[0], modules, fused_conv) 79 node.replace_all_uses_with(node.args[0]) 80 new_graph.erase_node(node) 81 return fx.GraphModule(fx_model, new_graph) 82 83def remove_dropout(model: nn.Module) -> nn.Module: 84 """ 85 Removes all dropout layers from the module. 86 """ 87 fx_model = fx.symbolic_trace(model) 88 89 class DropoutRemover(torch.fx.Transformer): 90 def call_module(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: 91 if isinstance(self.submodules[target], nn.Dropout): 92 assert len(args) == 1 93 return args[0] 94 else: 95 return super().call_module(target, args, kwargs) 96 return DropoutRemover(fx_model).transform() 97 98def extract_subgraph(orig_module: nn.Module, nodes: List[fx.Node], inputs: List[fx.Node], outputs: List[fx.Node]): 99 """ 100 Given lists of nodes from an existing graph that represent a subgraph, returns a submodule that executes that subgraph. 101 """ 102 new_graph = fx.Graph() 103 env: Dict[fx.Node, fx.Node] = {} 104 for input in inputs: 105 new_node = new_graph.placeholder(input.name) 106 env[input] = new_node 107 for node in nodes: 108 new_node = new_graph.node_copy(node, lambda x: env[x]) 109 env[node] = new_node 110 new_graph.output([env[output] for output in outputs]) 111 new_graph.lint() 112 return fx.GraphModule(orig_module, new_graph) 113 114mkldnn_supported = [ 115 nn.Conv2d, nn.Linear, nn.BatchNorm2d, nn.ReLU, nn.MaxPool2d, nn.AvgPool2d, nn.AdaptiveAvgPool2d, 116 torch.relu, torch.transpose, torch.sigmoid, 117 F.relu, F.avg_pool2d, F.adaptive_avg_pool2d 118] 119# These are operators that may not be convertible into MKLDNN ops (e.g. the 120# args are scalar values). Thus, we only include them in the subgraph if their 121# arguments are already in MKLDNN. 122# TODO: Determine whether this can be removed after type inference. 123mkldnn_supported_unknown = [operator.add, operator.mul] 124mkldnn_map = { 125 nn.Conv2d: th_mkldnn.MkldnnConv2d, 126 nn.Linear: th_mkldnn.MkldnnLinear, 127 nn.BatchNorm2d: lambda a, _: th_mkldnn.MkldnnBatchNorm(a) 128} 129 130 131def modules_to_mkldnn(nodes: List[fx.Node], modules: Dict[str, nn.Module]): 132 """ 133 For each node, if it's a module that can be preconverted into MKLDNN, 134 then we do so and create a mapping to allow us to convert from the MKLDNN 135 version of the module to the original. 136 """ 137 old_modules: Dict[nn.Module, nn.Module] = {} 138 for node in nodes: 139 if node.op == 'call_module': 140 assert isinstance(node.target, str) 141 cur_module = modules[node.target] 142 if type(cur_module) in mkldnn_map: 143 new_module = mkldnn_map[type(cur_module)](cur_module, torch.float) 144 assert isinstance(new_module, nn.Module) 145 old_modules[new_module] = copy.deepcopy(cur_module) 146 replace_node_module(node, modules, new_module) 147 return old_modules 148 149def reset_modules(nodes: List[fx.Node], modules: Dict[str, nn.Module], old_modules: Dict[nn.Module, nn.Module]): 150 """ 151 Maps each module that's been changed with `modules_to_mkldnn` back to its 152 original. 153 """ 154 for node in nodes: 155 if node.op == 'call_module': 156 assert (isinstance(node.target, str)) 157 cur_module = modules[node.target] 158 if cur_module in old_modules: 159 replace_node_module(node, modules, old_modules[cur_module]) 160 161class MklSubgraph: 162 def __init__(self, fx_graph: fx.Graph): 163 self.fx_graph = fx_graph 164 self.nodes: List[fx.Node] = [] 165 self.start_nodes: List[fx.Node] = [] 166 self.end_nodes: List[fx.Node] = [] 167 168def gen_mkl_autotuner(example_inputs, iters=10, warmup=1): 169 """ 170 This generates a heuristic that can be passed into `optimize_for_inference` that 171 determines whether a subgraph should be run in MKL by running it with the example_inputs. 172 173 Example usage: 174 heuristic = gen_mkl_autotuner(example_inputs, iters=10) 175 fast_model = optimization.optimize_for_inference(model, heuristic) 176 """ 177 fx_model = None 178 old_modules = None 179 180 def use_mkl_heuristic(graph: MklSubgraph) -> bool: 181 nonlocal fx_model, old_modules 182 input_nodes = graph.start_nodes 183 if fx_model is None: 184 fx_model = graph.fx_graph.owning_module 185 old_modules = graph.fx_graph.old_modules # type: ignore[attr-defined] 186 ShapeProp(fx_model).propagate(example_inputs) 187 sample_inputs = [torch.randn(node.shape) for node in input_nodes] # type: ignore[attr-defined] 188 output_args = cast(List[fx.Node], [node.args[0] for node in graph.end_nodes]) 189 submodule = extract_subgraph(fx_model, graph.nodes, input_nodes, output_args) 190 191 def benchmark(f): 192 for _ in range(warmup): 193 f() 194 begin = time.time() 195 for _ in range(iters): 196 out = f() 197 return time.time() - begin 198 199 mkl_time = benchmark(lambda: [i.to_dense() for i in submodule(*[i.to_mkldnn() for i in sample_inputs])]) 200 201 reset_modules(submodule.graph.nodes, dict(submodule.named_modules()), old_modules) 202 no_mkl_time = benchmark(lambda: submodule(*sample_inputs)) 203 return mkl_time < no_mkl_time 204 return use_mkl_heuristic 205 206def use_mkl_length(graph: MklSubgraph) -> bool: 207 """ 208 This is a heuristic that can be passed into `optimize_for_inference` that 209 determines whether a subgraph should be run in MKL by checking if there 210 are more than 2 nodes in it 211 """ 212 return len(graph.nodes) > 2 213 214class UnionFind: 215 def __init__(self, n): 216 self.parent: List[Optional[int]] = [None] * n 217 self.size: List[int] = [0] * n 218 219 def make_set(self, v: int): 220 self.parent[v] = v 221 self.size[v] = 1 222 223 def find(self, v: int) -> int: 224 par = self.parent[v] 225 if v == par: 226 return v 227 assert par is not None 228 self.parent[v] = self.find(par) 229 return cast(int, self.parent[v]) 230 231 def join(self, a: int, b: int): 232 a, b = self.find(a), self.find(b) 233 if a == b: 234 return a 235 if self.size[a] < self.size[b]: 236 a, b = b, a 237 self.parent[b] = a 238 self.size[a] += self.size[b] 239 240def optimize_for_inference( 241 model: torch.nn.Module, 242 pass_config: Optional[Dict[str, Any]] = None, 243 tracer: Type[fx.Tracer] = fx.Tracer 244) -> torch.nn.Module: 245 """ 246 Performs a set of optimization passes to optimize a model for the 247 purposes of inference. Specifically, the passes that are run are: 248 1. Conv/BN fusion 249 2. Dropout removal 250 3. MKL layout optimizations 251 252 The third optimization takes a function `use_mkl_heuristic` that's used 253 to determine whether a subgraph should be explicitly run in MKL layout. 254 255 Note: As FX does not currently handle aliasing, this pass currently 256 assumes nothing aliases. If that isn't true, use at your own risk. 257 """ 258 default_pass_config = { 259 "conv_bn_fuse": True, 260 "remove_dropout": True, 261 "mkldnn_layout_optimize": {'heuristic': use_mkl_length}, 262 } 263 if pass_config is None: 264 pass_config = {} 265 default_pass_config.update(pass_config) 266 267 if default_pass_config["conv_bn_fuse"]: 268 model = fuse(model) 269 if default_pass_config["remove_dropout"]: 270 model = remove_dropout(model) 271 if default_pass_config["mkldnn_layout_optimize"] is False: 272 return model 273 if not isinstance(default_pass_config["mkldnn_layout_optimize"], dict): 274 raise RuntimeError("mkldnn_layout_optimize config is not a dict") 275 if "heuristic" not in default_pass_config["mkldnn_layout_optimize"]: 276 raise RuntimeError("Heuristic not found in mkldnn_layout_optimize config") 277 use_mkl_heuristic = default_pass_config["mkldnn_layout_optimize"]["heuristic"] 278 279 cur_tracer = tracer() 280 fx_graph = cur_tracer.trace(copy.deepcopy(model)) 281 fx_model = fx.GraphModule(cur_tracer.root, fx_graph) 282 modules: Dict[str, nn.Module] = dict(model.named_modules()) 283 284 class MklSupport(Enum): 285 NO = 1 286 YES = 2 287 UNKNOWN = 3 288 289 # Inserts to_mkldnn and to_dense around every node we want to be a MKLDNN node. 290 # If the op is in `mkldnn_supported` then we always treat it as a MKLDNN node. 291 # However, if it's in `mkldnn_supported_unknown`, then we only treat it as 292 # a MKLDNN node if its inputs are MKLDNN nodes. 293 for node in list(fx_graph.nodes): 294 supports_mkldnn = MklSupport.NO 295 if node.op == 'call_module': 296 cur_module = modules[node.target] 297 if type(cur_module) in mkldnn_supported: 298 supports_mkldnn = MklSupport.YES 299 sample_parameter = next(cur_module.parameters(), None) 300 if sample_parameter is not None: 301 assert sample_parameter.dtype == torch.float, "this pass is only for torch.float modules" 302 assert sample_parameter.device == torch.device('cpu'), "this pass is only for CPU modules" 303 elif node.op == 'call_function': 304 if node.target in mkldnn_supported: 305 supports_mkldnn = MklSupport.YES 306 elif node.target in mkldnn_supported_unknown: 307 supports_mkldnn = MklSupport.UNKNOWN 308 309 if supports_mkldnn != MklSupport.NO: 310 if supports_mkldnn == MklSupport.UNKNOWN: 311 if not any(arg.target == 'to_dense' for arg in node.args): 312 continue 313 with fx_graph.inserting_before(node): 314 mkldnn_args = fx.map_arg(node.args, lambda n: fx_graph.call_method('to_mkldnn', (n, ))) 315 316 node.args = cast(Tuple[fx.node.Argument], mkldnn_args) 317 318 with fx_graph.inserting_after(node): 319 dense_x = fx_graph.create_node('call_method', 'to_dense', (node,)) 320 node.replace_all_uses_with(dense_x) 321 dense_x.args = (node,) 322 323 # Does pre-conversion of all modules into MKLDNN (when possible) 324 old_modules = modules_to_mkldnn(list(fx_graph.nodes), modules) 325 fx_graph.old_modules = old_modules # type: ignore[attr-defined] 326 327 # optimizes all a -> to_dense -> to_mkldnn -> b patterns into a -> b 328 for node in fx_graph.nodes: 329 if node.op == 'call_method' and node.target == 'to_dense': 330 prv_node = node.args[0] 331 users = list(node.users) 332 for user in users: 333 if user.op == 'call_method' and user.target == 'to_mkldnn': 334 user.replace_all_uses_with(prv_node) 335 fx_graph.erase_node(user) 336 if len(node.users) == 0: 337 fx_graph.erase_node(node) 338 339 340 num_nodes = len(fx_graph.nodes) 341 uf = UnionFind(num_nodes) 342 343 def get_color(n): 344 if hasattr(n, 'color'): # Current node is part of a MKL subgraph 345 return uf.find(n.color) 346 if hasattr(n, 'start_color'): # Current node is input to MKL subgraph 347 return uf.find(n.start_color) 348 return None 349 350 351 # This code is to find each MKLDNN subgraph. Each MKLDNN subgraph consists 352 # of input nodes (which are only `to_mkldnn` calls), output nodes 353 # (`to_dense` calls), and intermediate nodes, which are run entirely on 354 # MKLDNN layout tensors. 355 # 356 # Specifically, this code does a flood fill on a directed acyclic graph 357 # (DAG), starting from each possible "start node" (i.e: `to_mkldnn` nodes). 358 # If every node only had one input, this would be sufficient. However, in 359 # the case that a node has multiple inputs coming from different start 360 # nodes (i.e. colors), we need to join these 2 colors into 1. That's done 361 # using a Disjoint Set Union. 362 for cur_idx, node in enumerate(fx_graph.nodes): 363 if node.op == 'call_method' and node.target == 'to_mkldnn': 364 node.start_color = cur_idx 365 uf.make_set(cur_idx) 366 elif node.op == 'call_method' and node.target == 'to_dense': 367 assert get_color(node.args[0]) is not None 368 node.end_color = get_color(node.args[0]) 369 else: 370 cur_colors = [get_color(i) for i in node.all_input_nodes if isinstance(i, fx.Node) if get_color(i) is not None] 371 372 if len(cur_colors) == 0: 373 continue 374 assert not any(i is None for i in cur_colors) 375 cur_colors = sorted(cur_colors) 376 node.color = cur_colors[0] 377 for other_color in cur_colors[1:]: 378 uf.join(cur_colors[0], other_color) 379 380 381 mkldnn_graphs: Dict[int, MklSubgraph] = defaultdict(lambda: MklSubgraph(fx_graph)) 382 for node in fx_graph.nodes: 383 if hasattr(node, 'color'): 384 mkldnn_graphs[uf.find(node.color)].nodes.append(node) 385 if hasattr(node, 'start_color'): 386 mkldnn_graphs[uf.find(node.start_color)].start_nodes.append(node) 387 if hasattr(node, 'end_color'): 388 mkldnn_graphs[uf.find(node.end_color)].end_nodes.append(node) 389 390 391 # Now that we have all the subgraphs, we need to decide which MKLDNN 392 # subgraphs we actually want to keep in MKLDNN. 393 for graph in mkldnn_graphs.values(): 394 if not use_mkl_heuristic(graph): 395 for node in graph.start_nodes + graph.end_nodes: 396 prv = node.args[0] 397 node.replace_all_uses_with(prv) 398 fx_graph.erase_node(node) 399 reset_modules(graph.nodes, modules, old_modules) 400 401 mkldnn_conversions = 0 402 for node in fx_graph.nodes: 403 if node.target == 'to_mkldnn' or node.target == 'to_dense': 404 mkldnn_conversions += 1 405 406 logging.getLogger(__name__).info("mkldnn conversions: %s", mkldnn_conversions) 407 fx_graph.lint() 408 result = fx.GraphModule(model, fx_graph) 409 return result 410