xref: /aosp_15_r20/external/pytorch/torch/fx/experimental/unification/multipledispatch/variadic.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from .utils import typename
3
4__all__ = ["VariadicSignatureType", "isvariadic", "VariadicSignatureMeta", "Variadic"]
5
6class VariadicSignatureType(type):
7    # checking if subclass is a subclass of self
8    def __subclasscheck__(cls, subclass):
9        other_type = (subclass.variadic_type if isvariadic(subclass)
10                      else (subclass,))
11        return subclass is cls or all(
12            issubclass(other, cls.variadic_type) for other in other_type  # type: ignore[attr-defined]
13        )
14
15    def __eq__(cls, other):
16        """
17        Return True if other has the same variadic type
18        Parameters
19        ----------
20        other : object (type)
21            The object (type) to check
22        Returns
23        -------
24        bool
25            Whether or not `other` is equal to `self`
26        """
27        return (isvariadic(other) and
28                set(cls.variadic_type) == set(other.variadic_type))  # type: ignore[attr-defined]
29
30    def __hash__(cls):
31        return hash((type(cls), frozenset(cls.variadic_type)))  # type: ignore[attr-defined]
32
33
34def isvariadic(obj):
35    """Check whether the type `obj` is variadic.
36    Parameters
37    ----------
38    obj : type
39        The type to check
40    Returns
41    -------
42    bool
43        Whether or not `obj` is variadic
44    Examples
45    --------
46    >>> # xdoctest: +SKIP
47    >>> isvariadic(int)
48    False
49    >>> isvariadic(Variadic[int])
50    True
51    """
52    return isinstance(obj, VariadicSignatureType)
53
54
55class VariadicSignatureMeta(type):
56    """A metaclass that overrides ``__getitem__`` on the class. This is used to
57    generate a new type for Variadic signatures. See the Variadic class for
58    examples of how this behaves.
59    """
60    def __getitem__(cls, variadic_type):
61        if not (isinstance(variadic_type, (type, tuple)) or type(variadic_type)):
62            raise ValueError("Variadic types must be type or tuple of types"
63                             " (Variadic[int] or Variadic[(int, float)]")
64
65        if not isinstance(variadic_type, tuple):
66            variadic_type = variadic_type,
67        return VariadicSignatureType(
68            f'Variadic[{typename(variadic_type)}]',
69            (),
70            dict(variadic_type=variadic_type, __slots__=())
71        )
72
73
74class Variadic(metaclass=VariadicSignatureMeta):
75    """A class whose getitem method can be used to generate a new type
76    representing a specific variadic signature.
77    Examples
78    --------
79    >>> # xdoctest: +SKIP
80    >>> Variadic[int]  # any number of int arguments
81    <class 'multipledispatch.variadic.Variadic[int]'>
82    >>> Variadic[(int, str)]  # any number of one of int or str arguments
83    <class 'multipledispatch.variadic.Variadic[(int, str)]'>
84    >>> issubclass(int, Variadic[int])
85    True
86    >>> issubclass(int, Variadic[(int, str)])
87    True
88    >>> issubclass(str, Variadic[(int, str)])
89    True
90    >>> issubclass(float, Variadic[(int, str)])
91    False
92    """
93