xref: /aosp_15_r20/external/pytorch/torch/fx/experimental/unification/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2__all__ = ["hashable", "transitive_get", "raises", "reverse_dict", "xfail", "freeze"]
3def hashable(x):
4    try:
5        hash(x)
6        return True
7    except TypeError:
8        return False
9
10
11def transitive_get(key, d):
12    """ Transitive dict.get
13    >>> d = {1: 2, 2: 3, 3: 4}
14    >>> d.get(1)
15    2
16    >>> transitive_get(1, d)
17    4
18    """
19    while hashable(key) and key in d:
20        key = d[key]
21    return key
22
23
24def raises(err, lamda):
25    try:
26        lamda()
27        return False
28    except err:
29        return True
30
31
32# Taken from theano/theano/gof/sched.py
33# Avoids licensing issues because this was written by Matthew Rocklin
34def _toposort(edges):
35    """ Topological sort algorithm by Kahn [1] - O(nodes + vertices)
36    inputs:
37        edges - a dict of the form {a: {b, c}} where b and c depend on a
38    outputs:
39        L - an ordered list of nodes that satisfy the dependencies of edges
40    >>> # xdoctest: +SKIP
41    >>> _toposort({1: (2, 3), 2: (3, )})
42    [1, 2, 3]
43    Closely follows the wikipedia page [2]
44    [1] Kahn, Arthur B. (1962), "Topological sorting of large networks",
45    Communications of the ACM
46    [2] http://en.wikipedia.org/wiki/Toposort#Algorithms
47    """
48    incoming_edges = reverse_dict(edges)
49    incoming_edges = {k: set(val) for k, val in incoming_edges.items()}
50    S = ({v for v in edges if v not in incoming_edges})
51    L = []
52
53    while S:
54        n = S.pop()
55        L.append(n)
56        for m in edges.get(n, ()):
57            assert n in incoming_edges[m]
58            incoming_edges[m].remove(n)
59            if not incoming_edges[m]:
60                S.add(m)
61    if any(incoming_edges.get(v, None) for v in edges):
62        raise ValueError("Input has cycles")
63    return L
64
65
66def reverse_dict(d):
67    """Reverses direction of dependence dict
68    >>> d = {'a': (1, 2), 'b': (2, 3), 'c':()}
69    >>> reverse_dict(d)  # doctest: +SKIP
70    {1: ('a',), 2: ('a', 'b'), 3: ('b',)}
71    :note: dict order are not deterministic. As we iterate on the
72        input dict, it make the output of this function depend on the
73        dict order. So this function output order should be considered
74        as undeterministic.
75    """
76    result = {}  # type: ignore[var-annotated]
77    for key in d:
78        for val in d[key]:
79            result[val] = result.get(val, ()) + (key,)
80    return result
81
82
83def xfail(func):
84    try:
85        func()
86        raise Exception("XFailed test passed")  # pragma:nocover  # noqa: TRY002
87    except Exception:
88        pass
89
90
91def freeze(d):
92    """ Freeze container to hashable form
93    >>> freeze(1)
94    1
95    >>> freeze([1, 2])
96    (1, 2)
97    >>> freeze({1: 2}) # doctest: +SKIP
98    frozenset([(1, 2)])
99    """
100    if isinstance(d, dict):
101        return frozenset(map(freeze, d.items()))
102    if isinstance(d, set):
103        return frozenset(map(freeze, d))
104    if isinstance(d, (tuple, list)):
105        return tuple(map(freeze, d))
106    return d
107