1# mypy: allow-untyped-defs 2from .core import unify, reify # type: ignore[attr-defined] 3from .variable import isvar 4from .utils import _toposort, freeze 5from .unification_tools import groupby, first # type: ignore[import] 6 7 8class Dispatcher: 9 def __init__(self, name): 10 self.name = name 11 self.funcs = {} 12 self.ordering = [] 13 14 def add(self, signature, func): 15 self.funcs[freeze(signature)] = func 16 self.ordering = ordering(self.funcs) 17 18 def __call__(self, *args, **kwargs): 19 func, s = self.resolve(args) 20 return func(*args, **kwargs) 21 22 def resolve(self, args): 23 n = len(args) 24 for signature in self.ordering: 25 if len(signature) != n: 26 continue 27 s = unify(freeze(args), signature) 28 if s is not False: 29 result = self.funcs[signature] 30 return result, s 31 raise NotImplementedError("No match found. \nKnown matches: " 32 + str(self.ordering) + "\nInput: " + str(args)) 33 34 def register(self, *signature): 35 def _(func): 36 self.add(signature, func) 37 return self 38 return _ 39 40 41class VarDispatcher(Dispatcher): 42 """ A dispatcher that calls functions with variable names 43 >>> # xdoctest: +SKIP 44 >>> d = VarDispatcher('d') 45 >>> x = var('x') 46 >>> @d.register('inc', x) 47 ... def f(x): 48 ... return x + 1 49 >>> @d.register('double', x) 50 ... def f(x): 51 ... return x * 2 52 >>> d('inc', 10) 53 11 54 >>> d('double', 10) 55 20 56 """ 57 def __call__(self, *args, **kwargs): 58 func, s = self.resolve(args) 59 d = {k.token: v for k, v in s.items()} 60 return func(**d) 61 62 63global_namespace = {} # type: ignore[var-annotated] 64 65 66def match(*signature, **kwargs): 67 namespace = kwargs.get('namespace', global_namespace) 68 dispatcher = kwargs.get('Dispatcher', Dispatcher) 69 70 def _(func): 71 name = func.__name__ 72 73 if name not in namespace: 74 namespace[name] = dispatcher(name) 75 d = namespace[name] 76 77 d.add(signature, func) 78 79 return d 80 return _ 81 82 83def supercedes(a, b): 84 """ ``a`` is a more specific match than ``b`` """ 85 if isvar(b) and not isvar(a): 86 return True 87 s = unify(a, b) 88 if s is False: 89 return False 90 s = {k: v for k, v in s.items() if not isvar(k) or not isvar(v)} 91 if reify(a, s) == a: 92 return True 93 if reify(b, s) == b: 94 return False 95 96 97# Taken from multipledispatch 98def edge(a, b, tie_breaker=hash): 99 """ A should be checked before B 100 Tie broken by tie_breaker, defaults to ``hash`` 101 """ 102 if supercedes(a, b): 103 if supercedes(b, a): 104 return tie_breaker(a) > tie_breaker(b) 105 else: 106 return True 107 return False 108 109 110# Taken from multipledispatch 111def ordering(signatures): 112 """ A sane ordering of signatures to check, first to last 113 Topological sort of edges as given by ``edge`` and ``supercedes`` 114 """ 115 signatures = list(map(tuple, signatures)) 116 edges = [(a, b) for a in signatures for b in signatures if edge(a, b)] 117 edges = groupby(first, edges) 118 for s in signatures: 119 if s not in edges: 120 edges[s] = [] 121 edges = {k: [b for a, b in v] for k, v in edges.items()} # type: ignore[attr-defined, assignment] 122 return _toposort(edges) 123