xref: /aosp_15_r20/external/executorch/exir/passes/pass_registry.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9import re
10import warnings
11from typing import Callable, Dict, List, Optional
12
13from executorch.exir.error import ExportError, ExportErrorType
14
15from executorch.exir.pass_manager import PassType
16
17
18class PassRegistry:
19    """
20    Allows passes to be automatically registered into a global registry, and
21    users to search within the registry by the pass’s string name to get a pass.
22
23    Attributes:
24        registry: A dictionary of names of passes mapping to a list of passes in
25        the form of callable functions or PassBase instances (which are also callable)
26    """
27
28    registry: Dict[str, List[PassType]] = {}
29
30    @classmethod
31    def register(
32        cls, pass_name: Optional[str] = None
33    ) -> Callable[[PassType], PassType]:
34        """
35        A decorator used on top of passes to insert a pass into the registry. If
36        pass_name is not specified, then it will be generated based on the name
37        of the function passed in.
38
39        This decorator can be used on top of functions (with type
40        PassManagerParams * torch.fx.GraphModule -> None) or on top of PassBase
41        subclasses instances.
42        """
43
44        def wrapper(one_pass: PassType) -> PassType:
45            key = pass_name
46            if not key:
47                key = re.sub(r"(?<!^)(?=[A-Z])", "_", one_pass.__name__).lower()
48
49            cls.register_list(key, [one_pass])
50            return one_pass
51
52        return wrapper
53
54    @classmethod
55    def register_list(cls, pass_name: str, pass_list: List[PassType]) -> None:
56        """
57        A function used to insert a list of passes into the registry. The pass
58        can be searched for in the registry according to the given pass name.
59        """
60
61        if pass_name in cls.registry:
62            warnings.warn(
63                f"Pass {pass_name} already exists inside of the PassRegistry. Will ignore.",
64                stacklevel=1,
65            )
66            return
67
68        cls.registry[pass_name] = pass_list
69
70    @classmethod
71    def get(cls, key: str) -> List[PassType]:
72        """
73        Gets the pass corresponding to the given name. If the pass is a function
74        then it will directly return the callable function.
75
76        Args:
77            key: The name of a pass
78
79        Return:
80            A callable pass or a list of callable passes
81        """
82        if key not in cls.registry:
83            raise ExportError(
84                ExportErrorType.MISSING_PROPERTY,
85                f"Pass {key} does not exists inside of the PassRegistry",
86            )
87
88        pass_found = cls.registry[key]
89        return pass_found
90