xref: /aosp_15_r20/external/pytorch/torch/_numpy/_funcs.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import inspect
4import itertools
5
6from . import _funcs_impl, _reductions_impl
7from ._normalizations import normalizer
8
9
10# _funcs_impl.py contains functions which mimic NumPy's eponymous equivalents,
11# and consume/return PyTorch tensors/dtypes.
12# They are also type annotated.
13# Pull these functions from _funcs_impl and decorate them with @normalizer, which
14# - Converts any input `np.ndarray`, `torch._numpy.ndarray`, list of lists, Python scalars, etc into a `torch.Tensor`.
15# - Maps NumPy dtypes to PyTorch dtypes
16# - If the input to the `axis` kwarg is an ndarray, it maps it into a tuple
17# - Implements the semantics for the `out=` arg
18# - Wraps back the outputs into `torch._numpy.ndarrays`
19
20
21def _public_functions(mod):
22    def is_public_function(f):
23        return inspect.isfunction(f) and not f.__name__.startswith("_")
24
25    return inspect.getmembers(mod, is_public_function)
26
27
28# We fill in __all__ in the loop below
29__all__ = []
30
31# decorate implementer functions with argument normalizers and export to the top namespace
32for name, func in itertools.chain(
33    _public_functions(_funcs_impl), _public_functions(_reductions_impl)
34):
35    if name in ["percentile", "quantile", "median"]:
36        decorated = normalizer(func, promote_scalar_result=True)
37    elif name == "einsum":
38        # normalized manually
39        decorated = func
40    else:
41        decorated = normalizer(func)
42
43    decorated.__qualname__ = name
44    decorated.__name__ = name
45    vars()[name] = decorated
46    __all__.append(name)
47
48
49"""
50Vendored objects from numpy.lib.index_tricks
51"""
52
53
54class IndexExpression:
55    """
56    Written by Konrad Hinsen <[email protected]>
57    last revision: 1999-7-23
58
59    Cosmetic changes by T. Oliphant 2001
60    """
61
62    def __init__(self, maketuple):
63        self.maketuple = maketuple
64
65    def __getitem__(self, item):
66        if self.maketuple and not isinstance(item, tuple):
67            return (item,)
68        else:
69            return item
70
71
72index_exp = IndexExpression(maketuple=True)
73s_ = IndexExpression(maketuple=False)
74
75
76__all__ += ["index_exp", "s_"]
77