xref: /aosp_15_r20/external/pytorch/torch/fx/experimental/unification/multipledispatch/core.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import inspect
3import sys
4
5from .dispatcher import Dispatcher, MethodDispatcher
6
7global_namespace = {}  # type: ignore[var-annotated]
8
9__all__ = ["dispatch", "ismethod"]
10
11def dispatch(*types, **kwargs):
12    """ Dispatch function on the types of the inputs
13    Supports dispatch on all non-keyword arguments.
14    Collects implementations based on the function name.  Ignores namespaces.
15    If ambiguous type signatures occur a warning is raised when the function is
16    defined suggesting the additional method to break the ambiguity.
17
18    Example:
19        >>> # xdoctest: +SKIP
20        >>> @dispatch(int)
21        ... def f(x):
22        ...     return x + 1
23        >>> @dispatch(float)
24        ... def f(x):
25        ...     return x - 1
26        >>> # xdoctest: +SKIP
27        >>> f(3)
28        4
29        >>> f(3.0)
30        2.0
31        >>> # Specify an isolated namespace with the namespace keyword argument
32        >>> my_namespace = {}
33        >>> @dispatch(int, namespace=my_namespace)
34        ... def foo(x):
35        ...     return x + 1
36        >>> # Dispatch on instance methods within classes
37        >>> class MyClass(object):
38        ...     @dispatch(list)
39        ...     def __init__(self, data):
40        ...         self.data = data
41        ...     @dispatch(int)
42        ...     def __init__(self, datum):
43        ...         self.data = [datum]
44        >>> MyClass([1, 2, 3]).data
45        [1, 2, 3]
46        >>> MyClass(3).data
47        [3]
48    """
49    namespace = kwargs.get('namespace', global_namespace)
50
51    types = tuple(types)
52
53    def _df(func):
54        name = func.__name__
55
56        if ismethod(func):
57            dispatcher = inspect.currentframe().f_back.f_locals.get(  # type: ignore[union-attr]
58                name,  # type: ignore[union-attr]
59                MethodDispatcher(name),
60            )
61        else:
62            if name not in namespace:
63                namespace[name] = Dispatcher(name)
64            dispatcher = namespace[name]
65
66        dispatcher.add(types, func)
67        return dispatcher
68    return _df
69
70
71def ismethod(func):
72    """ Is func a method?
73    Note that this has to work as the method is defined but before the class is
74    defined.  At this stage methods look like functions.
75    """
76    if hasattr(inspect, "signature"):
77        signature = inspect.signature(func)
78        return signature.parameters.get('self', None) is not None
79    else:
80        if sys.version_info.major < 3:
81            spec = inspect.getargspec(func)  # type: ignore[attr-defined]
82        else:
83            spec = inspect.getfullargspec(func)  # type: ignore[union-attr, assignment]
84        return spec and spec.args and spec.args[0] == 'self'
85