1# mypy: ignore-errors 2 3""" "Normalize" arguments: convert array_likes to tensors, dtypes to torch dtypes and so on. 4""" 5from __future__ import annotations 6 7import functools 8import inspect 9import operator 10import typing 11 12import torch 13 14from . import _dtypes, _dtypes_impl, _util 15 16 17ArrayLike = typing.TypeVar("ArrayLike") 18Scalar = typing.Union[int, float, complex, bool] 19ArrayLikeOrScalar = typing.Union[ArrayLike, Scalar] 20 21DTypeLike = typing.TypeVar("DTypeLike") 22AxisLike = typing.TypeVar("AxisLike") 23NDArray = typing.TypeVar("NDArray") 24CastingModes = typing.TypeVar("CastingModes") 25KeepDims = typing.TypeVar("KeepDims") 26 27# OutArray is to annotate the out= array argument. 28# 29# This one is special is several respects: 30# First, It needs to be an NDArray, and we need to preserve the `result is out` 31# semantics. Therefore, we cannot just extract the Tensor from the out array. 32# So we never pass the out array to implementer functions and handle it in the 33# `normalizer` below. 34# Second, the out= argument can be either keyword or positional argument, and 35# as a positional arg, it can be anywhere in the signature. 36# To handle all this, we define a special `OutArray` annotation and dispatch on it. 37# 38OutArray = typing.TypeVar("OutArray") 39 40try: 41 from typing import NotImplementedType 42except ImportError: 43 NotImplementedType = typing.TypeVar("NotImplementedType") 44 45 46def normalize_array_like(x, parm=None): 47 from ._ndarray import asarray 48 49 return asarray(x).tensor 50 51 52def normalize_array_like_or_scalar(x, parm=None): 53 if _dtypes_impl.is_scalar_or_symbolic(x): 54 return x 55 return normalize_array_like(x, parm) 56 57 58def normalize_optional_array_like_or_scalar(x, parm=None): 59 if x is None: 60 return None 61 return normalize_array_like_or_scalar(x, parm) 62 63 64def normalize_optional_array_like(x, parm=None): 65 # This explicit normalizer is needed because otherwise normalize_array_like 66 # does not run for a parameter annotated as Optional[ArrayLike] 67 return None if x is None else normalize_array_like(x, parm) 68 69 70def normalize_seq_array_like(x, parm=None): 71 return tuple(normalize_array_like(value) for value in x) 72 73 74def normalize_dtype(dtype, parm=None): 75 # cf _decorators.dtype_to_torch 76 torch_dtype = None 77 if dtype is not None: 78 dtype = _dtypes.dtype(dtype) 79 torch_dtype = dtype.torch_dtype 80 return torch_dtype 81 82 83def normalize_not_implemented(arg, parm): 84 if arg != parm.default: 85 raise NotImplementedError(f"'{parm.name}' parameter is not supported.") 86 87 88def normalize_axis_like(arg, parm=None): 89 from ._ndarray import ndarray 90 91 if isinstance(arg, ndarray): 92 arg = operator.index(arg) 93 return arg 94 95 96def normalize_ndarray(arg, parm=None): 97 # check the arg is an ndarray, extract its tensor attribute 98 if arg is None: 99 return arg 100 101 from ._ndarray import ndarray 102 103 if not isinstance(arg, ndarray): 104 raise TypeError(f"'{parm.name}' must be an array") 105 return arg.tensor 106 107 108def normalize_outarray(arg, parm=None): 109 # almost normalize_ndarray, only return the array, not its tensor 110 if arg is None: 111 return arg 112 from ._ndarray import ndarray 113 114 # Dynamo can pass torch tensors as out arguments, 115 # wrap it in an ndarray before processing 116 if isinstance(arg, torch.Tensor): 117 arg = ndarray(arg) 118 119 if not isinstance(arg, ndarray): 120 raise TypeError(f"'{parm.name}' must be an array") 121 return arg 122 123 124def normalize_casting(arg, parm=None): 125 if arg not in ["no", "equiv", "safe", "same_kind", "unsafe"]: 126 raise ValueError( 127 f"casting must be one of 'no', 'equiv', 'safe', 'same_kind', or 'unsafe' (got '{arg}')" 128 ) 129 return arg 130 131 132normalizers = { 133 "ArrayLike": normalize_array_like, 134 "ArrayLikeOrScalar": normalize_array_like_or_scalar, 135 "Optional[ArrayLike]": normalize_optional_array_like, 136 "Sequence[ArrayLike]": normalize_seq_array_like, 137 "Optional[ArrayLikeOrScalar]": normalize_optional_array_like_or_scalar, 138 "Optional[NDArray]": normalize_ndarray, 139 "Optional[OutArray]": normalize_outarray, 140 "NDArray": normalize_ndarray, 141 "Optional[DTypeLike]": normalize_dtype, 142 "AxisLike": normalize_axis_like, 143 "NotImplementedType": normalize_not_implemented, 144 "Optional[CastingModes]": normalize_casting, 145} 146 147 148def maybe_normalize(arg, parm): 149 """Normalize arg if a normalizer is registered.""" 150 normalizer = normalizers.get(parm.annotation, None) 151 return normalizer(arg, parm) if normalizer else arg 152 153 154# ### Return value helpers ### 155 156 157def maybe_copy_to(out, result, promote_scalar_result=False): 158 # NB: here out is either an ndarray or None 159 if out is None: 160 return result 161 elif isinstance(result, torch.Tensor): 162 if result.shape != out.shape: 163 can_fit = result.numel() == 1 and out.ndim == 0 164 if promote_scalar_result and can_fit: 165 result = result.squeeze() 166 else: 167 raise ValueError( 168 f"Bad size of the out array: out.shape = {out.shape}" 169 f" while result.shape = {result.shape}." 170 ) 171 out.tensor.copy_(result) 172 return out 173 elif isinstance(result, (tuple, list)): 174 return type(result)( 175 maybe_copy_to(o, r, promote_scalar_result) for o, r in zip(out, result) 176 ) 177 else: 178 raise AssertionError # We should never hit this path 179 180 181def wrap_tensors(result): 182 from ._ndarray import ndarray 183 184 if isinstance(result, torch.Tensor): 185 return ndarray(result) 186 elif isinstance(result, (tuple, list)): 187 result = type(result)(wrap_tensors(x) for x in result) 188 return result 189 190 191def array_or_scalar(values, py_type=float, return_scalar=False): 192 if return_scalar: 193 return py_type(values.item()) 194 else: 195 from ._ndarray import ndarray 196 197 return ndarray(values) 198 199 200# ### The main decorator to normalize arguments / postprocess the output ### 201 202 203def normalizer(_func=None, *, promote_scalar_result=False): 204 def normalizer_inner(func): 205 @functools.wraps(func) 206 def wrapped(*args, **kwds): 207 sig = inspect.signature(func) 208 params = sig.parameters 209 first_param = next(iter(params.values())) 210 211 # NumPy's API does not have positional args before variadic positional args 212 if first_param.kind == inspect.Parameter.VAR_POSITIONAL: 213 args = [maybe_normalize(arg, first_param) for arg in args] 214 else: 215 # NB: extra unknown arguments: pass through, will raise in func(*args) below 216 args = ( 217 tuple( 218 maybe_normalize(arg, parm) 219 for arg, parm in zip(args, params.values()) 220 ) 221 + args[len(params.values()) :] 222 ) 223 224 kwds = { 225 name: maybe_normalize(arg, params[name]) if name in params else arg 226 for name, arg in kwds.items() 227 } 228 229 result = func(*args, **kwds) 230 231 # keepdims 232 bound_args = None 233 if "keepdims" in params and params["keepdims"].annotation == "KeepDims": 234 # keepdims can be in any position so we need sig.bind 235 bound_args = sig.bind(*args, **kwds).arguments 236 if bound_args.get("keepdims", False): 237 # In this case the first arg is the initial tensor and 238 # the second arg is (optionally) the axis 239 tensor = args[0] 240 axis = bound_args.get("axis") 241 result = _util.apply_keepdims(result, axis, tensor.ndim) 242 243 # out 244 if "out" in params: 245 # out can be in any position so we need sig.bind 246 if bound_args is None: 247 bound_args = sig.bind(*args, **kwds).arguments 248 out = bound_args.get("out") 249 result = maybe_copy_to(out, result, promote_scalar_result) 250 result = wrap_tensors(result) 251 252 return result 253 254 return wrapped 255 256 if _func is None: 257 return normalizer_inner 258 else: 259 return normalizer_inner(_func) 260