xref: /aosp_15_r20/external/executorch/exir/passes/pass_registry.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Workerimport re
10*523fa7a6SAndroid Build Coastguard Workerimport warnings
11*523fa7a6SAndroid Build Coastguard Workerfrom typing import Callable, Dict, List, Optional
12*523fa7a6SAndroid Build Coastguard Worker
13*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.error import ExportError, ExportErrorType
14*523fa7a6SAndroid Build Coastguard Worker
15*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.pass_manager import PassType
16*523fa7a6SAndroid Build Coastguard Worker
17*523fa7a6SAndroid Build Coastguard Worker
18*523fa7a6SAndroid Build Coastguard Workerclass PassRegistry:
19*523fa7a6SAndroid Build Coastguard Worker    """
20*523fa7a6SAndroid Build Coastguard Worker    Allows passes to be automatically registered into a global registry, and
21*523fa7a6SAndroid Build Coastguard Worker    users to search within the registry by the pass’s string name to get a pass.
22*523fa7a6SAndroid Build Coastguard Worker
23*523fa7a6SAndroid Build Coastguard Worker    Attributes:
24*523fa7a6SAndroid Build Coastguard Worker        registry: A dictionary of names of passes mapping to a list of passes in
25*523fa7a6SAndroid Build Coastguard Worker        the form of callable functions or PassBase instances (which are also callable)
26*523fa7a6SAndroid Build Coastguard Worker    """
27*523fa7a6SAndroid Build Coastguard Worker
28*523fa7a6SAndroid Build Coastguard Worker    registry: Dict[str, List[PassType]] = {}
29*523fa7a6SAndroid Build Coastguard Worker
30*523fa7a6SAndroid Build Coastguard Worker    @classmethod
31*523fa7a6SAndroid Build Coastguard Worker    def register(
32*523fa7a6SAndroid Build Coastguard Worker        cls, pass_name: Optional[str] = None
33*523fa7a6SAndroid Build Coastguard Worker    ) -> Callable[[PassType], PassType]:
34*523fa7a6SAndroid Build Coastguard Worker        """
35*523fa7a6SAndroid Build Coastguard Worker        A decorator used on top of passes to insert a pass into the registry. If
36*523fa7a6SAndroid Build Coastguard Worker        pass_name is not specified, then it will be generated based on the name
37*523fa7a6SAndroid Build Coastguard Worker        of the function passed in.
38*523fa7a6SAndroid Build Coastguard Worker
39*523fa7a6SAndroid Build Coastguard Worker        This decorator can be used on top of functions (with type
40*523fa7a6SAndroid Build Coastguard Worker        PassManagerParams * torch.fx.GraphModule -> None) or on top of PassBase
41*523fa7a6SAndroid Build Coastguard Worker        subclasses instances.
42*523fa7a6SAndroid Build Coastguard Worker        """
43*523fa7a6SAndroid Build Coastguard Worker
44*523fa7a6SAndroid Build Coastguard Worker        def wrapper(one_pass: PassType) -> PassType:
45*523fa7a6SAndroid Build Coastguard Worker            key = pass_name
46*523fa7a6SAndroid Build Coastguard Worker            if not key:
47*523fa7a6SAndroid Build Coastguard Worker                key = re.sub(r"(?<!^)(?=[A-Z])", "_", one_pass.__name__).lower()
48*523fa7a6SAndroid Build Coastguard Worker
49*523fa7a6SAndroid Build Coastguard Worker            cls.register_list(key, [one_pass])
50*523fa7a6SAndroid Build Coastguard Worker            return one_pass
51*523fa7a6SAndroid Build Coastguard Worker
52*523fa7a6SAndroid Build Coastguard Worker        return wrapper
53*523fa7a6SAndroid Build Coastguard Worker
54*523fa7a6SAndroid Build Coastguard Worker    @classmethod
55*523fa7a6SAndroid Build Coastguard Worker    def register_list(cls, pass_name: str, pass_list: List[PassType]) -> None:
56*523fa7a6SAndroid Build Coastguard Worker        """
57*523fa7a6SAndroid Build Coastguard Worker        A function used to insert a list of passes into the registry. The pass
58*523fa7a6SAndroid Build Coastguard Worker        can be searched for in the registry according to the given pass name.
59*523fa7a6SAndroid Build Coastguard Worker        """
60*523fa7a6SAndroid Build Coastguard Worker
61*523fa7a6SAndroid Build Coastguard Worker        if pass_name in cls.registry:
62*523fa7a6SAndroid Build Coastguard Worker            warnings.warn(
63*523fa7a6SAndroid Build Coastguard Worker                f"Pass {pass_name} already exists inside of the PassRegistry. Will ignore.",
64*523fa7a6SAndroid Build Coastguard Worker                stacklevel=1,
65*523fa7a6SAndroid Build Coastguard Worker            )
66*523fa7a6SAndroid Build Coastguard Worker            return
67*523fa7a6SAndroid Build Coastguard Worker
68*523fa7a6SAndroid Build Coastguard Worker        cls.registry[pass_name] = pass_list
69*523fa7a6SAndroid Build Coastguard Worker
70*523fa7a6SAndroid Build Coastguard Worker    @classmethod
71*523fa7a6SAndroid Build Coastguard Worker    def get(cls, key: str) -> List[PassType]:
72*523fa7a6SAndroid Build Coastguard Worker        """
73*523fa7a6SAndroid Build Coastguard Worker        Gets the pass corresponding to the given name. If the pass is a function
74*523fa7a6SAndroid Build Coastguard Worker        then it will directly return the callable function.
75*523fa7a6SAndroid Build Coastguard Worker
76*523fa7a6SAndroid Build Coastguard Worker        Args:
77*523fa7a6SAndroid Build Coastguard Worker            key: The name of a pass
78*523fa7a6SAndroid Build Coastguard Worker
79*523fa7a6SAndroid Build Coastguard Worker        Return:
80*523fa7a6SAndroid Build Coastguard Worker            A callable pass or a list of callable passes
81*523fa7a6SAndroid Build Coastguard Worker        """
82*523fa7a6SAndroid Build Coastguard Worker        if key not in cls.registry:
83*523fa7a6SAndroid Build Coastguard Worker            raise ExportError(
84*523fa7a6SAndroid Build Coastguard Worker                ExportErrorType.MISSING_PROPERTY,
85*523fa7a6SAndroid Build Coastguard Worker                f"Pass {key} does not exists inside of the PassRegistry",
86*523fa7a6SAndroid Build Coastguard Worker            )
87*523fa7a6SAndroid Build Coastguard Worker
88*523fa7a6SAndroid Build Coastguard Worker        pass_found = cls.registry[key]
89*523fa7a6SAndroid Build Coastguard Worker        return pass_found
90