1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9import re 10import warnings 11from typing import Callable, Dict, List, Optional 12 13from executorch.exir.error import ExportError, ExportErrorType 14 15from executorch.exir.pass_manager import PassType 16 17 18class PassRegistry: 19 """ 20 Allows passes to be automatically registered into a global registry, and 21 users to search within the registry by the pass’s string name to get a pass. 22 23 Attributes: 24 registry: A dictionary of names of passes mapping to a list of passes in 25 the form of callable functions or PassBase instances (which are also callable) 26 """ 27 28 registry: Dict[str, List[PassType]] = {} 29 30 @classmethod 31 def register( 32 cls, pass_name: Optional[str] = None 33 ) -> Callable[[PassType], PassType]: 34 """ 35 A decorator used on top of passes to insert a pass into the registry. If 36 pass_name is not specified, then it will be generated based on the name 37 of the function passed in. 38 39 This decorator can be used on top of functions (with type 40 PassManagerParams * torch.fx.GraphModule -> None) or on top of PassBase 41 subclasses instances. 42 """ 43 44 def wrapper(one_pass: PassType) -> PassType: 45 key = pass_name 46 if not key: 47 key = re.sub(r"(?<!^)(?=[A-Z])", "_", one_pass.__name__).lower() 48 49 cls.register_list(key, [one_pass]) 50 return one_pass 51 52 return wrapper 53 54 @classmethod 55 def register_list(cls, pass_name: str, pass_list: List[PassType]) -> None: 56 """ 57 A function used to insert a list of passes into the registry. The pass 58 can be searched for in the registry according to the given pass name. 59 """ 60 61 if pass_name in cls.registry: 62 warnings.warn( 63 f"Pass {pass_name} already exists inside of the PassRegistry. Will ignore.", 64 stacklevel=1, 65 ) 66 return 67 68 cls.registry[pass_name] = pass_list 69 70 @classmethod 71 def get(cls, key: str) -> List[PassType]: 72 """ 73 Gets the pass corresponding to the given name. If the pass is a function 74 then it will directly return the callable function. 75 76 Args: 77 key: The name of a pass 78 79 Return: 80 A callable pass or a list of callable passes 81 """ 82 if key not in cls.registry: 83 raise ExportError( 84 ExportErrorType.MISSING_PROPERTY, 85 f"Pass {key} does not exists inside of the PassRegistry", 86 ) 87 88 pass_found = cls.registry[key] 89 return pass_found 90