xref: /aosp_15_r20/external/pytorch/torch/_dynamo/backends/registry.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import functools
4import logging
5import sys
6from importlib.metadata import EntryPoint
7from typing import Callable, Dict, List, Optional, Protocol, Sequence, Tuple
8
9import torch
10from torch import fx
11
12
13log = logging.getLogger(__name__)
14
15
16class CompiledFn(Protocol):
17    def __call__(self, *args: torch.Tensor) -> Tuple[torch.Tensor, ...]:
18        ...
19
20
21CompilerFn = Callable[[fx.GraphModule, List[torch.Tensor]], CompiledFn]
22
23_BACKENDS: Dict[str, Optional[EntryPoint]] = {}
24_COMPILER_FNS: Dict[str, CompilerFn] = {}
25
26
27def register_backend(
28    compiler_fn: Optional[CompilerFn] = None,
29    name: Optional[str] = None,
30    tags: Sequence[str] = (),
31):
32    """
33    Decorator to add a given compiler to the registry to allow calling
34    `torch.compile` with string shorthand.  Note: for projects not
35    imported by default, it might be easier to pass a function directly
36    as a backend and not use a string.
37
38    Args:
39        compiler_fn: Callable taking a FX graph and fake tensor inputs
40        name: Optional name, defaults to `compiler_fn.__name__`
41        tags: Optional set of string tags to categorize backend with
42    """
43    if compiler_fn is None:
44        # @register_backend(name="") syntax
45        return functools.partial(register_backend, name=name, tags=tags)
46    assert callable(compiler_fn)
47    name = name or compiler_fn.__name__
48    assert name not in _COMPILER_FNS, f"duplicate name: {name}"
49    if compiler_fn not in _BACKENDS:
50        _BACKENDS[name] = None
51    _COMPILER_FNS[name] = compiler_fn
52    compiler_fn._tags = tuple(tags)
53    return compiler_fn
54
55
56register_debug_backend = functools.partial(register_backend, tags=("debug",))
57register_experimental_backend = functools.partial(
58    register_backend, tags=("experimental",)
59)
60
61
62def lookup_backend(compiler_fn):
63    """Expand backend strings to functions"""
64    if isinstance(compiler_fn, str):
65        if compiler_fn not in _BACKENDS:
66            _lazy_import()
67        if compiler_fn not in _BACKENDS:
68            from ..exc import InvalidBackend
69
70            raise InvalidBackend(name=compiler_fn)
71
72        if compiler_fn not in _COMPILER_FNS:
73            entry_point = _BACKENDS[compiler_fn]
74            register_backend(compiler_fn=entry_point.load(), name=compiler_fn)
75        compiler_fn = _COMPILER_FNS[compiler_fn]
76    return compiler_fn
77
78
79def list_backends(exclude_tags=("debug", "experimental")) -> List[str]:
80    """
81    Return valid strings that can be passed to:
82
83        torch.compile(..., backend="name")
84    """
85    _lazy_import()
86    exclude_tags = set(exclude_tags or ())
87
88    backends = [
89        name
90        for name in _BACKENDS.keys()
91        if name not in _COMPILER_FNS
92        or not exclude_tags.intersection(_COMPILER_FNS[name]._tags)
93    ]
94    return sorted(backends)
95
96
97@functools.lru_cache(None)
98def _lazy_import():
99    from .. import backends
100    from ..utils import import_submodule
101
102    import_submodule(backends)
103
104    from ..repro.after_dynamo import dynamo_minifier_backend
105
106    assert dynamo_minifier_backend is not None
107
108    _discover_entrypoint_backends()
109
110
111@functools.lru_cache(None)
112def _discover_entrypoint_backends():
113    # importing here so it will pick up the mocked version in test_backends.py
114    from importlib.metadata import entry_points
115
116    group_name = "torch_dynamo_backends"
117    if sys.version_info < (3, 10):
118        eps = entry_points()
119        eps = eps[group_name] if group_name in eps else []
120        eps = {ep.name: ep for ep in eps}
121    else:
122        eps = entry_points(group=group_name)
123        eps = {name: eps[name] for name in eps.names}
124    for backend_name in eps:
125        _BACKENDS[backend_name] = eps[backend_name]
126