1# mypy: ignore-errors 2 3import functools 4import logging 5import sys 6from importlib.metadata import EntryPoint 7from typing import Callable, Dict, List, Optional, Protocol, Sequence, Tuple 8 9import torch 10from torch import fx 11 12 13log = logging.getLogger(__name__) 14 15 16class CompiledFn(Protocol): 17 def __call__(self, *args: torch.Tensor) -> Tuple[torch.Tensor, ...]: 18 ... 19 20 21CompilerFn = Callable[[fx.GraphModule, List[torch.Tensor]], CompiledFn] 22 23_BACKENDS: Dict[str, Optional[EntryPoint]] = {} 24_COMPILER_FNS: Dict[str, CompilerFn] = {} 25 26 27def register_backend( 28 compiler_fn: Optional[CompilerFn] = None, 29 name: Optional[str] = None, 30 tags: Sequence[str] = (), 31): 32 """ 33 Decorator to add a given compiler to the registry to allow calling 34 `torch.compile` with string shorthand. Note: for projects not 35 imported by default, it might be easier to pass a function directly 36 as a backend and not use a string. 37 38 Args: 39 compiler_fn: Callable taking a FX graph and fake tensor inputs 40 name: Optional name, defaults to `compiler_fn.__name__` 41 tags: Optional set of string tags to categorize backend with 42 """ 43 if compiler_fn is None: 44 # @register_backend(name="") syntax 45 return functools.partial(register_backend, name=name, tags=tags) 46 assert callable(compiler_fn) 47 name = name or compiler_fn.__name__ 48 assert name not in _COMPILER_FNS, f"duplicate name: {name}" 49 if compiler_fn not in _BACKENDS: 50 _BACKENDS[name] = None 51 _COMPILER_FNS[name] = compiler_fn 52 compiler_fn._tags = tuple(tags) 53 return compiler_fn 54 55 56register_debug_backend = functools.partial(register_backend, tags=("debug",)) 57register_experimental_backend = functools.partial( 58 register_backend, tags=("experimental",) 59) 60 61 62def lookup_backend(compiler_fn): 63 """Expand backend strings to functions""" 64 if isinstance(compiler_fn, str): 65 if compiler_fn not in _BACKENDS: 66 _lazy_import() 67 if compiler_fn not in _BACKENDS: 68 from ..exc import InvalidBackend 69 70 raise InvalidBackend(name=compiler_fn) 71 72 if compiler_fn not in _COMPILER_FNS: 73 entry_point = _BACKENDS[compiler_fn] 74 register_backend(compiler_fn=entry_point.load(), name=compiler_fn) 75 compiler_fn = _COMPILER_FNS[compiler_fn] 76 return compiler_fn 77 78 79def list_backends(exclude_tags=("debug", "experimental")) -> List[str]: 80 """ 81 Return valid strings that can be passed to: 82 83 torch.compile(..., backend="name") 84 """ 85 _lazy_import() 86 exclude_tags = set(exclude_tags or ()) 87 88 backends = [ 89 name 90 for name in _BACKENDS.keys() 91 if name not in _COMPILER_FNS 92 or not exclude_tags.intersection(_COMPILER_FNS[name]._tags) 93 ] 94 return sorted(backends) 95 96 97@functools.lru_cache(None) 98def _lazy_import(): 99 from .. import backends 100 from ..utils import import_submodule 101 102 import_submodule(backends) 103 104 from ..repro.after_dynamo import dynamo_minifier_backend 105 106 assert dynamo_minifier_backend is not None 107 108 _discover_entrypoint_backends() 109 110 111@functools.lru_cache(None) 112def _discover_entrypoint_backends(): 113 # importing here so it will pick up the mocked version in test_backends.py 114 from importlib.metadata import entry_points 115 116 group_name = "torch_dynamo_backends" 117 if sys.version_info < (3, 10): 118 eps = entry_points() 119 eps = eps[group_name] if group_name in eps else [] 120 eps = {ep.name: ep for ep in eps} 121 else: 122 eps = entry_points(group=group_name) 123 eps = {name: eps[name] for name in eps.names} 124 for backend_name in eps: 125 _BACKENDS[backend_name] = eps[backend_name] 126