xref: /aosp_15_r20/external/pytorch/torch/_numpy/_normalizations.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3""" "Normalize" arguments: convert array_likes to tensors, dtypes to torch dtypes and so on.
4"""
5from __future__ import annotations
6
7import functools
8import inspect
9import operator
10import typing
11
12import torch
13
14from . import _dtypes, _dtypes_impl, _util
15
16
17ArrayLike = typing.TypeVar("ArrayLike")
18Scalar = typing.Union[int, float, complex, bool]
19ArrayLikeOrScalar = typing.Union[ArrayLike, Scalar]
20
21DTypeLike = typing.TypeVar("DTypeLike")
22AxisLike = typing.TypeVar("AxisLike")
23NDArray = typing.TypeVar("NDArray")
24CastingModes = typing.TypeVar("CastingModes")
25KeepDims = typing.TypeVar("KeepDims")
26
27# OutArray is to annotate the out= array argument.
28#
29# This one is special is several respects:
30# First, It needs to be an NDArray, and we need to preserve the `result is out`
31# semantics. Therefore, we cannot just extract the Tensor from the out array.
32# So we never pass the out array to implementer functions and handle it in the
33# `normalizer` below.
34# Second, the out= argument can be either keyword or positional argument, and
35# as a positional arg, it can be anywhere in the signature.
36# To handle all this, we define a special `OutArray` annotation and dispatch on it.
37#
38OutArray = typing.TypeVar("OutArray")
39
40try:
41    from typing import NotImplementedType
42except ImportError:
43    NotImplementedType = typing.TypeVar("NotImplementedType")
44
45
46def normalize_array_like(x, parm=None):
47    from ._ndarray import asarray
48
49    return asarray(x).tensor
50
51
52def normalize_array_like_or_scalar(x, parm=None):
53    if _dtypes_impl.is_scalar_or_symbolic(x):
54        return x
55    return normalize_array_like(x, parm)
56
57
58def normalize_optional_array_like_or_scalar(x, parm=None):
59    if x is None:
60        return None
61    return normalize_array_like_or_scalar(x, parm)
62
63
64def normalize_optional_array_like(x, parm=None):
65    # This explicit normalizer is needed because otherwise normalize_array_like
66    # does not run for a parameter annotated as Optional[ArrayLike]
67    return None if x is None else normalize_array_like(x, parm)
68
69
70def normalize_seq_array_like(x, parm=None):
71    return tuple(normalize_array_like(value) for value in x)
72
73
74def normalize_dtype(dtype, parm=None):
75    # cf _decorators.dtype_to_torch
76    torch_dtype = None
77    if dtype is not None:
78        dtype = _dtypes.dtype(dtype)
79        torch_dtype = dtype.torch_dtype
80    return torch_dtype
81
82
83def normalize_not_implemented(arg, parm):
84    if arg != parm.default:
85        raise NotImplementedError(f"'{parm.name}' parameter is not supported.")
86
87
88def normalize_axis_like(arg, parm=None):
89    from ._ndarray import ndarray
90
91    if isinstance(arg, ndarray):
92        arg = operator.index(arg)
93    return arg
94
95
96def normalize_ndarray(arg, parm=None):
97    # check the arg is an ndarray, extract its tensor attribute
98    if arg is None:
99        return arg
100
101    from ._ndarray import ndarray
102
103    if not isinstance(arg, ndarray):
104        raise TypeError(f"'{parm.name}' must be an array")
105    return arg.tensor
106
107
108def normalize_outarray(arg, parm=None):
109    # almost normalize_ndarray, only return the array, not its tensor
110    if arg is None:
111        return arg
112    from ._ndarray import ndarray
113
114    # Dynamo can pass torch tensors as out arguments,
115    # wrap it in an ndarray before processing
116    if isinstance(arg, torch.Tensor):
117        arg = ndarray(arg)
118
119    if not isinstance(arg, ndarray):
120        raise TypeError(f"'{parm.name}' must be an array")
121    return arg
122
123
124def normalize_casting(arg, parm=None):
125    if arg not in ["no", "equiv", "safe", "same_kind", "unsafe"]:
126        raise ValueError(
127            f"casting must be one of 'no', 'equiv', 'safe', 'same_kind', or 'unsafe' (got '{arg}')"
128        )
129    return arg
130
131
132normalizers = {
133    "ArrayLike": normalize_array_like,
134    "ArrayLikeOrScalar": normalize_array_like_or_scalar,
135    "Optional[ArrayLike]": normalize_optional_array_like,
136    "Sequence[ArrayLike]": normalize_seq_array_like,
137    "Optional[ArrayLikeOrScalar]": normalize_optional_array_like_or_scalar,
138    "Optional[NDArray]": normalize_ndarray,
139    "Optional[OutArray]": normalize_outarray,
140    "NDArray": normalize_ndarray,
141    "Optional[DTypeLike]": normalize_dtype,
142    "AxisLike": normalize_axis_like,
143    "NotImplementedType": normalize_not_implemented,
144    "Optional[CastingModes]": normalize_casting,
145}
146
147
148def maybe_normalize(arg, parm):
149    """Normalize arg if a normalizer is registered."""
150    normalizer = normalizers.get(parm.annotation, None)
151    return normalizer(arg, parm) if normalizer else arg
152
153
154# ### Return value helpers ###
155
156
157def maybe_copy_to(out, result, promote_scalar_result=False):
158    # NB: here out is either an ndarray or None
159    if out is None:
160        return result
161    elif isinstance(result, torch.Tensor):
162        if result.shape != out.shape:
163            can_fit = result.numel() == 1 and out.ndim == 0
164            if promote_scalar_result and can_fit:
165                result = result.squeeze()
166            else:
167                raise ValueError(
168                    f"Bad size of the out array: out.shape = {out.shape}"
169                    f" while result.shape = {result.shape}."
170                )
171        out.tensor.copy_(result)
172        return out
173    elif isinstance(result, (tuple, list)):
174        return type(result)(
175            maybe_copy_to(o, r, promote_scalar_result) for o, r in zip(out, result)
176        )
177    else:
178        raise AssertionError  # We should never hit this path
179
180
181def wrap_tensors(result):
182    from ._ndarray import ndarray
183
184    if isinstance(result, torch.Tensor):
185        return ndarray(result)
186    elif isinstance(result, (tuple, list)):
187        result = type(result)(wrap_tensors(x) for x in result)
188    return result
189
190
191def array_or_scalar(values, py_type=float, return_scalar=False):
192    if return_scalar:
193        return py_type(values.item())
194    else:
195        from ._ndarray import ndarray
196
197        return ndarray(values)
198
199
200# ### The main decorator to normalize arguments / postprocess the output ###
201
202
203def normalizer(_func=None, *, promote_scalar_result=False):
204    def normalizer_inner(func):
205        @functools.wraps(func)
206        def wrapped(*args, **kwds):
207            sig = inspect.signature(func)
208            params = sig.parameters
209            first_param = next(iter(params.values()))
210
211            # NumPy's API does not have positional args before variadic positional args
212            if first_param.kind == inspect.Parameter.VAR_POSITIONAL:
213                args = [maybe_normalize(arg, first_param) for arg in args]
214            else:
215                # NB: extra unknown arguments: pass through, will raise in func(*args) below
216                args = (
217                    tuple(
218                        maybe_normalize(arg, parm)
219                        for arg, parm in zip(args, params.values())
220                    )
221                    + args[len(params.values()) :]
222                )
223
224            kwds = {
225                name: maybe_normalize(arg, params[name]) if name in params else arg
226                for name, arg in kwds.items()
227            }
228
229            result = func(*args, **kwds)
230
231            # keepdims
232            bound_args = None
233            if "keepdims" in params and params["keepdims"].annotation == "KeepDims":
234                # keepdims can be in any position so we need sig.bind
235                bound_args = sig.bind(*args, **kwds).arguments
236                if bound_args.get("keepdims", False):
237                    # In this case the first arg is the initial tensor and
238                    # the second arg is (optionally) the axis
239                    tensor = args[0]
240                    axis = bound_args.get("axis")
241                    result = _util.apply_keepdims(result, axis, tensor.ndim)
242
243            # out
244            if "out" in params:
245                # out can be in any position so we need sig.bind
246                if bound_args is None:
247                    bound_args = sig.bind(*args, **kwds).arguments
248                out = bound_args.get("out")
249                result = maybe_copy_to(out, result, promote_scalar_result)
250            result = wrap_tensors(result)
251
252            return result
253
254        return wrapped
255
256    if _func is None:
257        return normalizer_inner
258    else:
259        return normalizer_inner(_func)
260