xref: /aosp_15_r20/external/pytorch/torch/fx/experimental/unification/multipledispatch/conflict.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from .utils import _toposort, groupby
3from .variadic import isvariadic
4import operator
5
6__all__ = ["AmbiguityWarning", "supercedes", "consistent", "ambiguous", "ambiguities", "super_signature",
7           "edge", "ordering"]
8
9class AmbiguityWarning(Warning):
10    pass
11
12
13def supercedes(a, b):
14    """ A is consistent and strictly more specific than B """
15    if len(a) < len(b):
16        # only case is if a is empty and b is variadic
17        return not a and len(b) == 1 and isvariadic(b[-1])
18    elif len(a) == len(b):
19        return all(map(issubclass, a, b))
20    else:
21        # len(a) > len(b)
22        p1 = 0
23        p2 = 0
24        while p1 < len(a) and p2 < len(b):
25            cur_a = a[p1]
26            cur_b = b[p2]
27            if not (isvariadic(cur_a) or isvariadic(cur_b)):
28                if not issubclass(cur_a, cur_b):
29                    return False
30                p1 += 1
31                p2 += 1
32            elif isvariadic(cur_a):
33                assert p1 == len(a) - 1
34                return p2 == len(b) - 1 and issubclass(cur_a, cur_b)
35            elif isvariadic(cur_b):
36                assert p2 == len(b) - 1
37                if not issubclass(cur_a, cur_b):
38                    return False
39                p1 += 1
40        return p2 == len(b) - 1 and p1 == len(a)
41
42
43def consistent(a, b):
44    """ It is possible for an argument list to satisfy both A and B """
45
46    # Need to check for empty args
47    if not a:
48        return not b or isvariadic(b[0])
49    if not b:
50        return not a or isvariadic(a[0])
51
52    # Non-empty args check for mutual subclasses
53    if len(a) == len(b):
54        return all(issubclass(aa, bb) or issubclass(bb, aa)
55                   for aa, bb in zip(a, b))
56    else:
57        p1 = 0
58        p2 = 0
59        while p1 < len(a) and p2 < len(b):
60            cur_a = a[p1]
61            cur_b = b[p2]
62            if not issubclass(cur_b, cur_a) and not issubclass(cur_a, cur_b):
63                return False
64            if not (isvariadic(cur_a) or isvariadic(cur_b)):
65                p1 += 1
66                p2 += 1
67            elif isvariadic(cur_a):
68                p2 += 1
69            elif isvariadic(cur_b):
70                p1 += 1
71        # We only need to check for variadic ends
72        # Variadic types are guaranteed to be the last element
73        return (isvariadic(cur_a) and p2 == len(b) or  # type: ignore[possibly-undefined]
74                isvariadic(cur_b) and p1 == len(a))  # type: ignore[possibly-undefined]
75
76
77def ambiguous(a, b):
78    """ A is consistent with B but neither is strictly more specific """
79    return consistent(a, b) and not (supercedes(a, b) or supercedes(b, a))
80
81
82def ambiguities(signatures):
83    """ All signature pairs such that A is ambiguous with B """
84    signatures = list(map(tuple, signatures))
85    return {(a, b) for a in signatures for b in signatures
86            if hash(a) < hash(b)
87            and ambiguous(a, b)
88            and not any(supercedes(c, a) and supercedes(c, b)
89            for c in signatures)}
90
91
92def super_signature(signatures):
93    """ A signature that would break ambiguities """
94    n = len(signatures[0])
95    assert all(len(s) == n for s in signatures)
96
97    return [max((type.mro(sig[i]) for sig in signatures), key=len)[0]
98            for i in range(n)]
99
100
101def edge(a, b, tie_breaker=hash):
102    """ A should be checked before B
103    Tie broken by tie_breaker, defaults to ``hash``
104    """
105    # A either supercedes B and B does not supercede A or if B does then call
106    # tie_breaker
107    return supercedes(a, b) and (not supercedes(b, a) or tie_breaker(a) > tie_breaker(b))
108
109
110def ordering(signatures):
111    """ A sane ordering of signatures to check, first to last
112    Topological sort of edges as given by ``edge`` and ``supercedes``
113    """
114    signatures = list(map(tuple, signatures))
115    edges = [(a, b) for a in signatures for b in signatures if edge(a, b)]
116    edges = groupby(operator.itemgetter(0), edges)
117    for s in signatures:
118        if s not in edges:
119            edges[s] = []
120    edges = {k: [b for a, b in v] for k, v in edges.items()}  # type: ignore[assignment, attr-defined]
121    return _toposort(edges)
122