xref: /aosp_15_r20/external/executorch/exir/pass_manager.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Workerfrom typing import Callable, List, Optional, Union
10*523fa7a6SAndroid Build Coastguard Worker
11*523fa7a6SAndroid Build Coastguard Workerimport torch
12*523fa7a6SAndroid Build Coastguard Workerimport torch.fx.passes.infra.pass_manager as fx
13*523fa7a6SAndroid Build Coastguard Workerimport torch.utils._pytree as pytree
14*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.error import ExportError, ExportErrorType
15*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx.passes.infra.pass_base import PassResult
16*523fa7a6SAndroid Build Coastguard Workerfrom typing_extensions import TypeAlias
17*523fa7a6SAndroid Build Coastguard Worker
18*523fa7a6SAndroid Build Coastguard WorkerPassType: TypeAlias = Callable[[torch.fx.GraphModule], Optional[PassResult]]
19*523fa7a6SAndroid Build Coastguard Worker
20*523fa7a6SAndroid Build Coastguard Worker
21*523fa7a6SAndroid Build Coastguard Workerclass PassManager(fx.PassManager):
22*523fa7a6SAndroid Build Coastguard Worker    """
23*523fa7a6SAndroid Build Coastguard Worker    Class to run multiple passes on a given graph module. The PassManager is
24*523fa7a6SAndroid Build Coastguard Worker    callable so to run it, we can just call the PassManager instance.
25*523fa7a6SAndroid Build Coastguard Worker
26*523fa7a6SAndroid Build Coastguard Worker    Private Attributes:
27*523fa7a6SAndroid Build Coastguard Worker        * **passes**: A list of callable passes
28*523fa7a6SAndroid Build Coastguard Worker        * **params**: An instance of PassManagerParams containing the result of the
29*523fa7a6SAndroid Build Coastguard Worker            flags set in the constructor.
30*523fa7a6SAndroid Build Coastguard Worker    """
31*523fa7a6SAndroid Build Coastguard Worker
32*523fa7a6SAndroid Build Coastguard Worker    def __init__(
33*523fa7a6SAndroid Build Coastguard Worker        self,
34*523fa7a6SAndroid Build Coastguard Worker        passes: Optional[Union[List[PassType], List[List[PassType]]]] = None,
35*523fa7a6SAndroid Build Coastguard Worker        run_checks_after_each_pass: bool = False,
36*523fa7a6SAndroid Build Coastguard Worker        suppress_check_failures: bool = False,
37*523fa7a6SAndroid Build Coastguard Worker    ) -> None:
38*523fa7a6SAndroid Build Coastguard Worker        r"""
39*523fa7a6SAndroid Build Coastguard Worker        Args:
40*523fa7a6SAndroid Build Coastguard Worker            passes: A list of passes
41*523fa7a6SAndroid Build Coastguard Worker            enable_debug_pass: set to true to enable the debug passes
42*523fa7a6SAndroid Build Coastguard Worker            run_checks_after_each_pass: whether to run checks and linting after each pass
43*523fa7a6SAndroid Build Coastguard Worker        """
44*523fa7a6SAndroid Build Coastguard Worker
45*523fa7a6SAndroid Build Coastguard Worker        # Flatten the passes to a list of callables
46*523fa7a6SAndroid Build Coastguard Worker        passes = passes if passes else []
47*523fa7a6SAndroid Build Coastguard Worker        flattened_passes = [
48*523fa7a6SAndroid Build Coastguard Worker            fx.pass_result_wrapper(fn) for fn in pytree.tree_flatten(passes)[0]
49*523fa7a6SAndroid Build Coastguard Worker        ]
50*523fa7a6SAndroid Build Coastguard Worker
51*523fa7a6SAndroid Build Coastguard Worker        super().__init__(
52*523fa7a6SAndroid Build Coastguard Worker            flattened_passes,
53*523fa7a6SAndroid Build Coastguard Worker            run_checks_after_each_pass=run_checks_after_each_pass,
54*523fa7a6SAndroid Build Coastguard Worker            suppress_check_failures=suppress_check_failures,
55*523fa7a6SAndroid Build Coastguard Worker        )
56*523fa7a6SAndroid Build Coastguard Worker
57*523fa7a6SAndroid Build Coastguard Worker    def check(self, module: torch.nn.Module) -> None:
58*523fa7a6SAndroid Build Coastguard Worker        """
59*523fa7a6SAndroid Build Coastguard Worker        Runs various checks on the given graph module to make sure it contains
60*523fa7a6SAndroid Build Coastguard Worker        the needed data for passes.
61*523fa7a6SAndroid Build Coastguard Worker
62*523fa7a6SAndroid Build Coastguard Worker        Some checks that need to be run:
63*523fa7a6SAndroid Build Coastguard Worker            - Ensure that types of operator node match the types specified in
64*523fa7a6SAndroid Build Coastguard Worker              the node's spec field (ex. if the op returns a tuple then the
65*523fa7a6SAndroid Build Coastguard Worker              node's spec field is a tuple)
66*523fa7a6SAndroid Build Coastguard Worker            - Ensure that the graph module has type torch.fx.GraphModule
67*523fa7a6SAndroid Build Coastguard Worker        """
68*523fa7a6SAndroid Build Coastguard Worker        assert isinstance(module, fx.GraphModule)
69*523fa7a6SAndroid Build Coastguard Worker        module.recompile()
70*523fa7a6SAndroid Build Coastguard Worker        module.graph.lint()
71*523fa7a6SAndroid Build Coastguard Worker        # TODO(qihan): use verifier.check_is_exir
72*523fa7a6SAndroid Build Coastguard Worker
73*523fa7a6SAndroid Build Coastguard Worker        for node in module.graph.nodes:
74*523fa7a6SAndroid Build Coastguard Worker            if node.op == "call_method":
75*523fa7a6SAndroid Build Coastguard Worker                raise ExportError(
76*523fa7a6SAndroid Build Coastguard Worker                    ExportErrorType.NOT_SUPPORTED,
77*523fa7a6SAndroid Build Coastguard Worker                    f"call_method `{node}` is not supported except for backend delegate.",
78*523fa7a6SAndroid Build Coastguard Worker                )
79