1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Workerimport re 10*523fa7a6SAndroid Build Coastguard Workerimport warnings 11*523fa7a6SAndroid Build Coastguard Workerfrom typing import Callable, Dict, List, Optional 12*523fa7a6SAndroid Build Coastguard Worker 13*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.error import ExportError, ExportErrorType 14*523fa7a6SAndroid Build Coastguard Worker 15*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.pass_manager import PassType 16*523fa7a6SAndroid Build Coastguard Worker 17*523fa7a6SAndroid Build Coastguard Worker 18*523fa7a6SAndroid Build Coastguard Workerclass PassRegistry: 19*523fa7a6SAndroid Build Coastguard Worker """ 20*523fa7a6SAndroid Build Coastguard Worker Allows passes to be automatically registered into a global registry, and 21*523fa7a6SAndroid Build Coastguard Worker users to search within the registry by the pass’s string name to get a pass. 22*523fa7a6SAndroid Build Coastguard Worker 23*523fa7a6SAndroid Build Coastguard Worker Attributes: 24*523fa7a6SAndroid Build Coastguard Worker registry: A dictionary of names of passes mapping to a list of passes in 25*523fa7a6SAndroid Build Coastguard Worker the form of callable functions or PassBase instances (which are also callable) 26*523fa7a6SAndroid Build Coastguard Worker """ 27*523fa7a6SAndroid Build Coastguard Worker 28*523fa7a6SAndroid Build Coastguard Worker registry: Dict[str, List[PassType]] = {} 29*523fa7a6SAndroid Build Coastguard Worker 30*523fa7a6SAndroid Build Coastguard Worker @classmethod 31*523fa7a6SAndroid Build Coastguard Worker def register( 32*523fa7a6SAndroid Build Coastguard Worker cls, pass_name: Optional[str] = None 33*523fa7a6SAndroid Build Coastguard Worker ) -> Callable[[PassType], PassType]: 34*523fa7a6SAndroid Build Coastguard Worker """ 35*523fa7a6SAndroid Build Coastguard Worker A decorator used on top of passes to insert a pass into the registry. If 36*523fa7a6SAndroid Build Coastguard Worker pass_name is not specified, then it will be generated based on the name 37*523fa7a6SAndroid Build Coastguard Worker of the function passed in. 38*523fa7a6SAndroid Build Coastguard Worker 39*523fa7a6SAndroid Build Coastguard Worker This decorator can be used on top of functions (with type 40*523fa7a6SAndroid Build Coastguard Worker PassManagerParams * torch.fx.GraphModule -> None) or on top of PassBase 41*523fa7a6SAndroid Build Coastguard Worker subclasses instances. 42*523fa7a6SAndroid Build Coastguard Worker """ 43*523fa7a6SAndroid Build Coastguard Worker 44*523fa7a6SAndroid Build Coastguard Worker def wrapper(one_pass: PassType) -> PassType: 45*523fa7a6SAndroid Build Coastguard Worker key = pass_name 46*523fa7a6SAndroid Build Coastguard Worker if not key: 47*523fa7a6SAndroid Build Coastguard Worker key = re.sub(r"(?<!^)(?=[A-Z])", "_", one_pass.__name__).lower() 48*523fa7a6SAndroid Build Coastguard Worker 49*523fa7a6SAndroid Build Coastguard Worker cls.register_list(key, [one_pass]) 50*523fa7a6SAndroid Build Coastguard Worker return one_pass 51*523fa7a6SAndroid Build Coastguard Worker 52*523fa7a6SAndroid Build Coastguard Worker return wrapper 53*523fa7a6SAndroid Build Coastguard Worker 54*523fa7a6SAndroid Build Coastguard Worker @classmethod 55*523fa7a6SAndroid Build Coastguard Worker def register_list(cls, pass_name: str, pass_list: List[PassType]) -> None: 56*523fa7a6SAndroid Build Coastguard Worker """ 57*523fa7a6SAndroid Build Coastguard Worker A function used to insert a list of passes into the registry. The pass 58*523fa7a6SAndroid Build Coastguard Worker can be searched for in the registry according to the given pass name. 59*523fa7a6SAndroid Build Coastguard Worker """ 60*523fa7a6SAndroid Build Coastguard Worker 61*523fa7a6SAndroid Build Coastguard Worker if pass_name in cls.registry: 62*523fa7a6SAndroid Build Coastguard Worker warnings.warn( 63*523fa7a6SAndroid Build Coastguard Worker f"Pass {pass_name} already exists inside of the PassRegistry. Will ignore.", 64*523fa7a6SAndroid Build Coastguard Worker stacklevel=1, 65*523fa7a6SAndroid Build Coastguard Worker ) 66*523fa7a6SAndroid Build Coastguard Worker return 67*523fa7a6SAndroid Build Coastguard Worker 68*523fa7a6SAndroid Build Coastguard Worker cls.registry[pass_name] = pass_list 69*523fa7a6SAndroid Build Coastguard Worker 70*523fa7a6SAndroid Build Coastguard Worker @classmethod 71*523fa7a6SAndroid Build Coastguard Worker def get(cls, key: str) -> List[PassType]: 72*523fa7a6SAndroid Build Coastguard Worker """ 73*523fa7a6SAndroid Build Coastguard Worker Gets the pass corresponding to the given name. If the pass is a function 74*523fa7a6SAndroid Build Coastguard Worker then it will directly return the callable function. 75*523fa7a6SAndroid Build Coastguard Worker 76*523fa7a6SAndroid Build Coastguard Worker Args: 77*523fa7a6SAndroid Build Coastguard Worker key: The name of a pass 78*523fa7a6SAndroid Build Coastguard Worker 79*523fa7a6SAndroid Build Coastguard Worker Return: 80*523fa7a6SAndroid Build Coastguard Worker A callable pass or a list of callable passes 81*523fa7a6SAndroid Build Coastguard Worker """ 82*523fa7a6SAndroid Build Coastguard Worker if key not in cls.registry: 83*523fa7a6SAndroid Build Coastguard Worker raise ExportError( 84*523fa7a6SAndroid Build Coastguard Worker ExportErrorType.MISSING_PROPERTY, 85*523fa7a6SAndroid Build Coastguard Worker f"Pass {key} does not exists inside of the PassRegistry", 86*523fa7a6SAndroid Build Coastguard Worker ) 87*523fa7a6SAndroid Build Coastguard Worker 88*523fa7a6SAndroid Build Coastguard Worker pass_found = cls.registry[key] 89*523fa7a6SAndroid Build Coastguard Worker return pass_found 90