1# mypy: ignore-errors 2 3import inspect 4import itertools 5 6from . import _funcs_impl, _reductions_impl 7from ._normalizations import normalizer 8 9 10# _funcs_impl.py contains functions which mimic NumPy's eponymous equivalents, 11# and consume/return PyTorch tensors/dtypes. 12# They are also type annotated. 13# Pull these functions from _funcs_impl and decorate them with @normalizer, which 14# - Converts any input `np.ndarray`, `torch._numpy.ndarray`, list of lists, Python scalars, etc into a `torch.Tensor`. 15# - Maps NumPy dtypes to PyTorch dtypes 16# - If the input to the `axis` kwarg is an ndarray, it maps it into a tuple 17# - Implements the semantics for the `out=` arg 18# - Wraps back the outputs into `torch._numpy.ndarrays` 19 20 21def _public_functions(mod): 22 def is_public_function(f): 23 return inspect.isfunction(f) and not f.__name__.startswith("_") 24 25 return inspect.getmembers(mod, is_public_function) 26 27 28# We fill in __all__ in the loop below 29__all__ = [] 30 31# decorate implementer functions with argument normalizers and export to the top namespace 32for name, func in itertools.chain( 33 _public_functions(_funcs_impl), _public_functions(_reductions_impl) 34): 35 if name in ["percentile", "quantile", "median"]: 36 decorated = normalizer(func, promote_scalar_result=True) 37 elif name == "einsum": 38 # normalized manually 39 decorated = func 40 else: 41 decorated = normalizer(func) 42 43 decorated.__qualname__ = name 44 decorated.__name__ = name 45 vars()[name] = decorated 46 __all__.append(name) 47 48 49""" 50Vendored objects from numpy.lib.index_tricks 51""" 52 53 54class IndexExpression: 55 """ 56 Written by Konrad Hinsen <[email protected]> 57 last revision: 1999-7-23 58 59 Cosmetic changes by T. Oliphant 2001 60 """ 61 62 def __init__(self, maketuple): 63 self.maketuple = maketuple 64 65 def __getitem__(self, item): 66 if self.maketuple and not isinstance(item, tuple): 67 return (item,) 68 else: 69 return item 70 71 72index_exp = IndexExpression(maketuple=True) 73s_ = IndexExpression(maketuple=False) 74 75 76__all__ += ["index_exp", "s_"] 77