xref: /aosp_15_r20/external/pytorch/torch/fx/experimental/optimization.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch.fx as fx
3from torch.fx.node import Argument, Target
4from torch.nn.utils.fusion import fuse_conv_bn_eval
5from typing import Type, Dict, Any, Tuple, Iterable, Optional, List, cast
6import torch
7import torch.nn as nn
8import torch.nn.functional as F
9from torch.fx.passes.shape_prop import ShapeProp
10import copy
11from collections import defaultdict
12import torch.utils.mkldnn as th_mkldnn
13import operator
14import time
15import logging
16from enum import Enum
17
18def _parent_name(target : str) -> Tuple[str, str]:
19    """
20    Splits a qualname into parent path and last atom.
21    For example, `foo.bar.baz` -> (`foo.bar`, `baz`)
22    """
23    *parent, name = target.rsplit('.', 1)
24    return parent[0] if parent else '', name
25
26# Works for length 2 patterns with 2 modules
27def matches_module_pattern(pattern: Iterable[Type], node: fx.Node, modules: Dict[str, Any]):
28    if len(node.args) == 0:
29        return False
30    nodes: Tuple[Any, fx.Node] = (node.args[0], node)
31    for expected_type, current_node in zip(pattern, nodes):
32        if not isinstance(current_node, fx.Node):
33            return False
34        if current_node.op != 'call_module':
35            return False
36        if not isinstance(current_node.target, str):
37            return False
38        if current_node.target not in modules:
39            return False
40        if type(modules[current_node.target]) is not expected_type:
41            return False
42    return True
43
44
45def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module):
46    assert isinstance(node.target, str)
47    parent_name, name = _parent_name(node.target)
48    modules[node.target] = new_module
49    setattr(modules[parent_name], name, new_module)
50
51def fuse(model: torch.nn.Module, inplace=False, no_trace=False) -> torch.nn.Module:
52    """
53    Fuses convolution/BN layers for inference purposes. Will deepcopy your
54    model by default, but can modify the model inplace as well.
55    """
56    patterns = [(nn.Conv1d, nn.BatchNorm1d),
57                (nn.Conv2d, nn.BatchNorm2d),
58                (nn.Conv3d, nn.BatchNorm3d)]
59    if not inplace:
60        model = copy.deepcopy(model)
61    if not no_trace or not isinstance(model, torch.fx.GraphModule):
62        fx_model = fx.symbolic_trace(model)
63    else:
64        fx_model = model
65    modules = dict(fx_model.named_modules())
66    new_graph = copy.deepcopy(fx_model.graph)
67
68    for pattern in patterns:
69        for node in new_graph.nodes:
70            if matches_module_pattern(pattern, node, modules):
71                if len(node.args[0].users) > 1:  # Output of conv is used by other nodes
72                    continue
73                conv = modules[node.args[0].target]
74                bn = modules[node.target]
75                if not bn.track_running_stats:
76                    continue
77                fused_conv = fuse_conv_bn_eval(conv, bn)
78                replace_node_module(node.args[0], modules, fused_conv)
79                node.replace_all_uses_with(node.args[0])
80                new_graph.erase_node(node)
81    return fx.GraphModule(fx_model, new_graph)
82
83def remove_dropout(model: nn.Module) -> nn.Module:
84    """
85    Removes all dropout layers from the module.
86    """
87    fx_model = fx.symbolic_trace(model)
88
89    class DropoutRemover(torch.fx.Transformer):
90        def call_module(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
91            if isinstance(self.submodules[target], nn.Dropout):
92                assert len(args) == 1
93                return args[0]
94            else:
95                return super().call_module(target, args, kwargs)
96    return DropoutRemover(fx_model).transform()
97
98def extract_subgraph(orig_module: nn.Module, nodes: List[fx.Node], inputs: List[fx.Node], outputs: List[fx.Node]):
99    """
100    Given lists of nodes from an existing graph that represent a subgraph, returns a submodule that executes that subgraph.
101    """
102    new_graph = fx.Graph()
103    env: Dict[fx.Node, fx.Node] = {}
104    for input in inputs:
105        new_node = new_graph.placeholder(input.name)
106        env[input] = new_node
107    for node in nodes:
108        new_node = new_graph.node_copy(node, lambda x: env[x])
109        env[node] = new_node
110    new_graph.output([env[output] for output in outputs])
111    new_graph.lint()
112    return fx.GraphModule(orig_module, new_graph)
113
114mkldnn_supported = [
115    nn.Conv2d, nn.Linear, nn.BatchNorm2d, nn.ReLU, nn.MaxPool2d, nn.AvgPool2d, nn.AdaptiveAvgPool2d,
116    torch.relu, torch.transpose, torch.sigmoid,
117    F.relu, F.avg_pool2d, F.adaptive_avg_pool2d
118]
119# These are operators that may not be convertible into MKLDNN ops (e.g. the
120# args are scalar values). Thus, we only include them in the subgraph if their
121# arguments are already in MKLDNN.
122# TODO: Determine whether this can be removed after type inference.
123mkldnn_supported_unknown = [operator.add, operator.mul]
124mkldnn_map = {
125    nn.Conv2d: th_mkldnn.MkldnnConv2d,
126    nn.Linear: th_mkldnn.MkldnnLinear,
127    nn.BatchNorm2d: lambda a, _: th_mkldnn.MkldnnBatchNorm(a)
128}
129
130
131def modules_to_mkldnn(nodes: List[fx.Node], modules: Dict[str, nn.Module]):
132    """
133    For each node, if it's a module that can be preconverted into MKLDNN,
134    then we do so and create a mapping to allow us to convert from the MKLDNN
135    version of the module to the original.
136    """
137    old_modules: Dict[nn.Module, nn.Module] = {}
138    for node in nodes:
139        if node.op == 'call_module':
140            assert isinstance(node.target, str)
141            cur_module = modules[node.target]
142            if type(cur_module) in mkldnn_map:
143                new_module = mkldnn_map[type(cur_module)](cur_module, torch.float)
144                assert isinstance(new_module, nn.Module)
145                old_modules[new_module] = copy.deepcopy(cur_module)
146                replace_node_module(node, modules, new_module)
147    return old_modules
148
149def reset_modules(nodes: List[fx.Node], modules: Dict[str, nn.Module], old_modules: Dict[nn.Module, nn.Module]):
150    """
151    Maps each module that's been changed with `modules_to_mkldnn` back to its
152    original.
153    """
154    for node in nodes:
155        if node.op == 'call_module':
156            assert (isinstance(node.target, str))
157            cur_module = modules[node.target]
158            if cur_module in old_modules:
159                replace_node_module(node, modules, old_modules[cur_module])
160
161class MklSubgraph:
162    def __init__(self, fx_graph: fx.Graph):
163        self.fx_graph = fx_graph
164        self.nodes: List[fx.Node] = []
165        self.start_nodes: List[fx.Node] = []
166        self.end_nodes: List[fx.Node] = []
167
168def gen_mkl_autotuner(example_inputs, iters=10, warmup=1):
169    """
170    This generates a heuristic that can be passed into `optimize_for_inference` that
171    determines whether a subgraph should be run in MKL by running it with the example_inputs.
172
173    Example usage:
174        heuristic = gen_mkl_autotuner(example_inputs, iters=10)
175        fast_model = optimization.optimize_for_inference(model, heuristic)
176    """
177    fx_model = None
178    old_modules = None
179
180    def use_mkl_heuristic(graph: MklSubgraph) -> bool:
181        nonlocal fx_model, old_modules
182        input_nodes = graph.start_nodes
183        if fx_model is None:
184            fx_model = graph.fx_graph.owning_module
185            old_modules = graph.fx_graph.old_modules  # type: ignore[attr-defined]
186            ShapeProp(fx_model).propagate(example_inputs)
187        sample_inputs = [torch.randn(node.shape) for node in input_nodes]  # type: ignore[attr-defined]
188        output_args = cast(List[fx.Node], [node.args[0] for node in graph.end_nodes])
189        submodule = extract_subgraph(fx_model, graph.nodes, input_nodes, output_args)
190
191        def benchmark(f):
192            for _ in range(warmup):
193                f()
194            begin = time.time()
195            for _ in range(iters):
196                out = f()
197            return time.time() - begin
198
199        mkl_time = benchmark(lambda: [i.to_dense() for i in submodule(*[i.to_mkldnn() for i in sample_inputs])])
200
201        reset_modules(submodule.graph.nodes, dict(submodule.named_modules()), old_modules)
202        no_mkl_time = benchmark(lambda: submodule(*sample_inputs))
203        return mkl_time < no_mkl_time
204    return use_mkl_heuristic
205
206def use_mkl_length(graph: MklSubgraph) -> bool:
207    """
208    This is a heuristic that can be passed into `optimize_for_inference` that
209    determines whether a subgraph should be run in MKL by checking if there
210    are more than 2 nodes in it
211    """
212    return len(graph.nodes) > 2
213
214class UnionFind:
215    def __init__(self, n):
216        self.parent: List[Optional[int]] = [None] * n
217        self.size: List[int] = [0] * n
218
219    def make_set(self, v: int):
220        self.parent[v] = v
221        self.size[v] = 1
222
223    def find(self, v: int) -> int:
224        par = self.parent[v]
225        if v == par:
226            return v
227        assert par is not None
228        self.parent[v] = self.find(par)
229        return cast(int, self.parent[v])
230
231    def join(self, a: int, b: int):
232        a, b = self.find(a), self.find(b)
233        if a == b:
234            return a
235        if self.size[a] < self.size[b]:
236            a, b = b, a
237        self.parent[b] = a
238        self.size[a] += self.size[b]
239
240def optimize_for_inference(
241    model: torch.nn.Module,
242    pass_config: Optional[Dict[str, Any]] = None,
243    tracer: Type[fx.Tracer] = fx.Tracer
244) -> torch.nn.Module:
245    """
246    Performs a set of optimization passes to optimize a model for the
247    purposes of inference. Specifically, the passes that are run are:
248    1. Conv/BN fusion
249    2. Dropout removal
250    3. MKL layout optimizations
251
252    The third optimization takes a function `use_mkl_heuristic` that's used
253    to determine whether a subgraph should be explicitly run in MKL layout.
254
255    Note: As FX does not currently handle aliasing, this pass currently
256    assumes nothing aliases. If that isn't true, use at your own risk.
257    """
258    default_pass_config = {
259        "conv_bn_fuse": True,
260        "remove_dropout": True,
261        "mkldnn_layout_optimize": {'heuristic': use_mkl_length},
262    }
263    if pass_config is None:
264        pass_config = {}
265    default_pass_config.update(pass_config)
266
267    if default_pass_config["conv_bn_fuse"]:
268        model = fuse(model)
269    if default_pass_config["remove_dropout"]:
270        model = remove_dropout(model)
271    if default_pass_config["mkldnn_layout_optimize"] is False:
272        return model
273    if not isinstance(default_pass_config["mkldnn_layout_optimize"], dict):
274        raise RuntimeError("mkldnn_layout_optimize config is not a dict")
275    if "heuristic" not in default_pass_config["mkldnn_layout_optimize"]:
276        raise RuntimeError("Heuristic not found in mkldnn_layout_optimize config")
277    use_mkl_heuristic = default_pass_config["mkldnn_layout_optimize"]["heuristic"]
278
279    cur_tracer = tracer()
280    fx_graph = cur_tracer.trace(copy.deepcopy(model))
281    fx_model = fx.GraphModule(cur_tracer.root, fx_graph)
282    modules: Dict[str, nn.Module] = dict(model.named_modules())
283
284    class MklSupport(Enum):
285        NO = 1
286        YES = 2
287        UNKNOWN = 3
288
289    # Inserts to_mkldnn and to_dense around every node we want to be a MKLDNN node.
290    # If the op is in `mkldnn_supported` then we always treat it as a MKLDNN node.
291    # However, if it's in `mkldnn_supported_unknown`, then we only treat it as
292    # a MKLDNN node if its inputs are MKLDNN nodes.
293    for node in list(fx_graph.nodes):
294        supports_mkldnn = MklSupport.NO
295        if node.op == 'call_module':
296            cur_module = modules[node.target]
297            if type(cur_module) in mkldnn_supported:
298                supports_mkldnn = MklSupport.YES
299                sample_parameter = next(cur_module.parameters(), None)
300                if sample_parameter is not None:
301                    assert sample_parameter.dtype == torch.float, "this pass is only for torch.float modules"
302                    assert sample_parameter.device == torch.device('cpu'), "this pass is only for CPU modules"
303        elif node.op == 'call_function':
304            if node.target in mkldnn_supported:
305                supports_mkldnn = MklSupport.YES
306            elif node.target in mkldnn_supported_unknown:
307                supports_mkldnn = MklSupport.UNKNOWN
308
309        if supports_mkldnn != MklSupport.NO:
310            if supports_mkldnn == MklSupport.UNKNOWN:
311                if not any(arg.target == 'to_dense' for arg in node.args):
312                    continue
313            with fx_graph.inserting_before(node):
314                mkldnn_args = fx.map_arg(node.args, lambda n: fx_graph.call_method('to_mkldnn', (n, )))
315
316            node.args = cast(Tuple[fx.node.Argument], mkldnn_args)
317
318            with fx_graph.inserting_after(node):
319                dense_x = fx_graph.create_node('call_method', 'to_dense', (node,))
320                node.replace_all_uses_with(dense_x)
321                dense_x.args = (node,)
322
323    # Does pre-conversion of all modules into MKLDNN (when possible)
324    old_modules = modules_to_mkldnn(list(fx_graph.nodes), modules)
325    fx_graph.old_modules = old_modules  # type: ignore[attr-defined]
326
327    # optimizes all a -> to_dense -> to_mkldnn -> b patterns into a -> b
328    for node in fx_graph.nodes:
329        if node.op == 'call_method' and node.target == 'to_dense':
330            prv_node = node.args[0]
331            users = list(node.users)
332            for user in users:
333                if user.op == 'call_method' and user.target == 'to_mkldnn':
334                    user.replace_all_uses_with(prv_node)
335                    fx_graph.erase_node(user)
336            if len(node.users) == 0:
337                fx_graph.erase_node(node)
338
339
340    num_nodes = len(fx_graph.nodes)
341    uf = UnionFind(num_nodes)
342
343    def get_color(n):
344        if hasattr(n, 'color'):  # Current node is part of a MKL subgraph
345            return uf.find(n.color)
346        if hasattr(n, 'start_color'):  # Current node is input to MKL subgraph
347            return uf.find(n.start_color)
348        return None
349
350
351    # This code is to find each MKLDNN subgraph. Each MKLDNN subgraph consists
352    # of input nodes (which are only `to_mkldnn` calls), output nodes
353    # (`to_dense` calls), and intermediate nodes, which are run entirely on
354    # MKLDNN layout tensors.
355    #
356    # Specifically, this code does a flood fill on a directed acyclic graph
357    # (DAG), starting from each possible "start node" (i.e: `to_mkldnn` nodes).
358    # If every node only had one input, this would be sufficient. However, in
359    # the case that a node has multiple inputs coming from different start
360    # nodes (i.e. colors), we need to join these 2 colors into 1. That's done
361    # using a Disjoint Set Union.
362    for cur_idx, node in enumerate(fx_graph.nodes):
363        if node.op == 'call_method' and node.target == 'to_mkldnn':
364            node.start_color = cur_idx
365            uf.make_set(cur_idx)
366        elif node.op == 'call_method' and node.target == 'to_dense':
367            assert get_color(node.args[0]) is not None
368            node.end_color = get_color(node.args[0])
369        else:
370            cur_colors = [get_color(i) for i in node.all_input_nodes if isinstance(i, fx.Node) if get_color(i) is not None]
371
372            if len(cur_colors) == 0:
373                continue
374            assert not any(i is None for i in cur_colors)
375            cur_colors = sorted(cur_colors)
376            node.color = cur_colors[0]
377            for other_color in cur_colors[1:]:
378                uf.join(cur_colors[0], other_color)
379
380
381    mkldnn_graphs: Dict[int, MklSubgraph] = defaultdict(lambda: MklSubgraph(fx_graph))
382    for node in fx_graph.nodes:
383        if hasattr(node, 'color'):
384            mkldnn_graphs[uf.find(node.color)].nodes.append(node)
385        if hasattr(node, 'start_color'):
386            mkldnn_graphs[uf.find(node.start_color)].start_nodes.append(node)
387        if hasattr(node, 'end_color'):
388            mkldnn_graphs[uf.find(node.end_color)].end_nodes.append(node)
389
390
391    # Now that we have all the subgraphs, we need to decide which MKLDNN
392    # subgraphs we actually want to keep in MKLDNN.
393    for graph in mkldnn_graphs.values():
394        if not use_mkl_heuristic(graph):
395            for node in graph.start_nodes + graph.end_nodes:
396                prv = node.args[0]
397                node.replace_all_uses_with(prv)
398                fx_graph.erase_node(node)
399            reset_modules(graph.nodes, modules, old_modules)
400
401    mkldnn_conversions = 0
402    for node in fx_graph.nodes:
403        if node.target == 'to_mkldnn' or node.target == 'to_dense':
404            mkldnn_conversions += 1
405
406    logging.getLogger(__name__).info("mkldnn conversions: %s", mkldnn_conversions)
407    fx_graph.lint()
408    result = fx.GraphModule(model, fx_graph)
409    return result
410