xref: /aosp_15_r20/external/pytorch/torch/fx/experimental/unification/multipledispatch/dispatcher.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from warnings import warn
3import inspect
4from typing_extensions import deprecated
5from .conflict import ordering, ambiguities, super_signature, AmbiguityWarning
6from .utils import expand_tuples
7from .variadic import Variadic, isvariadic
8import itertools as itl
9
10__all__ = ["MDNotImplementedError", "ambiguity_warn", "halt_ordering", "restart_ordering", "variadic_signature_matches_iter",
11           "variadic_signature_matches", "Dispatcher", "source", "MethodDispatcher", "str_signature", "warning_text"]
12
13class MDNotImplementedError(NotImplementedError):
14    """ A NotImplementedError for multiple dispatch """
15
16
17def ambiguity_warn(dispatcher, ambiguities):
18    """ Raise warning when ambiguity is detected
19    Parameters
20    ----------
21    dispatcher : Dispatcher
22        The dispatcher on which the ambiguity was detected
23    ambiguities : set
24        Set of type signature pairs that are ambiguous within this dispatcher
25    See Also:
26        Dispatcher.add
27        warning_text
28    """
29    warn(warning_text(dispatcher.name, ambiguities), AmbiguityWarning)
30
31
32@deprecated(
33    "`halt_ordering` is deprecated, you can safely remove this call.",
34    category=FutureWarning,
35)
36def halt_ordering():
37    """Deprecated interface to temporarily disable ordering."""
38
39
40@deprecated(
41    "`restart_ordering` is deprecated, if you would like to eagerly order the dispatchers, "
42    "you should call the `reorder()` method on each dispatcher.",
43    category=FutureWarning,
44)
45def restart_ordering(on_ambiguity=ambiguity_warn):
46    """Deprecated interface to temporarily resume ordering."""
47
48
49def variadic_signature_matches_iter(types, full_signature):
50    """Check if a set of input types matches a variadic signature.
51    Notes
52    -----
53    The algorithm is as follows:
54    Initialize the current signature to the first in the sequence
55    For each type in `types`:
56        If the current signature is variadic
57            If the type matches the signature
58                yield True
59            Else
60                Try to get the next signature
61                If no signatures are left we can't possibly have a match
62                    so yield False
63        Else
64            yield True if the type matches the current signature
65            Get the next signature
66    """
67    sigiter = iter(full_signature)
68    sig = next(sigiter)
69    for typ in types:
70        matches = issubclass(typ, sig)
71        yield matches
72        if not isvariadic(sig):
73            # we're not matching a variadic argument, so move to the next
74            # element in the signature
75            sig = next(sigiter)
76    else:
77        try:
78            sig = next(sigiter)
79        except StopIteration:
80            assert isvariadic(sig)
81            yield True
82        else:
83            # We have signature items left over, so all of our arguments
84            # haven't matched
85            yield False
86
87
88def variadic_signature_matches(types, full_signature):
89    # No arguments always matches a variadic signature
90    assert full_signature
91    return all(variadic_signature_matches_iter(types, full_signature))
92
93
94class Dispatcher:
95    """ Dispatch methods based on type signature
96    Use ``dispatch`` to add implementations
97    Examples
98    --------
99    >>> # xdoctest: +SKIP("bad import name")
100    >>> from multipledispatch import dispatch
101    >>> @dispatch(int)
102    ... def f(x):
103    ...     return x + 1
104    >>> @dispatch(float)
105    ... def f(x):
106    ...     return x - 1
107    >>> f(3)
108    4
109    >>> f(3.0)
110    2.0
111    """
112    __slots__ = '__name__', 'name', 'funcs', '_ordering', '_cache', 'doc'
113
114    def __init__(self, name, doc=None):
115        self.name = self.__name__ = name
116        self.funcs = {}
117        self.doc = doc
118
119        self._cache = {}
120
121    def register(self, *types, **kwargs):
122        """ register dispatcher with new implementation
123        >>> # xdoctest: +SKIP
124        >>> f = Dispatcher('f')
125        >>> @f.register(int)
126        ... def inc(x):
127        ...     return x + 1
128        >>> @f.register(float)
129        ... def dec(x):
130        ...     return x - 1
131        >>> @f.register(list)
132        ... @f.register(tuple)
133        ... def reverse(x):
134        ...     return x[::-1]
135        >>> f(1)
136        2
137        >>> f(1.0)
138        0.0
139        >>> f([1, 2, 3])
140        [3, 2, 1]
141        """
142        def _df(func):
143            self.add(types, func, **kwargs)   # type: ignore[call-arg]
144            return func
145        return _df
146
147    @classmethod
148    def get_func_params(cls, func):
149        if hasattr(inspect, "signature"):
150            sig = inspect.signature(func)
151            return sig.parameters.values()
152
153    @classmethod
154    def get_func_annotations(cls, func):
155        """ get annotations of function positional parameters
156        """
157        params = cls.get_func_params(func)
158        if params:
159            Parameter = inspect.Parameter
160
161            params = (param for param in params
162                      if param.kind in
163                      (Parameter.POSITIONAL_ONLY,
164                       Parameter.POSITIONAL_OR_KEYWORD))
165
166            annotations = tuple(
167                param.annotation
168                for param in params)
169
170            if all(ann is not Parameter.empty for ann in annotations):
171                return annotations
172
173    def add(self, signature, func):
174        """ Add new types/method pair to dispatcher
175        >>> # xdoctest: +SKIP
176        >>> D = Dispatcher('add')
177        >>> D.add((int, int), lambda x, y: x + y)
178        >>> D.add((float, float), lambda x, y: x + y)
179        >>> D(1, 2)
180        3
181        >>> D(1, 2.0)
182        Traceback (most recent call last):
183        ...
184        NotImplementedError: Could not find signature for add: <int, float>
185        >>> # When ``add`` detects a warning it calls the ``on_ambiguity`` callback
186        >>> # with a dispatcher/itself, and a set of ambiguous type signature pairs
187        >>> # as inputs.  See ``ambiguity_warn`` for an example.
188        """
189        # Handle annotations
190        if not signature:
191            annotations = self.get_func_annotations(func)
192            if annotations:
193                signature = annotations
194
195        # Handle union types
196        if any(isinstance(typ, tuple) for typ in signature):
197            for typs in expand_tuples(signature):
198                self.add(typs, func)
199            return
200
201        new_signature = []
202
203        for index, typ in enumerate(signature, start=1):
204            if not isinstance(typ, (type, list)):
205                str_sig = ', '.join(c.__name__ if isinstance(c, type)
206                                    else str(c) for c in signature)
207                raise TypeError(f"Tried to dispatch on non-type: {typ}\n"
208                                f"In signature: <{str_sig}>\n"
209                                f"In function: {self.name}")
210
211            # handle variadic signatures
212            if isinstance(typ, list):
213                if index != len(signature):
214                    raise TypeError(
215                        'Variadic signature must be the last element'
216                    )
217
218                if len(typ) != 1:
219                    raise TypeError(
220                        'Variadic signature must contain exactly one element. '
221                        'To use a variadic union type place the desired types '
222                        'inside of a tuple, e.g., [(int, str)]'
223                    )
224                new_signature.append(Variadic[typ[0]])
225            else:
226                new_signature.append(typ)
227
228        self.funcs[tuple(new_signature)] = func
229        self._cache.clear()
230
231        try:
232            del self._ordering
233        except AttributeError:
234            pass
235
236    @property
237    def ordering(self):
238        try:
239            return self._ordering
240        except AttributeError:
241            return self.reorder()
242
243    def reorder(self, on_ambiguity=ambiguity_warn):
244        self._ordering = od = ordering(self.funcs)
245        amb = ambiguities(self.funcs)
246        if amb:
247            on_ambiguity(self, amb)
248        return od
249
250    def __call__(self, *args, **kwargs):
251        types = tuple([type(arg) for arg in args])
252        try:
253            func = self._cache[types]
254        except KeyError as e:
255            func = self.dispatch(*types)
256            if not func:
257                raise NotImplementedError(
258                    f'Could not find signature for {self.name}: <{str_signature(types)}>') from e
259            self._cache[types] = func
260        try:
261            return func(*args, **kwargs)
262
263        except MDNotImplementedError as e:
264            funcs = self.dispatch_iter(*types)
265            next(funcs)  # burn first
266            for func in funcs:
267                try:
268                    return func(*args, **kwargs)
269                except MDNotImplementedError:
270                    pass
271
272            raise NotImplementedError(
273                "Matching functions for "
274                f"{self.name}: <{str_signature(types)}> found, but none completed successfully",) from e
275
276    def __str__(self):
277        return f"<dispatched {self.name}>"
278    __repr__ = __str__
279
280    def dispatch(self, *types):
281        """Determine appropriate implementation for this type signature
282        This method is internal.  Users should call this object as a function.
283        Implementation resolution occurs within the ``__call__`` method.
284        >>> # xdoctest: +SKIP
285        >>> from multipledispatch import dispatch
286        >>> @dispatch(int)
287        ... def inc(x):
288        ...     return x + 1
289        >>> implementation = inc.dispatch(int)
290        >>> implementation(3)
291        4
292        >>> print(inc.dispatch(float))
293        None
294        See Also:
295          ``multipledispatch.conflict`` - module to determine resolution order
296        """
297
298        if types in self.funcs:
299            return self.funcs[types]
300
301        try:
302            return next(self.dispatch_iter(*types))
303        except StopIteration:
304            return None
305
306    def dispatch_iter(self, *types):
307
308        n = len(types)
309        for signature in self.ordering:
310            if len(signature) == n and all(map(issubclass, types, signature)):
311                result = self.funcs[signature]
312                yield result
313            elif len(signature) and isvariadic(signature[-1]):
314                if variadic_signature_matches(types, signature):
315                    result = self.funcs[signature]
316                    yield result
317
318    @deprecated("`resolve()` is deprecated, use `dispatch(*types)`", category=FutureWarning)
319    def resolve(self, types):
320        """ Determine appropriate implementation for this type signature
321        .. deprecated:: 0.4.4
322            Use ``dispatch(*types)`` instead
323        """
324        return self.dispatch(*types)
325
326    def __getstate__(self):
327        return {'name': self.name,
328                'funcs': self.funcs}
329
330    def __setstate__(self, d):
331        self.name = d['name']
332        self.funcs = d['funcs']
333        self._ordering = ordering(self.funcs)
334        self._cache = {}
335
336    @property
337    def __doc__(self):
338        docs = [f"Multiply dispatched method: {self.name}"]
339
340        if self.doc:
341            docs.append(self.doc)
342
343        other = []
344        for sig in self.ordering[::-1]:
345            func = self.funcs[sig]
346            if func.__doc__:
347                s = f'Inputs: <{str_signature(sig)}>\n'
348                s += '-' * len(s) + '\n'
349                s += func.__doc__.strip()
350                docs.append(s)
351            else:
352                other.append(str_signature(sig))
353
354        if other:
355            docs.append('Other signatures:\n    ' + '\n    '.join(other))
356
357        return '\n\n'.join(docs)
358
359    def _help(self, *args):
360        return self.dispatch(*map(type, args)).__doc__
361
362    def help(self, *args, **kwargs):
363        """ Print docstring for the function corresponding to inputs """
364        print(self._help(*args))
365
366    def _source(self, *args):
367        func = self.dispatch(*map(type, args))
368        if not func:
369            raise TypeError("No function found")
370        return source(func)
371
372    def source(self, *args, **kwargs):
373        """ Print source code for the function corresponding to inputs """
374        print(self._source(*args))
375
376
377def source(func):
378    s = f'File: {inspect.getsourcefile(func)}\n\n'
379    s = s + inspect.getsource(func)
380    return s
381
382
383class MethodDispatcher(Dispatcher):
384    """ Dispatch methods based on type signature
385    See Also:
386        Dispatcher
387    """
388    __slots__ = ('obj', 'cls')
389
390    @classmethod
391    def get_func_params(cls, func):
392        if hasattr(inspect, "signature"):
393            sig = inspect.signature(func)
394            return itl.islice(sig.parameters.values(), 1, None)
395
396    def __get__(self, instance, owner):
397        self.obj = instance
398        self.cls = owner
399        return self
400
401    def __call__(self, *args, **kwargs):
402        types = tuple([type(arg) for arg in args])
403        func = self.dispatch(*types)
404        if not func:
405            raise NotImplementedError(f'Could not find signature for {self.name}: <{str_signature(types)}>')
406        return func(self.obj, *args, **kwargs)
407
408
409def str_signature(sig):
410    """ String representation of type signature
411    >>> str_signature((int, float))
412    'int, float'
413    """
414    return ', '.join(cls.__name__ for cls in sig)
415
416
417def warning_text(name, amb):
418    """ The text for ambiguity warnings """
419    text = f"\nAmbiguities exist in dispatched function {name}\n\n"
420    text += "The following signatures may result in ambiguous behavior:\n"
421    for pair in amb:
422        text += "\t" + \
423                ', '.join('[' + str_signature(s) + ']' for s in pair) + "\n"
424    text += "\n\nConsider making the following additions:\n\n"
425    text += '\n\n'.join(['@dispatch(' + str_signature(super_signature(s))
426                         + f')\ndef {name}(...)' for s in amb])
427    return text
428