xref: /aosp_15_r20/external/pytorch/torch/_numpy/_ufuncs.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3from __future__ import annotations
4
5from typing import Optional
6
7import torch
8
9from . import _binary_ufuncs_impl, _dtypes_impl, _unary_ufuncs_impl, _util
10from ._normalizations import (
11    ArrayLike,
12    ArrayLikeOrScalar,
13    CastingModes,
14    DTypeLike,
15    normalizer,
16    NotImplementedType,
17    OutArray,
18)
19
20
21def _ufunc_postprocess(result, out, casting):
22    if out is not None:
23        result = _util.typecast_tensor(result, out.dtype.torch_dtype, casting)
24        result = torch.broadcast_to(result, out.shape)
25    return result
26
27
28# ############# Binary ufuncs ######################
29
30_binary = [
31    name
32    for name in dir(_binary_ufuncs_impl)
33    if not name.startswith("_") and name not in ["torch", "matmul", "divmod", "ldexp"]
34]
35
36
37NEP50_FUNCS = (
38    "add",
39    "subtract",
40    "multiply",
41    "floor_divide",
42    "true_divide",
43    "divide",
44    "remainder",
45    "bitwise_and",
46    "bitwise_or",
47    "bitwise_xor",
48    "bitwise_left_shift",
49    "bitwise_right_shift",
50    "hypot",
51    "arctan2",
52    "logaddexp",
53    "logaddexp2",
54    "heaviside",
55    "copysign",
56    "fmax",
57    "minimum",
58    "fmin",
59    "maximum",
60    "fmod",
61    "gcd",
62    "lcm",
63    "pow",
64)
65
66
67def deco_binary_ufunc(torch_func):
68    """Common infra for binary ufuncs.
69
70    Normalize arguments, sort out type casting, broadcasting and delegate to
71    the pytorch functions for the actual work.
72    """
73
74    @normalizer
75    def wrapped(
76        x1: ArrayLikeOrScalar,
77        x2: ArrayLikeOrScalar,
78        /,
79        out: Optional[OutArray] = None,
80        *,
81        where: NotImplementedType = True,
82        casting: Optional[CastingModes] = "same_kind",
83        order: NotImplementedType = "K",
84        dtype: Optional[DTypeLike] = None,
85        subok: NotImplementedType = False,
86        signature: NotImplementedType = None,
87        extobj: NotImplementedType = None,
88    ):
89        if dtype is not None:
90
91            def cast(x, dtype):
92                if isinstance(x, torch.Tensor):
93                    return _util.typecast_tensor(x, dtype, casting)
94                else:
95                    return torch.as_tensor(x, dtype=dtype)
96
97            x1 = cast(x1, dtype)
98            x2 = cast(x2, dtype)
99        elif isinstance(x1, torch.Tensor) and isinstance(x2, torch.Tensor):
100            dtype = _dtypes_impl.result_type_impl(x1, x2)
101            x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting)
102        else:
103            x1, x2 = _dtypes_impl.nep50_to_tensors(
104                x1, x2, torch_func.__name__ in NEP50_FUNCS, torch_func.__name__
105            )
106
107        result = torch_func(x1, x2)
108
109        return _ufunc_postprocess(result, out, casting)
110
111    wrapped.__qualname__ = torch_func.__name__
112    wrapped.__name__ = torch_func.__name__
113
114    return wrapped
115
116
117# matmul's signature is _slightly_ different from other ufuncs:
118# - no where=...
119# - additional axis=..., axes=...
120# - no NEP50 scalars in or out
121@normalizer
122def matmul(
123    x1: ArrayLike,
124    x2: ArrayLike,
125    /,
126    out: Optional[OutArray] = None,
127    *,
128    casting: Optional[CastingModes] = "same_kind",
129    order: NotImplementedType = "K",
130    dtype: Optional[DTypeLike] = None,
131    subok: NotImplementedType = False,
132    signature: NotImplementedType = None,
133    extobj: NotImplementedType = None,
134    axes: NotImplementedType = None,
135    axis: NotImplementedType = None,
136):
137    if dtype is None:
138        dtype = _dtypes_impl.result_type_impl(x1, x2)
139    x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting)
140
141    result = _binary_ufuncs_impl.matmul(x1, x2)
142
143    result = _ufunc_postprocess(result, out, casting)
144    return result
145
146
147# ldexp casting is special : the dtype of the result == dtype of the 1st arg
148@normalizer
149def ldexp(
150    x1: ArrayLikeOrScalar,
151    x2: ArrayLikeOrScalar,
152    /,
153    out: Optional[OutArray] = None,
154    *,
155    where: NotImplementedType = True,
156    casting: Optional[CastingModes] = "same_kind",
157    order: NotImplementedType = "K",
158    dtype: Optional[DTypeLike] = None,
159    subok: NotImplementedType = False,
160    signature: NotImplementedType = None,
161    extobj: NotImplementedType = None,
162):
163    if dtype is not None:
164        if isinstance(x1, torch.Tensor):
165            x1 = _util.typecast_tensor(x1, dtype, casting)
166        else:
167            x1 = torch.as_tensor(x1, dtype=dtype)
168    else:
169        if not isinstance(x1, torch.Tensor):
170            x1 = torch.as_tensor(x1)
171            x1 = _util.cast_int_to_float(x1)
172
173    x2 = torch.as_tensor(x2)
174    # the second arg must be integer
175    if _dtypes_impl._category(x2.dtype) != 1:
176        raise ValueError("ldexp 2nd arg must be integer")
177
178    result = _binary_ufuncs_impl.ldexp(x1, x2)
179
180    if x1.dtype == torch.float16:
181        # torch.ldexp(f16, int) -> f32, undo it
182        result = result.to(torch.float16)
183
184    return _ufunc_postprocess(result, out, casting)
185
186
187# nin=2, nout=2
188@normalizer
189def divmod(
190    x1: ArrayLike,
191    x2: ArrayLike,
192    out1: Optional[OutArray] = None,
193    out2: Optional[OutArray] = None,
194    /,
195    out: tuple[Optional[OutArray], Optional[OutArray]] = (None, None),
196    *,
197    where: NotImplementedType = True,
198    casting: Optional[CastingModes] = "same_kind",
199    order: NotImplementedType = "K",
200    dtype: Optional[DTypeLike] = None,
201    subok: NotImplementedType = False,
202    signature: NotImplementedType = None,
203    extobj: NotImplementedType = None,
204):
205    # make sure we either have no out arrays at all, or there is either
206    # out1, out2, or out=tuple, but not both
207    num_outs = sum(x is not None for x in [out1, out2])
208    if num_outs == 1:
209        raise ValueError("both out1 and out2 need to be provided")
210    elif num_outs == 2:
211        o1, o2 = out
212        if o1 is not None or o2 is not None:
213            raise TypeError(
214                "cannot specify 'out' as both a positional and keyword argument"
215            )
216    else:
217        out1, out2 = out
218
219    if dtype is None:
220        dtype = _dtypes_impl.result_type_impl(x1, x2)
221    x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting)
222
223    quot, rem = _binary_ufuncs_impl.divmod(x1, x2)
224
225    quot = _ufunc_postprocess(quot, out1, casting)
226    rem = _ufunc_postprocess(rem, out2, casting)
227    return quot, rem
228
229
230#
231# Attach ufuncs to this module, for a further export to the public namespace in __init__.py
232#
233for name in _binary:
234    ufunc = getattr(_binary_ufuncs_impl, name)
235    vars()[name] = deco_binary_ufunc(ufunc)
236
237
238def modf(x, /, *args, **kwds):
239    quot, rem = divmod(x, 1, *args, **kwds)
240    return rem, quot
241
242
243_binary = _binary + ["divmod", "modf", "matmul", "ldexp"]
244
245
246# ############# Unary ufuncs ######################
247
248
249_unary = [
250    name
251    for name in dir(_unary_ufuncs_impl)
252    if not name.startswith("_") and name != "torch"
253]
254
255
256# these are ufunc(int) -> float
257_fp_unary = [
258    "arccos",
259    "arccosh",
260    "arcsin",
261    "arcsinh",
262    "arctan",
263    "arctanh",
264    "cbrt",
265    "cos",
266    "cosh",
267    "deg2rad",
268    "degrees",
269    "exp",
270    "exp2",
271    "expm1",
272    "log",
273    "log10",
274    "log1p",
275    "log2",
276    "rad2deg",
277    "radians",
278    "reciprocal",
279    "sin",
280    "sinh",
281    "sqrt",
282    "square",
283    "tan",
284    "tanh",
285    "trunc",
286]
287
288
289def deco_unary_ufunc(torch_func):
290    """Common infra for unary ufuncs.
291
292    Normalize arguments, sort out type casting, broadcasting and delegate to
293    the pytorch functions for the actual work.
294    """
295
296    @normalizer
297    def wrapped(
298        x: ArrayLike,
299        /,
300        out: Optional[OutArray] = None,
301        *,
302        where=True,
303        casting: Optional[CastingModes] = "same_kind",
304        order="K",
305        dtype: Optional[DTypeLike] = None,
306        subok: NotImplementedType = False,
307        signature=None,
308        extobj=None,
309    ):
310        if dtype is not None:
311            x = _util.typecast_tensor(x, dtype, casting)
312
313        if torch_func.__name__ in _fp_unary:
314            x = _util.cast_int_to_float(x)
315
316        result = torch_func(x)
317        result = _ufunc_postprocess(result, out, casting)
318        return result
319
320    wrapped.__qualname__ = torch_func.__name__
321    wrapped.__name__ = torch_func.__name__
322
323    return wrapped
324
325
326#
327# Attach ufuncs to this module, for a further export to the public namespace in __init__.py
328#
329for name in _unary:
330    ufunc = getattr(_unary_ufuncs_impl, name)
331    vars()[name] = deco_unary_ufunc(ufunc)
332
333
334__all__ = _binary + _unary  # noqa: PLE0605
335