1# mypy: ignore-errors 2 3"""Assorted utilities, which do not need anything other then torch and stdlib. 4""" 5 6import operator 7 8import torch 9 10from . import _dtypes_impl 11 12 13# https://github.com/numpy/numpy/blob/v1.23.0/numpy/distutils/misc_util.py#L497-L504 14def is_sequence(seq): 15 if isinstance(seq, str): 16 return False 17 try: 18 len(seq) 19 except Exception: 20 return False 21 return True 22 23 24class AxisError(ValueError, IndexError): 25 pass 26 27 28class UFuncTypeError(TypeError, RuntimeError): 29 pass 30 31 32def cast_if_needed(tensor, dtype): 33 # NB: no casting if dtype=None 34 if dtype is not None and tensor.dtype != dtype: 35 tensor = tensor.to(dtype) 36 return tensor 37 38 39def cast_int_to_float(x): 40 # cast integers and bools to the default float dtype 41 if _dtypes_impl._category(x.dtype) < 2: 42 x = x.to(_dtypes_impl.default_dtypes().float_dtype) 43 return x 44 45 46# a replica of the version in ./numpy/numpy/core/src/multiarray/common.h 47def normalize_axis_index(ax, ndim, argname=None): 48 if not (-ndim <= ax < ndim): 49 raise AxisError(f"axis {ax} is out of bounds for array of dimension {ndim}") 50 if ax < 0: 51 ax += ndim 52 return ax 53 54 55# from https://github.com/numpy/numpy/blob/main/numpy/core/numeric.py#L1378 56def normalize_axis_tuple(axis, ndim, argname=None, allow_duplicate=False): 57 """ 58 Normalizes an axis argument into a tuple of non-negative integer axes. 59 60 This handles shorthands such as ``1`` and converts them to ``(1,)``, 61 as well as performing the handling of negative indices covered by 62 `normalize_axis_index`. 63 64 By default, this forbids axes from being specified multiple times. 65 Used internally by multi-axis-checking logic. 66 67 Parameters 68 ---------- 69 axis : int, iterable of int 70 The un-normalized index or indices of the axis. 71 ndim : int 72 The number of dimensions of the array that `axis` should be normalized 73 against. 74 argname : str, optional 75 A prefix to put before the error message, typically the name of the 76 argument. 77 allow_duplicate : bool, optional 78 If False, the default, disallow an axis from being specified twice. 79 80 Returns 81 ------- 82 normalized_axes : tuple of int 83 The normalized axis index, such that `0 <= normalized_axis < ndim` 84 """ 85 # Optimization to speed-up the most common cases. 86 if type(axis) not in (tuple, list): 87 try: 88 axis = [operator.index(axis)] 89 except TypeError: 90 pass 91 # Going via an iterator directly is slower than via list comprehension. 92 axis = tuple([normalize_axis_index(ax, ndim, argname) for ax in axis]) 93 if not allow_duplicate and len(set(map(int, axis))) != len(axis): 94 if argname: 95 raise ValueError(f"repeated axis in `{argname}` argument") 96 else: 97 raise ValueError("repeated axis") 98 return axis 99 100 101def allow_only_single_axis(axis): 102 if axis is None: 103 return axis 104 if len(axis) != 1: 105 raise NotImplementedError("does not handle tuple axis") 106 return axis[0] 107 108 109def expand_shape(arr_shape, axis): 110 # taken from numpy 1.23.x, expand_dims function 111 if type(axis) not in (list, tuple): 112 axis = (axis,) 113 out_ndim = len(axis) + len(arr_shape) 114 axis = normalize_axis_tuple(axis, out_ndim) 115 shape_it = iter(arr_shape) 116 shape = [1 if ax in axis else next(shape_it) for ax in range(out_ndim)] 117 return shape 118 119 120def apply_keepdims(tensor, axis, ndim): 121 if axis is None: 122 # tensor was a scalar 123 shape = (1,) * ndim 124 tensor = tensor.expand(shape).contiguous() 125 else: 126 shape = expand_shape(tensor.shape, axis) 127 tensor = tensor.reshape(shape) 128 return tensor 129 130 131def axis_none_flatten(*tensors, axis=None): 132 """Flatten the arrays if axis is None.""" 133 if axis is None: 134 tensors = tuple(ar.flatten() for ar in tensors) 135 return tensors, 0 136 else: 137 return tensors, axis 138 139 140def typecast_tensor(t, target_dtype, casting): 141 """Dtype-cast tensor to target_dtype. 142 143 Parameters 144 ---------- 145 t : torch.Tensor 146 The tensor to cast 147 target_dtype : torch dtype object 148 The array dtype to cast all tensors to 149 casting : str 150 The casting mode, see `np.can_cast` 151 152 Returns 153 ------- 154 `torch.Tensor` of the `target_dtype` dtype 155 156 Raises 157 ------ 158 ValueError 159 if the argument cannot be cast according to the `casting` rule 160 161 """ 162 can_cast = _dtypes_impl.can_cast_impl 163 164 if not can_cast(t.dtype, target_dtype, casting=casting): 165 raise TypeError( 166 f"Cannot cast array data from {t.dtype} to" 167 f" {target_dtype} according to the rule '{casting}'" 168 ) 169 return cast_if_needed(t, target_dtype) 170 171 172def typecast_tensors(tensors, target_dtype, casting): 173 return tuple(typecast_tensor(t, target_dtype, casting) for t in tensors) 174 175 176def _try_convert_to_tensor(obj): 177 try: 178 tensor = torch.as_tensor(obj) 179 except Exception as e: 180 mesg = f"failed to convert {obj} to ndarray. \nInternal error is: {str(e)}." 181 raise NotImplementedError(mesg) # noqa: B904 182 return tensor 183 184 185def _coerce_to_tensor(obj, dtype=None, copy=False, ndmin=0): 186 """The core logic of the array(...) function. 187 188 Parameters 189 ---------- 190 obj : tensor_like 191 The thing to coerce 192 dtype : torch.dtype object or None 193 Coerce to this torch dtype 194 copy : bool 195 Copy or not 196 ndmin : int 197 The results as least this many dimensions 198 is_weak : bool 199 Whether obj is a weakly typed python scalar. 200 201 Returns 202 ------- 203 tensor : torch.Tensor 204 a tensor object with requested dtype, ndim and copy semantics. 205 206 Notes 207 ----- 208 This is almost a "tensor_like" coersion function. Does not handle wrapper 209 ndarrays (those should be handled in the ndarray-aware layer prior to 210 invoking this function). 211 """ 212 if isinstance(obj, torch.Tensor): 213 tensor = obj 214 else: 215 # tensor.dtype is the pytorch default, typically float32. If obj's elements 216 # are not exactly representable in float32, we've lost precision: 217 # >>> torch.as_tensor(1e12).item() - 1e12 218 # -4096.0 219 default_dtype = torch.get_default_dtype() 220 torch.set_default_dtype(_dtypes_impl.get_default_dtype_for(torch.float32)) 221 try: 222 tensor = _try_convert_to_tensor(obj) 223 finally: 224 torch.set_default_dtype(default_dtype) 225 226 # type cast if requested 227 tensor = cast_if_needed(tensor, dtype) 228 229 # adjust ndim if needed 230 ndim_extra = ndmin - tensor.ndim 231 if ndim_extra > 0: 232 tensor = tensor.view((1,) * ndim_extra + tensor.shape) 233 234 # copy if requested 235 if copy: 236 tensor = tensor.clone() 237 238 return tensor 239 240 241def ndarrays_to_tensors(*inputs): 242 """Convert all ndarrays from `inputs` to tensors. (other things are intact)""" 243 from ._ndarray import ndarray 244 245 if len(inputs) == 0: 246 return ValueError() 247 elif len(inputs) == 1: 248 input_ = inputs[0] 249 if isinstance(input_, ndarray): 250 return input_.tensor 251 elif isinstance(input_, tuple): 252 result = [] 253 for sub_input in input_: 254 sub_result = ndarrays_to_tensors(sub_input) 255 result.append(sub_result) 256 return tuple(result) 257 else: 258 return input_ 259 else: 260 assert isinstance(inputs, tuple) # sanity check 261 return ndarrays_to_tensors(inputs) 262