xref: /aosp_15_r20/external/pytorch/torch/fx/experimental/unification/match.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from .core import unify, reify  # type: ignore[attr-defined]
3from .variable import isvar
4from .utils import _toposort, freeze
5from .unification_tools import groupby, first  # type: ignore[import]
6
7
8class Dispatcher:
9    def __init__(self, name):
10        self.name = name
11        self.funcs = {}
12        self.ordering = []
13
14    def add(self, signature, func):
15        self.funcs[freeze(signature)] = func
16        self.ordering = ordering(self.funcs)
17
18    def __call__(self, *args, **kwargs):
19        func, s = self.resolve(args)
20        return func(*args, **kwargs)
21
22    def resolve(self, args):
23        n = len(args)
24        for signature in self.ordering:
25            if len(signature) != n:
26                continue
27            s = unify(freeze(args), signature)
28            if s is not False:
29                result = self.funcs[signature]
30                return result, s
31        raise NotImplementedError("No match found. \nKnown matches: "
32                                  + str(self.ordering) + "\nInput: " + str(args))
33
34    def register(self, *signature):
35        def _(func):
36            self.add(signature, func)
37            return self
38        return _
39
40
41class VarDispatcher(Dispatcher):
42    """ A dispatcher that calls functions with variable names
43    >>> # xdoctest: +SKIP
44    >>> d = VarDispatcher('d')
45    >>> x = var('x')
46    >>> @d.register('inc', x)
47    ... def f(x):
48    ...     return x + 1
49    >>> @d.register('double', x)
50    ... def f(x):
51    ...     return x * 2
52    >>> d('inc', 10)
53    11
54    >>> d('double', 10)
55    20
56    """
57    def __call__(self, *args, **kwargs):
58        func, s = self.resolve(args)
59        d = {k.token: v for k, v in s.items()}
60        return func(**d)
61
62
63global_namespace = {}  # type: ignore[var-annotated]
64
65
66def match(*signature, **kwargs):
67    namespace = kwargs.get('namespace', global_namespace)
68    dispatcher = kwargs.get('Dispatcher', Dispatcher)
69
70    def _(func):
71        name = func.__name__
72
73        if name not in namespace:
74            namespace[name] = dispatcher(name)
75        d = namespace[name]
76
77        d.add(signature, func)
78
79        return d
80    return _
81
82
83def supercedes(a, b):
84    """ ``a`` is a more specific match than ``b`` """
85    if isvar(b) and not isvar(a):
86        return True
87    s = unify(a, b)
88    if s is False:
89        return False
90    s = {k: v for k, v in s.items() if not isvar(k) or not isvar(v)}
91    if reify(a, s) == a:
92        return True
93    if reify(b, s) == b:
94        return False
95
96
97# Taken from multipledispatch
98def edge(a, b, tie_breaker=hash):
99    """ A should be checked before B
100    Tie broken by tie_breaker, defaults to ``hash``
101    """
102    if supercedes(a, b):
103        if supercedes(b, a):
104            return tie_breaker(a) > tie_breaker(b)
105        else:
106            return True
107    return False
108
109
110# Taken from multipledispatch
111def ordering(signatures):
112    """ A sane ordering of signatures to check, first to last
113    Topological sort of edges as given by ``edge`` and ``supercedes``
114    """
115    signatures = list(map(tuple, signatures))
116    edges = [(a, b) for a in signatures for b in signatures if edge(a, b)]
117    edges = groupby(first, edges)
118    for s in signatures:
119        if s not in edges:
120            edges[s] = []
121    edges = {k: [b for a, b in v] for k, v in edges.items()}  # type: ignore[attr-defined, assignment]
122    return _toposort(edges)
123