1# mypy: ignore-errors 2 3from __future__ import annotations 4 5import builtins 6import math 7import operator 8from typing import Sequence 9 10import torch 11 12from . import _dtypes, _dtypes_impl, _funcs, _ufuncs, _util 13from ._normalizations import ( 14 ArrayLike, 15 normalize_array_like, 16 normalizer, 17 NotImplementedType, 18) 19 20 21newaxis = None 22 23FLAGS = [ 24 "C_CONTIGUOUS", 25 "F_CONTIGUOUS", 26 "OWNDATA", 27 "WRITEABLE", 28 "ALIGNED", 29 "WRITEBACKIFCOPY", 30 "FNC", 31 "FORC", 32 "BEHAVED", 33 "CARRAY", 34 "FARRAY", 35] 36 37SHORTHAND_TO_FLAGS = { 38 "C": "C_CONTIGUOUS", 39 "F": "F_CONTIGUOUS", 40 "O": "OWNDATA", 41 "W": "WRITEABLE", 42 "A": "ALIGNED", 43 "X": "WRITEBACKIFCOPY", 44 "B": "BEHAVED", 45 "CA": "CARRAY", 46 "FA": "FARRAY", 47} 48 49 50class Flags: 51 def __init__(self, flag_to_value: dict): 52 assert all(k in FLAGS for k in flag_to_value.keys()) # sanity check 53 self._flag_to_value = flag_to_value 54 55 def __getattr__(self, attr: str): 56 if attr.islower() and attr.upper() in FLAGS: 57 return self[attr.upper()] 58 else: 59 raise AttributeError(f"No flag attribute '{attr}'") 60 61 def __getitem__(self, key): 62 if key in SHORTHAND_TO_FLAGS.keys(): 63 key = SHORTHAND_TO_FLAGS[key] 64 if key in FLAGS: 65 try: 66 return self._flag_to_value[key] 67 except KeyError as e: 68 raise NotImplementedError(f"{key=}") from e 69 else: 70 raise KeyError(f"No flag key '{key}'") 71 72 def __setattr__(self, attr, value): 73 if attr.islower() and attr.upper() in FLAGS: 74 self[attr.upper()] = value 75 else: 76 super().__setattr__(attr, value) 77 78 def __setitem__(self, key, value): 79 if key in FLAGS or key in SHORTHAND_TO_FLAGS.keys(): 80 raise NotImplementedError("Modifying flags is not implemented") 81 else: 82 raise KeyError(f"No flag key '{key}'") 83 84 85def create_method(fn, name=None): 86 name = name or fn.__name__ 87 88 def f(*args, **kwargs): 89 return fn(*args, **kwargs) 90 91 f.__name__ = name 92 f.__qualname__ = f"ndarray.{name}" 93 return f 94 95 96# Map ndarray.name_method -> np.name_func 97# If name_func == None, it means that name_method == name_func 98methods = { 99 "clip": None, 100 "nonzero": None, 101 "repeat": None, 102 "round": None, 103 "squeeze": None, 104 "swapaxes": None, 105 "ravel": None, 106 # linalg 107 "diagonal": None, 108 "dot": None, 109 "trace": None, 110 # sorting 111 "argsort": None, 112 "searchsorted": None, 113 # reductions 114 "argmax": None, 115 "argmin": None, 116 "any": None, 117 "all": None, 118 "max": None, 119 "min": None, 120 "ptp": None, 121 "sum": None, 122 "prod": None, 123 "mean": None, 124 "var": None, 125 "std": None, 126 # scans 127 "cumsum": None, 128 "cumprod": None, 129 # advanced indexing 130 "take": None, 131 "choose": None, 132} 133 134dunder = { 135 "abs": "absolute", 136 "invert": None, 137 "pos": "positive", 138 "neg": "negative", 139 "gt": "greater", 140 "lt": "less", 141 "ge": "greater_equal", 142 "le": "less_equal", 143} 144 145# dunder methods with right-looking and in-place variants 146ri_dunder = { 147 "add": None, 148 "sub": "subtract", 149 "mul": "multiply", 150 "truediv": "divide", 151 "floordiv": "floor_divide", 152 "pow": "power", 153 "mod": "remainder", 154 "and": "bitwise_and", 155 "or": "bitwise_or", 156 "xor": "bitwise_xor", 157 "lshift": "left_shift", 158 "rshift": "right_shift", 159 "matmul": None, 160} 161 162 163def _upcast_int_indices(index): 164 if isinstance(index, torch.Tensor): 165 if index.dtype in (torch.int8, torch.int16, torch.int32, torch.uint8): 166 return index.to(torch.int64) 167 elif isinstance(index, tuple): 168 return tuple(_upcast_int_indices(i) for i in index) 169 return index 170 171 172# Used to indicate that a parameter is unspecified (as opposed to explicitly 173# `None`) 174class _Unspecified: 175 pass 176 177 178_Unspecified.unspecified = _Unspecified() 179 180############################################################### 181# ndarray class # 182############################################################### 183 184 185class ndarray: 186 def __init__(self, t=None): 187 if t is None: 188 self.tensor = torch.Tensor() 189 elif isinstance(t, torch.Tensor): 190 self.tensor = t 191 else: 192 raise ValueError( 193 "ndarray constructor is not recommended; prefer" 194 "either array(...) or zeros/empty(...)" 195 ) 196 197 # Register NumPy functions as methods 198 for method, name in methods.items(): 199 fn = getattr(_funcs, name or method) 200 vars()[method] = create_method(fn, method) 201 202 # Regular methods but coming from ufuncs 203 conj = create_method(_ufuncs.conjugate, "conj") 204 conjugate = create_method(_ufuncs.conjugate) 205 206 for method, name in dunder.items(): 207 fn = getattr(_ufuncs, name or method) 208 method = f"__{method}__" 209 vars()[method] = create_method(fn, method) 210 211 for method, name in ri_dunder.items(): 212 fn = getattr(_ufuncs, name or method) 213 plain = f"__{method}__" 214 vars()[plain] = create_method(fn, plain) 215 rvar = f"__r{method}__" 216 vars()[rvar] = create_method(lambda self, other, fn=fn: fn(other, self), rvar) 217 ivar = f"__i{method}__" 218 vars()[ivar] = create_method( 219 lambda self, other, fn=fn: fn(self, other, out=self), ivar 220 ) 221 222 # There's no __idivmod__ 223 __divmod__ = create_method(_ufuncs.divmod, "__divmod__") 224 __rdivmod__ = create_method( 225 lambda self, other: _ufuncs.divmod(other, self), "__rdivmod__" 226 ) 227 228 # prevent loop variables leaking into the ndarray class namespace 229 del ivar, rvar, name, plain, fn, method 230 231 @property 232 def shape(self): 233 return tuple(self.tensor.shape) 234 235 @property 236 def size(self): 237 return self.tensor.numel() 238 239 @property 240 def ndim(self): 241 return self.tensor.ndim 242 243 @property 244 def dtype(self): 245 return _dtypes.dtype(self.tensor.dtype) 246 247 @property 248 def strides(self): 249 elsize = self.tensor.element_size() 250 return tuple(stride * elsize for stride in self.tensor.stride()) 251 252 @property 253 def itemsize(self): 254 return self.tensor.element_size() 255 256 @property 257 def flags(self): 258 # Note contiguous in torch is assumed C-style 259 return Flags( 260 { 261 "C_CONTIGUOUS": self.tensor.is_contiguous(), 262 "F_CONTIGUOUS": self.T.tensor.is_contiguous(), 263 "OWNDATA": self.tensor._base is None, 264 "WRITEABLE": True, # pytorch does not have readonly tensors 265 } 266 ) 267 268 @property 269 def data(self): 270 return self.tensor.data_ptr() 271 272 @property 273 def nbytes(self): 274 return self.tensor.storage().nbytes() 275 276 @property 277 def T(self): 278 return self.transpose() 279 280 @property 281 def real(self): 282 return _funcs.real(self) 283 284 @real.setter 285 def real(self, value): 286 self.tensor.real = asarray(value).tensor 287 288 @property 289 def imag(self): 290 return _funcs.imag(self) 291 292 @imag.setter 293 def imag(self, value): 294 self.tensor.imag = asarray(value).tensor 295 296 # ctors 297 def astype(self, dtype, order="K", casting="unsafe", subok=True, copy=True): 298 if order != "K": 299 raise NotImplementedError(f"astype(..., order={order} is not implemented.") 300 if casting != "unsafe": 301 raise NotImplementedError( 302 f"astype(..., casting={casting} is not implemented." 303 ) 304 if not subok: 305 raise NotImplementedError(f"astype(..., subok={subok} is not implemented.") 306 if not copy: 307 raise NotImplementedError(f"astype(..., copy={copy} is not implemented.") 308 torch_dtype = _dtypes.dtype(dtype).torch_dtype 309 t = self.tensor.to(torch_dtype) 310 return ndarray(t) 311 312 @normalizer 313 def copy(self: ArrayLike, order: NotImplementedType = "C"): 314 return self.clone() 315 316 @normalizer 317 def flatten(self: ArrayLike, order: NotImplementedType = "C"): 318 return torch.flatten(self) 319 320 def resize(self, *new_shape, refcheck=False): 321 # NB: differs from np.resize: fills with zeros instead of making repeated copies of input. 322 if refcheck: 323 raise NotImplementedError( 324 f"resize(..., refcheck={refcheck} is not implemented." 325 ) 326 if new_shape in [(), (None,)]: 327 return 328 329 # support both x.resize((2, 2)) and x.resize(2, 2) 330 if len(new_shape) == 1: 331 new_shape = new_shape[0] 332 if isinstance(new_shape, int): 333 new_shape = (new_shape,) 334 335 if builtins.any(x < 0 for x in new_shape): 336 raise ValueError("all elements of `new_shape` must be non-negative") 337 338 new_numel, old_numel = math.prod(new_shape), self.tensor.numel() 339 340 self.tensor.resize_(new_shape) 341 342 if new_numel >= old_numel: 343 # zero-fill new elements 344 assert self.tensor.is_contiguous() 345 b = self.tensor.flatten() # does not copy 346 b[old_numel:].zero_() 347 348 def view(self, dtype=_Unspecified.unspecified, type=_Unspecified.unspecified): 349 if dtype is _Unspecified.unspecified: 350 dtype = self.dtype 351 if type is not _Unspecified.unspecified: 352 raise NotImplementedError(f"view(..., type={type} is not implemented.") 353 torch_dtype = _dtypes.dtype(dtype).torch_dtype 354 tview = self.tensor.view(torch_dtype) 355 return ndarray(tview) 356 357 @normalizer 358 def fill(self, value: ArrayLike): 359 # Both Pytorch and NumPy accept 0D arrays/tensors and scalars, and 360 # error out on D > 0 arrays 361 self.tensor.fill_(value) 362 363 def tolist(self): 364 return self.tensor.tolist() 365 366 def __iter__(self): 367 return (ndarray(x) for x in self.tensor.__iter__()) 368 369 def __str__(self): 370 return ( 371 str(self.tensor) 372 .replace("tensor", "torch.ndarray") 373 .replace("dtype=torch.", "dtype=") 374 ) 375 376 __repr__ = create_method(__str__) 377 378 def __eq__(self, other): 379 try: 380 return _ufuncs.equal(self, other) 381 except (RuntimeError, TypeError): 382 # Failed to convert other to array: definitely not equal. 383 falsy = torch.full(self.shape, fill_value=False, dtype=bool) 384 return asarray(falsy) 385 386 def __ne__(self, other): 387 return ~(self == other) 388 389 def __index__(self): 390 try: 391 return operator.index(self.tensor.item()) 392 except Exception as exc: 393 raise TypeError( 394 "only integer scalar arrays can be converted to a scalar index" 395 ) from exc 396 397 def __bool__(self): 398 return bool(self.tensor) 399 400 def __int__(self): 401 return int(self.tensor) 402 403 def __float__(self): 404 return float(self.tensor) 405 406 def __complex__(self): 407 return complex(self.tensor) 408 409 def is_integer(self): 410 try: 411 v = self.tensor.item() 412 result = int(v) == v 413 except Exception: 414 result = False 415 return result 416 417 def __len__(self): 418 return self.tensor.shape[0] 419 420 def __contains__(self, x): 421 return self.tensor.__contains__(x) 422 423 def transpose(self, *axes): 424 # np.transpose(arr, axis=None) but arr.transpose(*axes) 425 return _funcs.transpose(self, axes) 426 427 def reshape(self, *shape, order="C"): 428 # arr.reshape(shape) and arr.reshape(*shape) 429 return _funcs.reshape(self, shape, order=order) 430 431 def sort(self, axis=-1, kind=None, order=None): 432 # ndarray.sort works in-place 433 _funcs.copyto(self, _funcs.sort(self, axis, kind, order)) 434 435 def item(self, *args): 436 # Mimic NumPy's implementation with three special cases (no arguments, 437 # a flat index and a multi-index): 438 # https://github.com/numpy/numpy/blob/main/numpy/core/src/multiarray/methods.c#L702 439 if args == (): 440 return self.tensor.item() 441 elif len(args) == 1: 442 # int argument 443 return self.ravel()[args[0]] 444 else: 445 return self.__getitem__(args) 446 447 def __getitem__(self, index): 448 tensor = self.tensor 449 450 def neg_step(i, s): 451 if not (isinstance(s, slice) and s.step is not None and s.step < 0): 452 return s 453 454 nonlocal tensor 455 tensor = torch.flip(tensor, (i,)) 456 457 # Account for the fact that a slice includes the start but not the end 458 assert isinstance(s.start, int) or s.start is None 459 assert isinstance(s.stop, int) or s.stop is None 460 start = s.stop + 1 if s.stop else None 461 stop = s.start + 1 if s.start else None 462 463 return slice(start, stop, -s.step) 464 465 if isinstance(index, Sequence): 466 index = type(index)(neg_step(i, s) for i, s in enumerate(index)) 467 else: 468 index = neg_step(0, index) 469 index = _util.ndarrays_to_tensors(index) 470 index = _upcast_int_indices(index) 471 return ndarray(tensor.__getitem__(index)) 472 473 def __setitem__(self, index, value): 474 index = _util.ndarrays_to_tensors(index) 475 index = _upcast_int_indices(index) 476 477 if not _dtypes_impl.is_scalar(value): 478 value = normalize_array_like(value) 479 value = _util.cast_if_needed(value, self.tensor.dtype) 480 481 return self.tensor.__setitem__(index, value) 482 483 take = _funcs.take 484 put = _funcs.put 485 486 def __dlpack__(self, *, stream=None): 487 return self.tensor.__dlpack__(stream=stream) 488 489 def __dlpack_device__(self): 490 return self.tensor.__dlpack_device__() 491 492 493def _tolist(obj): 494 """Recursively convert tensors into lists.""" 495 a1 = [] 496 for elem in obj: 497 if isinstance(elem, (list, tuple)): 498 elem = _tolist(elem) 499 if isinstance(elem, ndarray): 500 a1.append(elem.tensor.tolist()) 501 else: 502 a1.append(elem) 503 return a1 504 505 506# This is the ideally the only place which talks to ndarray directly. 507# The rest goes through asarray (preferred) or array. 508 509 510def array(obj, dtype=None, *, copy=True, order="K", subok=False, ndmin=0, like=None): 511 if subok is not False: 512 raise NotImplementedError("'subok' parameter is not supported.") 513 if like is not None: 514 raise NotImplementedError("'like' parameter is not supported.") 515 if order != "K": 516 raise NotImplementedError 517 518 # a happy path 519 if ( 520 isinstance(obj, ndarray) 521 and copy is False 522 and dtype is None 523 and ndmin <= obj.ndim 524 ): 525 return obj 526 527 if isinstance(obj, (list, tuple)): 528 # FIXME and they have the same dtype, device, etc 529 if obj and all(isinstance(x, torch.Tensor) for x in obj): 530 # list of arrays: *under torch.Dynamo* these are FakeTensors 531 obj = torch.stack(obj) 532 else: 533 # XXX: remove tolist 534 # lists of ndarrays: [1, [2, 3], ndarray(4)] convert to lists of lists 535 obj = _tolist(obj) 536 537 # is obj an ndarray already? 538 if isinstance(obj, ndarray): 539 obj = obj.tensor 540 541 # is a specific dtype requested? 542 torch_dtype = None 543 if dtype is not None: 544 torch_dtype = _dtypes.dtype(dtype).torch_dtype 545 546 tensor = _util._coerce_to_tensor(obj, torch_dtype, copy, ndmin) 547 return ndarray(tensor) 548 549 550def asarray(a, dtype=None, order="K", *, like=None): 551 return array(a, dtype=dtype, order=order, like=like, copy=False, ndmin=0) 552 553 554def ascontiguousarray(a, dtype=None, *, like=None): 555 arr = asarray(a, dtype=dtype, like=like) 556 if not arr.tensor.is_contiguous(): 557 arr.tensor = arr.tensor.contiguous() 558 return arr 559 560 561def from_dlpack(x, /): 562 t = torch.from_dlpack(x) 563 return ndarray(t) 564 565 566def _extract_dtype(entry): 567 try: 568 dty = _dtypes.dtype(entry) 569 except Exception: 570 dty = asarray(entry).dtype 571 return dty 572 573 574def can_cast(from_, to, casting="safe"): 575 from_ = _extract_dtype(from_) 576 to_ = _extract_dtype(to) 577 578 return _dtypes_impl.can_cast_impl(from_.torch_dtype, to_.torch_dtype, casting) 579 580 581def result_type(*arrays_and_dtypes): 582 tensors = [] 583 for entry in arrays_and_dtypes: 584 try: 585 t = asarray(entry).tensor 586 except (RuntimeError, ValueError, TypeError): 587 dty = _dtypes.dtype(entry) 588 t = torch.empty(1, dtype=dty.torch_dtype) 589 tensors.append(t) 590 591 torch_dtype = _dtypes_impl.result_type_impl(*tensors) 592 return _dtypes.dtype(torch_dtype) 593