1# mypy: allow-untyped-defs 2import inspect 3import sys 4 5from .dispatcher import Dispatcher, MethodDispatcher 6 7global_namespace = {} # type: ignore[var-annotated] 8 9__all__ = ["dispatch", "ismethod"] 10 11def dispatch(*types, **kwargs): 12 """ Dispatch function on the types of the inputs 13 Supports dispatch on all non-keyword arguments. 14 Collects implementations based on the function name. Ignores namespaces. 15 If ambiguous type signatures occur a warning is raised when the function is 16 defined suggesting the additional method to break the ambiguity. 17 18 Example: 19 >>> # xdoctest: +SKIP 20 >>> @dispatch(int) 21 ... def f(x): 22 ... return x + 1 23 >>> @dispatch(float) 24 ... def f(x): 25 ... return x - 1 26 >>> # xdoctest: +SKIP 27 >>> f(3) 28 4 29 >>> f(3.0) 30 2.0 31 >>> # Specify an isolated namespace with the namespace keyword argument 32 >>> my_namespace = {} 33 >>> @dispatch(int, namespace=my_namespace) 34 ... def foo(x): 35 ... return x + 1 36 >>> # Dispatch on instance methods within classes 37 >>> class MyClass(object): 38 ... @dispatch(list) 39 ... def __init__(self, data): 40 ... self.data = data 41 ... @dispatch(int) 42 ... def __init__(self, datum): 43 ... self.data = [datum] 44 >>> MyClass([1, 2, 3]).data 45 [1, 2, 3] 46 >>> MyClass(3).data 47 [3] 48 """ 49 namespace = kwargs.get('namespace', global_namespace) 50 51 types = tuple(types) 52 53 def _df(func): 54 name = func.__name__ 55 56 if ismethod(func): 57 dispatcher = inspect.currentframe().f_back.f_locals.get( # type: ignore[union-attr] 58 name, # type: ignore[union-attr] 59 MethodDispatcher(name), 60 ) 61 else: 62 if name not in namespace: 63 namespace[name] = Dispatcher(name) 64 dispatcher = namespace[name] 65 66 dispatcher.add(types, func) 67 return dispatcher 68 return _df 69 70 71def ismethod(func): 72 """ Is func a method? 73 Note that this has to work as the method is defined but before the class is 74 defined. At this stage methods look like functions. 75 """ 76 if hasattr(inspect, "signature"): 77 signature = inspect.signature(func) 78 return signature.parameters.get('self', None) is not None 79 else: 80 if sys.version_info.major < 3: 81 spec = inspect.getargspec(func) # type: ignore[attr-defined] 82 else: 83 spec = inspect.getfullargspec(func) # type: ignore[union-attr, assignment] 84 return spec and spec.args and spec.args[0] == 'self' 85