1# mypy: allow-untyped-defs 2from warnings import warn 3import inspect 4from typing_extensions import deprecated 5from .conflict import ordering, ambiguities, super_signature, AmbiguityWarning 6from .utils import expand_tuples 7from .variadic import Variadic, isvariadic 8import itertools as itl 9 10__all__ = ["MDNotImplementedError", "ambiguity_warn", "halt_ordering", "restart_ordering", "variadic_signature_matches_iter", 11 "variadic_signature_matches", "Dispatcher", "source", "MethodDispatcher", "str_signature", "warning_text"] 12 13class MDNotImplementedError(NotImplementedError): 14 """ A NotImplementedError for multiple dispatch """ 15 16 17def ambiguity_warn(dispatcher, ambiguities): 18 """ Raise warning when ambiguity is detected 19 Parameters 20 ---------- 21 dispatcher : Dispatcher 22 The dispatcher on which the ambiguity was detected 23 ambiguities : set 24 Set of type signature pairs that are ambiguous within this dispatcher 25 See Also: 26 Dispatcher.add 27 warning_text 28 """ 29 warn(warning_text(dispatcher.name, ambiguities), AmbiguityWarning) 30 31 32@deprecated( 33 "`halt_ordering` is deprecated, you can safely remove this call.", 34 category=FutureWarning, 35) 36def halt_ordering(): 37 """Deprecated interface to temporarily disable ordering.""" 38 39 40@deprecated( 41 "`restart_ordering` is deprecated, if you would like to eagerly order the dispatchers, " 42 "you should call the `reorder()` method on each dispatcher.", 43 category=FutureWarning, 44) 45def restart_ordering(on_ambiguity=ambiguity_warn): 46 """Deprecated interface to temporarily resume ordering.""" 47 48 49def variadic_signature_matches_iter(types, full_signature): 50 """Check if a set of input types matches a variadic signature. 51 Notes 52 ----- 53 The algorithm is as follows: 54 Initialize the current signature to the first in the sequence 55 For each type in `types`: 56 If the current signature is variadic 57 If the type matches the signature 58 yield True 59 Else 60 Try to get the next signature 61 If no signatures are left we can't possibly have a match 62 so yield False 63 Else 64 yield True if the type matches the current signature 65 Get the next signature 66 """ 67 sigiter = iter(full_signature) 68 sig = next(sigiter) 69 for typ in types: 70 matches = issubclass(typ, sig) 71 yield matches 72 if not isvariadic(sig): 73 # we're not matching a variadic argument, so move to the next 74 # element in the signature 75 sig = next(sigiter) 76 else: 77 try: 78 sig = next(sigiter) 79 except StopIteration: 80 assert isvariadic(sig) 81 yield True 82 else: 83 # We have signature items left over, so all of our arguments 84 # haven't matched 85 yield False 86 87 88def variadic_signature_matches(types, full_signature): 89 # No arguments always matches a variadic signature 90 assert full_signature 91 return all(variadic_signature_matches_iter(types, full_signature)) 92 93 94class Dispatcher: 95 """ Dispatch methods based on type signature 96 Use ``dispatch`` to add implementations 97 Examples 98 -------- 99 >>> # xdoctest: +SKIP("bad import name") 100 >>> from multipledispatch import dispatch 101 >>> @dispatch(int) 102 ... def f(x): 103 ... return x + 1 104 >>> @dispatch(float) 105 ... def f(x): 106 ... return x - 1 107 >>> f(3) 108 4 109 >>> f(3.0) 110 2.0 111 """ 112 __slots__ = '__name__', 'name', 'funcs', '_ordering', '_cache', 'doc' 113 114 def __init__(self, name, doc=None): 115 self.name = self.__name__ = name 116 self.funcs = {} 117 self.doc = doc 118 119 self._cache = {} 120 121 def register(self, *types, **kwargs): 122 """ register dispatcher with new implementation 123 >>> # xdoctest: +SKIP 124 >>> f = Dispatcher('f') 125 >>> @f.register(int) 126 ... def inc(x): 127 ... return x + 1 128 >>> @f.register(float) 129 ... def dec(x): 130 ... return x - 1 131 >>> @f.register(list) 132 ... @f.register(tuple) 133 ... def reverse(x): 134 ... return x[::-1] 135 >>> f(1) 136 2 137 >>> f(1.0) 138 0.0 139 >>> f([1, 2, 3]) 140 [3, 2, 1] 141 """ 142 def _df(func): 143 self.add(types, func, **kwargs) # type: ignore[call-arg] 144 return func 145 return _df 146 147 @classmethod 148 def get_func_params(cls, func): 149 if hasattr(inspect, "signature"): 150 sig = inspect.signature(func) 151 return sig.parameters.values() 152 153 @classmethod 154 def get_func_annotations(cls, func): 155 """ get annotations of function positional parameters 156 """ 157 params = cls.get_func_params(func) 158 if params: 159 Parameter = inspect.Parameter 160 161 params = (param for param in params 162 if param.kind in 163 (Parameter.POSITIONAL_ONLY, 164 Parameter.POSITIONAL_OR_KEYWORD)) 165 166 annotations = tuple( 167 param.annotation 168 for param in params) 169 170 if all(ann is not Parameter.empty for ann in annotations): 171 return annotations 172 173 def add(self, signature, func): 174 """ Add new types/method pair to dispatcher 175 >>> # xdoctest: +SKIP 176 >>> D = Dispatcher('add') 177 >>> D.add((int, int), lambda x, y: x + y) 178 >>> D.add((float, float), lambda x, y: x + y) 179 >>> D(1, 2) 180 3 181 >>> D(1, 2.0) 182 Traceback (most recent call last): 183 ... 184 NotImplementedError: Could not find signature for add: <int, float> 185 >>> # When ``add`` detects a warning it calls the ``on_ambiguity`` callback 186 >>> # with a dispatcher/itself, and a set of ambiguous type signature pairs 187 >>> # as inputs. See ``ambiguity_warn`` for an example. 188 """ 189 # Handle annotations 190 if not signature: 191 annotations = self.get_func_annotations(func) 192 if annotations: 193 signature = annotations 194 195 # Handle union types 196 if any(isinstance(typ, tuple) for typ in signature): 197 for typs in expand_tuples(signature): 198 self.add(typs, func) 199 return 200 201 new_signature = [] 202 203 for index, typ in enumerate(signature, start=1): 204 if not isinstance(typ, (type, list)): 205 str_sig = ', '.join(c.__name__ if isinstance(c, type) 206 else str(c) for c in signature) 207 raise TypeError(f"Tried to dispatch on non-type: {typ}\n" 208 f"In signature: <{str_sig}>\n" 209 f"In function: {self.name}") 210 211 # handle variadic signatures 212 if isinstance(typ, list): 213 if index != len(signature): 214 raise TypeError( 215 'Variadic signature must be the last element' 216 ) 217 218 if len(typ) != 1: 219 raise TypeError( 220 'Variadic signature must contain exactly one element. ' 221 'To use a variadic union type place the desired types ' 222 'inside of a tuple, e.g., [(int, str)]' 223 ) 224 new_signature.append(Variadic[typ[0]]) 225 else: 226 new_signature.append(typ) 227 228 self.funcs[tuple(new_signature)] = func 229 self._cache.clear() 230 231 try: 232 del self._ordering 233 except AttributeError: 234 pass 235 236 @property 237 def ordering(self): 238 try: 239 return self._ordering 240 except AttributeError: 241 return self.reorder() 242 243 def reorder(self, on_ambiguity=ambiguity_warn): 244 self._ordering = od = ordering(self.funcs) 245 amb = ambiguities(self.funcs) 246 if amb: 247 on_ambiguity(self, amb) 248 return od 249 250 def __call__(self, *args, **kwargs): 251 types = tuple([type(arg) for arg in args]) 252 try: 253 func = self._cache[types] 254 except KeyError as e: 255 func = self.dispatch(*types) 256 if not func: 257 raise NotImplementedError( 258 f'Could not find signature for {self.name}: <{str_signature(types)}>') from e 259 self._cache[types] = func 260 try: 261 return func(*args, **kwargs) 262 263 except MDNotImplementedError as e: 264 funcs = self.dispatch_iter(*types) 265 next(funcs) # burn first 266 for func in funcs: 267 try: 268 return func(*args, **kwargs) 269 except MDNotImplementedError: 270 pass 271 272 raise NotImplementedError( 273 "Matching functions for " 274 f"{self.name}: <{str_signature(types)}> found, but none completed successfully",) from e 275 276 def __str__(self): 277 return f"<dispatched {self.name}>" 278 __repr__ = __str__ 279 280 def dispatch(self, *types): 281 """Determine appropriate implementation for this type signature 282 This method is internal. Users should call this object as a function. 283 Implementation resolution occurs within the ``__call__`` method. 284 >>> # xdoctest: +SKIP 285 >>> from multipledispatch import dispatch 286 >>> @dispatch(int) 287 ... def inc(x): 288 ... return x + 1 289 >>> implementation = inc.dispatch(int) 290 >>> implementation(3) 291 4 292 >>> print(inc.dispatch(float)) 293 None 294 See Also: 295 ``multipledispatch.conflict`` - module to determine resolution order 296 """ 297 298 if types in self.funcs: 299 return self.funcs[types] 300 301 try: 302 return next(self.dispatch_iter(*types)) 303 except StopIteration: 304 return None 305 306 def dispatch_iter(self, *types): 307 308 n = len(types) 309 for signature in self.ordering: 310 if len(signature) == n and all(map(issubclass, types, signature)): 311 result = self.funcs[signature] 312 yield result 313 elif len(signature) and isvariadic(signature[-1]): 314 if variadic_signature_matches(types, signature): 315 result = self.funcs[signature] 316 yield result 317 318 @deprecated("`resolve()` is deprecated, use `dispatch(*types)`", category=FutureWarning) 319 def resolve(self, types): 320 """ Determine appropriate implementation for this type signature 321 .. deprecated:: 0.4.4 322 Use ``dispatch(*types)`` instead 323 """ 324 return self.dispatch(*types) 325 326 def __getstate__(self): 327 return {'name': self.name, 328 'funcs': self.funcs} 329 330 def __setstate__(self, d): 331 self.name = d['name'] 332 self.funcs = d['funcs'] 333 self._ordering = ordering(self.funcs) 334 self._cache = {} 335 336 @property 337 def __doc__(self): 338 docs = [f"Multiply dispatched method: {self.name}"] 339 340 if self.doc: 341 docs.append(self.doc) 342 343 other = [] 344 for sig in self.ordering[::-1]: 345 func = self.funcs[sig] 346 if func.__doc__: 347 s = f'Inputs: <{str_signature(sig)}>\n' 348 s += '-' * len(s) + '\n' 349 s += func.__doc__.strip() 350 docs.append(s) 351 else: 352 other.append(str_signature(sig)) 353 354 if other: 355 docs.append('Other signatures:\n ' + '\n '.join(other)) 356 357 return '\n\n'.join(docs) 358 359 def _help(self, *args): 360 return self.dispatch(*map(type, args)).__doc__ 361 362 def help(self, *args, **kwargs): 363 """ Print docstring for the function corresponding to inputs """ 364 print(self._help(*args)) 365 366 def _source(self, *args): 367 func = self.dispatch(*map(type, args)) 368 if not func: 369 raise TypeError("No function found") 370 return source(func) 371 372 def source(self, *args, **kwargs): 373 """ Print source code for the function corresponding to inputs """ 374 print(self._source(*args)) 375 376 377def source(func): 378 s = f'File: {inspect.getsourcefile(func)}\n\n' 379 s = s + inspect.getsource(func) 380 return s 381 382 383class MethodDispatcher(Dispatcher): 384 """ Dispatch methods based on type signature 385 See Also: 386 Dispatcher 387 """ 388 __slots__ = ('obj', 'cls') 389 390 @classmethod 391 def get_func_params(cls, func): 392 if hasattr(inspect, "signature"): 393 sig = inspect.signature(func) 394 return itl.islice(sig.parameters.values(), 1, None) 395 396 def __get__(self, instance, owner): 397 self.obj = instance 398 self.cls = owner 399 return self 400 401 def __call__(self, *args, **kwargs): 402 types = tuple([type(arg) for arg in args]) 403 func = self.dispatch(*types) 404 if not func: 405 raise NotImplementedError(f'Could not find signature for {self.name}: <{str_signature(types)}>') 406 return func(self.obj, *args, **kwargs) 407 408 409def str_signature(sig): 410 """ String representation of type signature 411 >>> str_signature((int, float)) 412 'int, float' 413 """ 414 return ', '.join(cls.__name__ for cls in sig) 415 416 417def warning_text(name, amb): 418 """ The text for ambiguity warnings """ 419 text = f"\nAmbiguities exist in dispatched function {name}\n\n" 420 text += "The following signatures may result in ambiguous behavior:\n" 421 for pair in amb: 422 text += "\t" + \ 423 ', '.join('[' + str_signature(s) + ']' for s in pair) + "\n" 424 text += "\n\nConsider making the following additions:\n\n" 425 text += '\n\n'.join(['@dispatch(' + str_signature(super_signature(s)) 426 + f')\ndef {name}(...)' for s in amb]) 427 return text 428