xref: /aosp_15_r20/external/pytorch/torch/fx/experimental/unification/multipledispatch/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from collections import OrderedDict
3
4__all__ = ["raises", "expand_tuples", "reverse_dict", "groupby", "typename"]
5
6def raises(err, lamda):
7    try:
8        lamda()
9        return False
10    except err:
11        return True
12
13
14def expand_tuples(L):
15    """
16    >>> expand_tuples([1, (2, 3)])
17    [(1, 2), (1, 3)]
18    >>> expand_tuples([1, 2])
19    [(1, 2)]
20    """
21    if not L:
22        return [()]
23    elif not isinstance(L[0], tuple):
24        rest = expand_tuples(L[1:])
25        return [(L[0],) + t for t in rest]
26    else:
27        rest = expand_tuples(L[1:])
28        return [(item,) + t for t in rest for item in L[0]]
29
30
31# Taken from theano/theano/gof/sched.py
32# Avoids licensing issues because this was written by Matthew Rocklin
33def _toposort(edges):
34    """ Topological sort algorithm by Kahn [1] - O(nodes + vertices)
35    inputs:
36        edges - a dict of the form {a: {b, c}} where b and c depend on a
37    outputs:
38        L - an ordered list of nodes that satisfy the dependencies of edges
39    >>> _toposort({1: (2, 3), 2: (3, )})
40    [1, 2, 3]
41    >>> # Closely follows the wikipedia page [2]
42    >>> # [1] Kahn, Arthur B. (1962), "Topological sorting of large networks",
43    >>> # Communications of the ACM
44    >>> # [2] http://en.wikipedia.org/wiki/Toposort#Algorithms
45    """
46    incoming_edges = reverse_dict(edges)
47    incoming_edges = OrderedDict((k, set(val))
48                                 for k, val in incoming_edges.items())
49    S = OrderedDict.fromkeys(v for v in edges if v not in incoming_edges)
50    L = []
51
52    while S:
53        n, _ = S.popitem()
54        L.append(n)
55        for m in edges.get(n, ()):
56            assert n in incoming_edges[m]
57            incoming_edges[m].remove(n)
58            if not incoming_edges[m]:
59                S[m] = None
60    if any(incoming_edges.get(v, None) for v in edges):
61        raise ValueError("Input has cycles")
62    return L
63
64
65def reverse_dict(d):
66    """Reverses direction of dependence dict
67    >>> d = {'a': (1, 2), 'b': (2, 3), 'c':()}
68    >>> reverse_dict(d)  # doctest: +SKIP
69    {1: ('a',), 2: ('a', 'b'), 3: ('b',)}
70    :note: dict order are not deterministic. As we iterate on the
71        input dict, it make the output of this function depend on the
72        dict order. So this function output order should be considered
73        as undeterministic.
74    """
75    result = OrderedDict()  # type: ignore[var-annotated]
76    for key in d:
77        for val in d[key]:
78            result[val] = result.get(val, ()) + (key,)
79    return result
80
81
82# Taken from toolz
83# Avoids licensing issues because this version was authored by Matthew Rocklin
84def groupby(func, seq):
85    """ Group a collection by a key function
86    >>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank']
87    >>> groupby(len, names)  # doctest: +SKIP
88    {3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']}
89    >>> iseven = lambda x: x % 2 == 0
90    >>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8])  # doctest: +SKIP
91    {False: [1, 3, 5, 7], True: [2, 4, 6, 8]}
92    See Also:
93        ``countby``
94    """
95
96    d = OrderedDict()  # type: ignore[var-annotated]
97    for item in seq:
98        key = func(item)
99        if key not in d:
100            d[key] = []
101        d[key].append(item)
102    return d
103
104
105def typename(type):
106    """Get the name of `type`.
107    Parameters
108    ----------
109    type : Union[Type, Tuple[Type]]
110    Returns
111    -------
112    str
113        The name of `type` or a tuple of the names of the types in `type`.
114    Examples
115    --------
116    >>> typename(int)
117    'int'
118    >>> typename((int, float))
119    '(int, float)'
120    """
121    try:
122        return type.__name__
123    except AttributeError:
124        if len(type) == 1:
125            return typename(*type)
126        return f"({', '.join(map(typename, type))})"
127