1# mypy: ignore-errors 2 3from __future__ import annotations 4 5from typing import Optional 6 7import torch 8 9from . import _binary_ufuncs_impl, _dtypes_impl, _unary_ufuncs_impl, _util 10from ._normalizations import ( 11 ArrayLike, 12 ArrayLikeOrScalar, 13 CastingModes, 14 DTypeLike, 15 normalizer, 16 NotImplementedType, 17 OutArray, 18) 19 20 21def _ufunc_postprocess(result, out, casting): 22 if out is not None: 23 result = _util.typecast_tensor(result, out.dtype.torch_dtype, casting) 24 result = torch.broadcast_to(result, out.shape) 25 return result 26 27 28# ############# Binary ufuncs ###################### 29 30_binary = [ 31 name 32 for name in dir(_binary_ufuncs_impl) 33 if not name.startswith("_") and name not in ["torch", "matmul", "divmod", "ldexp"] 34] 35 36 37NEP50_FUNCS = ( 38 "add", 39 "subtract", 40 "multiply", 41 "floor_divide", 42 "true_divide", 43 "divide", 44 "remainder", 45 "bitwise_and", 46 "bitwise_or", 47 "bitwise_xor", 48 "bitwise_left_shift", 49 "bitwise_right_shift", 50 "hypot", 51 "arctan2", 52 "logaddexp", 53 "logaddexp2", 54 "heaviside", 55 "copysign", 56 "fmax", 57 "minimum", 58 "fmin", 59 "maximum", 60 "fmod", 61 "gcd", 62 "lcm", 63 "pow", 64) 65 66 67def deco_binary_ufunc(torch_func): 68 """Common infra for binary ufuncs. 69 70 Normalize arguments, sort out type casting, broadcasting and delegate to 71 the pytorch functions for the actual work. 72 """ 73 74 @normalizer 75 def wrapped( 76 x1: ArrayLikeOrScalar, 77 x2: ArrayLikeOrScalar, 78 /, 79 out: Optional[OutArray] = None, 80 *, 81 where: NotImplementedType = True, 82 casting: Optional[CastingModes] = "same_kind", 83 order: NotImplementedType = "K", 84 dtype: Optional[DTypeLike] = None, 85 subok: NotImplementedType = False, 86 signature: NotImplementedType = None, 87 extobj: NotImplementedType = None, 88 ): 89 if dtype is not None: 90 91 def cast(x, dtype): 92 if isinstance(x, torch.Tensor): 93 return _util.typecast_tensor(x, dtype, casting) 94 else: 95 return torch.as_tensor(x, dtype=dtype) 96 97 x1 = cast(x1, dtype) 98 x2 = cast(x2, dtype) 99 elif isinstance(x1, torch.Tensor) and isinstance(x2, torch.Tensor): 100 dtype = _dtypes_impl.result_type_impl(x1, x2) 101 x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting) 102 else: 103 x1, x2 = _dtypes_impl.nep50_to_tensors( 104 x1, x2, torch_func.__name__ in NEP50_FUNCS, torch_func.__name__ 105 ) 106 107 result = torch_func(x1, x2) 108 109 return _ufunc_postprocess(result, out, casting) 110 111 wrapped.__qualname__ = torch_func.__name__ 112 wrapped.__name__ = torch_func.__name__ 113 114 return wrapped 115 116 117# matmul's signature is _slightly_ different from other ufuncs: 118# - no where=... 119# - additional axis=..., axes=... 120# - no NEP50 scalars in or out 121@normalizer 122def matmul( 123 x1: ArrayLike, 124 x2: ArrayLike, 125 /, 126 out: Optional[OutArray] = None, 127 *, 128 casting: Optional[CastingModes] = "same_kind", 129 order: NotImplementedType = "K", 130 dtype: Optional[DTypeLike] = None, 131 subok: NotImplementedType = False, 132 signature: NotImplementedType = None, 133 extobj: NotImplementedType = None, 134 axes: NotImplementedType = None, 135 axis: NotImplementedType = None, 136): 137 if dtype is None: 138 dtype = _dtypes_impl.result_type_impl(x1, x2) 139 x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting) 140 141 result = _binary_ufuncs_impl.matmul(x1, x2) 142 143 result = _ufunc_postprocess(result, out, casting) 144 return result 145 146 147# ldexp casting is special : the dtype of the result == dtype of the 1st arg 148@normalizer 149def ldexp( 150 x1: ArrayLikeOrScalar, 151 x2: ArrayLikeOrScalar, 152 /, 153 out: Optional[OutArray] = None, 154 *, 155 where: NotImplementedType = True, 156 casting: Optional[CastingModes] = "same_kind", 157 order: NotImplementedType = "K", 158 dtype: Optional[DTypeLike] = None, 159 subok: NotImplementedType = False, 160 signature: NotImplementedType = None, 161 extobj: NotImplementedType = None, 162): 163 if dtype is not None: 164 if isinstance(x1, torch.Tensor): 165 x1 = _util.typecast_tensor(x1, dtype, casting) 166 else: 167 x1 = torch.as_tensor(x1, dtype=dtype) 168 else: 169 if not isinstance(x1, torch.Tensor): 170 x1 = torch.as_tensor(x1) 171 x1 = _util.cast_int_to_float(x1) 172 173 x2 = torch.as_tensor(x2) 174 # the second arg must be integer 175 if _dtypes_impl._category(x2.dtype) != 1: 176 raise ValueError("ldexp 2nd arg must be integer") 177 178 result = _binary_ufuncs_impl.ldexp(x1, x2) 179 180 if x1.dtype == torch.float16: 181 # torch.ldexp(f16, int) -> f32, undo it 182 result = result.to(torch.float16) 183 184 return _ufunc_postprocess(result, out, casting) 185 186 187# nin=2, nout=2 188@normalizer 189def divmod( 190 x1: ArrayLike, 191 x2: ArrayLike, 192 out1: Optional[OutArray] = None, 193 out2: Optional[OutArray] = None, 194 /, 195 out: tuple[Optional[OutArray], Optional[OutArray]] = (None, None), 196 *, 197 where: NotImplementedType = True, 198 casting: Optional[CastingModes] = "same_kind", 199 order: NotImplementedType = "K", 200 dtype: Optional[DTypeLike] = None, 201 subok: NotImplementedType = False, 202 signature: NotImplementedType = None, 203 extobj: NotImplementedType = None, 204): 205 # make sure we either have no out arrays at all, or there is either 206 # out1, out2, or out=tuple, but not both 207 num_outs = sum(x is not None for x in [out1, out2]) 208 if num_outs == 1: 209 raise ValueError("both out1 and out2 need to be provided") 210 elif num_outs == 2: 211 o1, o2 = out 212 if o1 is not None or o2 is not None: 213 raise TypeError( 214 "cannot specify 'out' as both a positional and keyword argument" 215 ) 216 else: 217 out1, out2 = out 218 219 if dtype is None: 220 dtype = _dtypes_impl.result_type_impl(x1, x2) 221 x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting) 222 223 quot, rem = _binary_ufuncs_impl.divmod(x1, x2) 224 225 quot = _ufunc_postprocess(quot, out1, casting) 226 rem = _ufunc_postprocess(rem, out2, casting) 227 return quot, rem 228 229 230# 231# Attach ufuncs to this module, for a further export to the public namespace in __init__.py 232# 233for name in _binary: 234 ufunc = getattr(_binary_ufuncs_impl, name) 235 vars()[name] = deco_binary_ufunc(ufunc) 236 237 238def modf(x, /, *args, **kwds): 239 quot, rem = divmod(x, 1, *args, **kwds) 240 return rem, quot 241 242 243_binary = _binary + ["divmod", "modf", "matmul", "ldexp"] 244 245 246# ############# Unary ufuncs ###################### 247 248 249_unary = [ 250 name 251 for name in dir(_unary_ufuncs_impl) 252 if not name.startswith("_") and name != "torch" 253] 254 255 256# these are ufunc(int) -> float 257_fp_unary = [ 258 "arccos", 259 "arccosh", 260 "arcsin", 261 "arcsinh", 262 "arctan", 263 "arctanh", 264 "cbrt", 265 "cos", 266 "cosh", 267 "deg2rad", 268 "degrees", 269 "exp", 270 "exp2", 271 "expm1", 272 "log", 273 "log10", 274 "log1p", 275 "log2", 276 "rad2deg", 277 "radians", 278 "reciprocal", 279 "sin", 280 "sinh", 281 "sqrt", 282 "square", 283 "tan", 284 "tanh", 285 "trunc", 286] 287 288 289def deco_unary_ufunc(torch_func): 290 """Common infra for unary ufuncs. 291 292 Normalize arguments, sort out type casting, broadcasting and delegate to 293 the pytorch functions for the actual work. 294 """ 295 296 @normalizer 297 def wrapped( 298 x: ArrayLike, 299 /, 300 out: Optional[OutArray] = None, 301 *, 302 where=True, 303 casting: Optional[CastingModes] = "same_kind", 304 order="K", 305 dtype: Optional[DTypeLike] = None, 306 subok: NotImplementedType = False, 307 signature=None, 308 extobj=None, 309 ): 310 if dtype is not None: 311 x = _util.typecast_tensor(x, dtype, casting) 312 313 if torch_func.__name__ in _fp_unary: 314 x = _util.cast_int_to_float(x) 315 316 result = torch_func(x) 317 result = _ufunc_postprocess(result, out, casting) 318 return result 319 320 wrapped.__qualname__ = torch_func.__name__ 321 wrapped.__name__ = torch_func.__name__ 322 323 return wrapped 324 325 326# 327# Attach ufuncs to this module, for a further export to the public namespace in __init__.py 328# 329for name in _unary: 330 ufunc = getattr(_unary_ufuncs_impl, name) 331 vars()[name] = deco_unary_ufunc(ufunc) 332 333 334__all__ = _binary + _unary # noqa: PLE0605 335