xref: /aosp_15_r20/external/executorch/backends/cadence/aot/_passes.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9from typing import Any, List, Optional, Type
10
11import torch
12import torch.fx
13import torch.utils._pytree as pytree
14from executorch.backends.cadence.aot.fuse_ops import (
15    CadenceFuseOpsInGraph,
16    FuseFullThenReshapePass,
17    FuseTransposeOpPairsPass,
18)
19from executorch.backends.cadence.aot.pass_utils import (
20    CadencePassAttribute,
21    create_cadence_pass_filter,
22    register_cadence_pass,
23)
24
25from executorch.backends.cadence.aot.remove_ops import (
26    CadenceRemoveNops,
27    RemoveNopSliceOrViewOpPass,
28    RemoveRedundantOps,
29)
30from executorch.backends.cadence.aot.reorder_ops import CadenceReorderOpsInGraph
31from executorch.backends.cadence.aot.replace_ops import CadenceReplaceOpsInGraph
32from executorch.backends.cadence.aot.simplify_ops import CadenceSimplifyOpsInGraph
33from executorch.exir.pass_base import ExportPass, PassResult
34from executorch.exir.pass_manager import PassManager, PassType
35from executorch.exir.passes import dead_code_elimination_pass
36from executorch.exir.passes.scalar_to_tensor_pass import ScalarToTensorPass
37from executorch.exir.passes.spec_prop_pass import SpecPropPass
38
39
40@register_cadence_pass(CadencePassAttribute(opt_level=0))
41class InitializePipeline(ExportPass):
42    """
43    Initialize the pass pipeline. This should invariably be the first pass to
44    run.
45    """
46
47    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
48        dead_code_elimination_pass(graph_module)
49        result = SpecPropPass()(graph_module)
50        assert result is not None
51        return result
52
53
54@register_cadence_pass(CadencePassAttribute(opt_level=0))
55class FinalizePipeline(ExportPass):
56    """
57    The final cleanup pass after running the pass pipeline.
58    """
59
60    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
61        finalize_passes: List[PassType] = [
62            ScalarToTensorPass(),
63            SpecPropPass(),
64        ]
65        result = PassManager(passes=finalize_passes)(graph_module)
66        dead_code_elimination_pass(result.graph_module)
67        return result
68
69
70# Similar to what's done in executorch/exir/pass_base.py
71Argument = Any  # pyre-ignore
72
73
74def get_passes_in_default_order() -> List[Type[PassType]]:
75    passes = [
76        InitializePipeline,
77        RemoveRedundantOps.passes,
78        CadenceReorderOpsInGraph.passes,
79        # Phase ordering: remove -> fusion -> replacement passes.
80        CadenceRemoveNops.passes,
81        CadenceFuseOpsInGraph.passes,
82        CadenceReplaceOpsInGraph.passes,
83        CadenceSimplifyOpsInGraph.passes,
84        FinalizePipeline,
85        FuseFullThenReshapePass,
86        FuseTransposeOpPairsPass,
87        RemoveNopSliceOrViewOpPass,
88    ]
89    return pytree.tree_flatten(passes)[0]
90
91
92def get_cadence_passes(
93    opt_level: int,
94) -> List[Optional[PassResult]]:
95    passes = get_passes_in_default_order()
96    pass_filter = create_cadence_pass_filter(opt_level)
97    filtered_passes = [
98        # pyre-fixme[20]: Call `torch.fx.passes.infra.pass_base.PassBase.__call__` expects argument `graph_module`.
99        filtered_pass()
100        # pyre-fixme[6]: In call `filter.__new__` ... got `List[Type[typing.Callable[[GraphModule], Optional[PassResult]]]]`.
101        for filtered_pass in list(filter(pass_filter, passes))
102    ]
103    return filtered_passes
104