1# mypy: allow-untyped-defs 2__all__ = ["hashable", "transitive_get", "raises", "reverse_dict", "xfail", "freeze"] 3def hashable(x): 4 try: 5 hash(x) 6 return True 7 except TypeError: 8 return False 9 10 11def transitive_get(key, d): 12 """ Transitive dict.get 13 >>> d = {1: 2, 2: 3, 3: 4} 14 >>> d.get(1) 15 2 16 >>> transitive_get(1, d) 17 4 18 """ 19 while hashable(key) and key in d: 20 key = d[key] 21 return key 22 23 24def raises(err, lamda): 25 try: 26 lamda() 27 return False 28 except err: 29 return True 30 31 32# Taken from theano/theano/gof/sched.py 33# Avoids licensing issues because this was written by Matthew Rocklin 34def _toposort(edges): 35 """ Topological sort algorithm by Kahn [1] - O(nodes + vertices) 36 inputs: 37 edges - a dict of the form {a: {b, c}} where b and c depend on a 38 outputs: 39 L - an ordered list of nodes that satisfy the dependencies of edges 40 >>> # xdoctest: +SKIP 41 >>> _toposort({1: (2, 3), 2: (3, )}) 42 [1, 2, 3] 43 Closely follows the wikipedia page [2] 44 [1] Kahn, Arthur B. (1962), "Topological sorting of large networks", 45 Communications of the ACM 46 [2] http://en.wikipedia.org/wiki/Toposort#Algorithms 47 """ 48 incoming_edges = reverse_dict(edges) 49 incoming_edges = {k: set(val) for k, val in incoming_edges.items()} 50 S = ({v for v in edges if v not in incoming_edges}) 51 L = [] 52 53 while S: 54 n = S.pop() 55 L.append(n) 56 for m in edges.get(n, ()): 57 assert n in incoming_edges[m] 58 incoming_edges[m].remove(n) 59 if not incoming_edges[m]: 60 S.add(m) 61 if any(incoming_edges.get(v, None) for v in edges): 62 raise ValueError("Input has cycles") 63 return L 64 65 66def reverse_dict(d): 67 """Reverses direction of dependence dict 68 >>> d = {'a': (1, 2), 'b': (2, 3), 'c':()} 69 >>> reverse_dict(d) # doctest: +SKIP 70 {1: ('a',), 2: ('a', 'b'), 3: ('b',)} 71 :note: dict order are not deterministic. As we iterate on the 72 input dict, it make the output of this function depend on the 73 dict order. So this function output order should be considered 74 as undeterministic. 75 """ 76 result = {} # type: ignore[var-annotated] 77 for key in d: 78 for val in d[key]: 79 result[val] = result.get(val, ()) + (key,) 80 return result 81 82 83def xfail(func): 84 try: 85 func() 86 raise Exception("XFailed test passed") # pragma:nocover # noqa: TRY002 87 except Exception: 88 pass 89 90 91def freeze(d): 92 """ Freeze container to hashable form 93 >>> freeze(1) 94 1 95 >>> freeze([1, 2]) 96 (1, 2) 97 >>> freeze({1: 2}) # doctest: +SKIP 98 frozenset([(1, 2)]) 99 """ 100 if isinstance(d, dict): 101 return frozenset(map(freeze, d.items())) 102 if isinstance(d, set): 103 return frozenset(map(freeze, d)) 104 if isinstance(d, (tuple, list)): 105 return tuple(map(freeze, d)) 106 return d 107