xref: /aosp_15_r20/external/pytorch/torch/_numpy/_ndarray.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3from __future__ import annotations
4
5import builtins
6import math
7import operator
8from typing import Sequence
9
10import torch
11
12from . import _dtypes, _dtypes_impl, _funcs, _ufuncs, _util
13from ._normalizations import (
14    ArrayLike,
15    normalize_array_like,
16    normalizer,
17    NotImplementedType,
18)
19
20
21newaxis = None
22
23FLAGS = [
24    "C_CONTIGUOUS",
25    "F_CONTIGUOUS",
26    "OWNDATA",
27    "WRITEABLE",
28    "ALIGNED",
29    "WRITEBACKIFCOPY",
30    "FNC",
31    "FORC",
32    "BEHAVED",
33    "CARRAY",
34    "FARRAY",
35]
36
37SHORTHAND_TO_FLAGS = {
38    "C": "C_CONTIGUOUS",
39    "F": "F_CONTIGUOUS",
40    "O": "OWNDATA",
41    "W": "WRITEABLE",
42    "A": "ALIGNED",
43    "X": "WRITEBACKIFCOPY",
44    "B": "BEHAVED",
45    "CA": "CARRAY",
46    "FA": "FARRAY",
47}
48
49
50class Flags:
51    def __init__(self, flag_to_value: dict):
52        assert all(k in FLAGS for k in flag_to_value.keys())  # sanity check
53        self._flag_to_value = flag_to_value
54
55    def __getattr__(self, attr: str):
56        if attr.islower() and attr.upper() in FLAGS:
57            return self[attr.upper()]
58        else:
59            raise AttributeError(f"No flag attribute '{attr}'")
60
61    def __getitem__(self, key):
62        if key in SHORTHAND_TO_FLAGS.keys():
63            key = SHORTHAND_TO_FLAGS[key]
64        if key in FLAGS:
65            try:
66                return self._flag_to_value[key]
67            except KeyError as e:
68                raise NotImplementedError(f"{key=}") from e
69        else:
70            raise KeyError(f"No flag key '{key}'")
71
72    def __setattr__(self, attr, value):
73        if attr.islower() and attr.upper() in FLAGS:
74            self[attr.upper()] = value
75        else:
76            super().__setattr__(attr, value)
77
78    def __setitem__(self, key, value):
79        if key in FLAGS or key in SHORTHAND_TO_FLAGS.keys():
80            raise NotImplementedError("Modifying flags is not implemented")
81        else:
82            raise KeyError(f"No flag key '{key}'")
83
84
85def create_method(fn, name=None):
86    name = name or fn.__name__
87
88    def f(*args, **kwargs):
89        return fn(*args, **kwargs)
90
91    f.__name__ = name
92    f.__qualname__ = f"ndarray.{name}"
93    return f
94
95
96# Map ndarray.name_method -> np.name_func
97# If name_func == None, it means that name_method == name_func
98methods = {
99    "clip": None,
100    "nonzero": None,
101    "repeat": None,
102    "round": None,
103    "squeeze": None,
104    "swapaxes": None,
105    "ravel": None,
106    # linalg
107    "diagonal": None,
108    "dot": None,
109    "trace": None,
110    # sorting
111    "argsort": None,
112    "searchsorted": None,
113    # reductions
114    "argmax": None,
115    "argmin": None,
116    "any": None,
117    "all": None,
118    "max": None,
119    "min": None,
120    "ptp": None,
121    "sum": None,
122    "prod": None,
123    "mean": None,
124    "var": None,
125    "std": None,
126    # scans
127    "cumsum": None,
128    "cumprod": None,
129    # advanced indexing
130    "take": None,
131    "choose": None,
132}
133
134dunder = {
135    "abs": "absolute",
136    "invert": None,
137    "pos": "positive",
138    "neg": "negative",
139    "gt": "greater",
140    "lt": "less",
141    "ge": "greater_equal",
142    "le": "less_equal",
143}
144
145# dunder methods with right-looking and in-place variants
146ri_dunder = {
147    "add": None,
148    "sub": "subtract",
149    "mul": "multiply",
150    "truediv": "divide",
151    "floordiv": "floor_divide",
152    "pow": "power",
153    "mod": "remainder",
154    "and": "bitwise_and",
155    "or": "bitwise_or",
156    "xor": "bitwise_xor",
157    "lshift": "left_shift",
158    "rshift": "right_shift",
159    "matmul": None,
160}
161
162
163def _upcast_int_indices(index):
164    if isinstance(index, torch.Tensor):
165        if index.dtype in (torch.int8, torch.int16, torch.int32, torch.uint8):
166            return index.to(torch.int64)
167    elif isinstance(index, tuple):
168        return tuple(_upcast_int_indices(i) for i in index)
169    return index
170
171
172# Used to indicate that a parameter is unspecified (as opposed to explicitly
173# `None`)
174class _Unspecified:
175    pass
176
177
178_Unspecified.unspecified = _Unspecified()
179
180###############################################################
181#                      ndarray class                          #
182###############################################################
183
184
185class ndarray:
186    def __init__(self, t=None):
187        if t is None:
188            self.tensor = torch.Tensor()
189        elif isinstance(t, torch.Tensor):
190            self.tensor = t
191        else:
192            raise ValueError(
193                "ndarray constructor is not recommended; prefer"
194                "either array(...) or zeros/empty(...)"
195            )
196
197    # Register NumPy functions as methods
198    for method, name in methods.items():
199        fn = getattr(_funcs, name or method)
200        vars()[method] = create_method(fn, method)
201
202    # Regular methods but coming from ufuncs
203    conj = create_method(_ufuncs.conjugate, "conj")
204    conjugate = create_method(_ufuncs.conjugate)
205
206    for method, name in dunder.items():
207        fn = getattr(_ufuncs, name or method)
208        method = f"__{method}__"
209        vars()[method] = create_method(fn, method)
210
211    for method, name in ri_dunder.items():
212        fn = getattr(_ufuncs, name or method)
213        plain = f"__{method}__"
214        vars()[plain] = create_method(fn, plain)
215        rvar = f"__r{method}__"
216        vars()[rvar] = create_method(lambda self, other, fn=fn: fn(other, self), rvar)
217        ivar = f"__i{method}__"
218        vars()[ivar] = create_method(
219            lambda self, other, fn=fn: fn(self, other, out=self), ivar
220        )
221
222    # There's no __idivmod__
223    __divmod__ = create_method(_ufuncs.divmod, "__divmod__")
224    __rdivmod__ = create_method(
225        lambda self, other: _ufuncs.divmod(other, self), "__rdivmod__"
226    )
227
228    # prevent loop variables leaking into the ndarray class namespace
229    del ivar, rvar, name, plain, fn, method
230
231    @property
232    def shape(self):
233        return tuple(self.tensor.shape)
234
235    @property
236    def size(self):
237        return self.tensor.numel()
238
239    @property
240    def ndim(self):
241        return self.tensor.ndim
242
243    @property
244    def dtype(self):
245        return _dtypes.dtype(self.tensor.dtype)
246
247    @property
248    def strides(self):
249        elsize = self.tensor.element_size()
250        return tuple(stride * elsize for stride in self.tensor.stride())
251
252    @property
253    def itemsize(self):
254        return self.tensor.element_size()
255
256    @property
257    def flags(self):
258        # Note contiguous in torch is assumed C-style
259        return Flags(
260            {
261                "C_CONTIGUOUS": self.tensor.is_contiguous(),
262                "F_CONTIGUOUS": self.T.tensor.is_contiguous(),
263                "OWNDATA": self.tensor._base is None,
264                "WRITEABLE": True,  # pytorch does not have readonly tensors
265            }
266        )
267
268    @property
269    def data(self):
270        return self.tensor.data_ptr()
271
272    @property
273    def nbytes(self):
274        return self.tensor.storage().nbytes()
275
276    @property
277    def T(self):
278        return self.transpose()
279
280    @property
281    def real(self):
282        return _funcs.real(self)
283
284    @real.setter
285    def real(self, value):
286        self.tensor.real = asarray(value).tensor
287
288    @property
289    def imag(self):
290        return _funcs.imag(self)
291
292    @imag.setter
293    def imag(self, value):
294        self.tensor.imag = asarray(value).tensor
295
296    # ctors
297    def astype(self, dtype, order="K", casting="unsafe", subok=True, copy=True):
298        if order != "K":
299            raise NotImplementedError(f"astype(..., order={order} is not implemented.")
300        if casting != "unsafe":
301            raise NotImplementedError(
302                f"astype(..., casting={casting} is not implemented."
303            )
304        if not subok:
305            raise NotImplementedError(f"astype(..., subok={subok} is not implemented.")
306        if not copy:
307            raise NotImplementedError(f"astype(..., copy={copy} is not implemented.")
308        torch_dtype = _dtypes.dtype(dtype).torch_dtype
309        t = self.tensor.to(torch_dtype)
310        return ndarray(t)
311
312    @normalizer
313    def copy(self: ArrayLike, order: NotImplementedType = "C"):
314        return self.clone()
315
316    @normalizer
317    def flatten(self: ArrayLike, order: NotImplementedType = "C"):
318        return torch.flatten(self)
319
320    def resize(self, *new_shape, refcheck=False):
321        # NB: differs from np.resize: fills with zeros instead of making repeated copies of input.
322        if refcheck:
323            raise NotImplementedError(
324                f"resize(..., refcheck={refcheck} is not implemented."
325            )
326        if new_shape in [(), (None,)]:
327            return
328
329        # support both x.resize((2, 2)) and x.resize(2, 2)
330        if len(new_shape) == 1:
331            new_shape = new_shape[0]
332        if isinstance(new_shape, int):
333            new_shape = (new_shape,)
334
335        if builtins.any(x < 0 for x in new_shape):
336            raise ValueError("all elements of `new_shape` must be non-negative")
337
338        new_numel, old_numel = math.prod(new_shape), self.tensor.numel()
339
340        self.tensor.resize_(new_shape)
341
342        if new_numel >= old_numel:
343            # zero-fill new elements
344            assert self.tensor.is_contiguous()
345            b = self.tensor.flatten()  # does not copy
346            b[old_numel:].zero_()
347
348    def view(self, dtype=_Unspecified.unspecified, type=_Unspecified.unspecified):
349        if dtype is _Unspecified.unspecified:
350            dtype = self.dtype
351        if type is not _Unspecified.unspecified:
352            raise NotImplementedError(f"view(..., type={type} is not implemented.")
353        torch_dtype = _dtypes.dtype(dtype).torch_dtype
354        tview = self.tensor.view(torch_dtype)
355        return ndarray(tview)
356
357    @normalizer
358    def fill(self, value: ArrayLike):
359        # Both Pytorch and NumPy accept 0D arrays/tensors and scalars, and
360        # error out on D > 0 arrays
361        self.tensor.fill_(value)
362
363    def tolist(self):
364        return self.tensor.tolist()
365
366    def __iter__(self):
367        return (ndarray(x) for x in self.tensor.__iter__())
368
369    def __str__(self):
370        return (
371            str(self.tensor)
372            .replace("tensor", "torch.ndarray")
373            .replace("dtype=torch.", "dtype=")
374        )
375
376    __repr__ = create_method(__str__)
377
378    def __eq__(self, other):
379        try:
380            return _ufuncs.equal(self, other)
381        except (RuntimeError, TypeError):
382            # Failed to convert other to array: definitely not equal.
383            falsy = torch.full(self.shape, fill_value=False, dtype=bool)
384            return asarray(falsy)
385
386    def __ne__(self, other):
387        return ~(self == other)
388
389    def __index__(self):
390        try:
391            return operator.index(self.tensor.item())
392        except Exception as exc:
393            raise TypeError(
394                "only integer scalar arrays can be converted to a scalar index"
395            ) from exc
396
397    def __bool__(self):
398        return bool(self.tensor)
399
400    def __int__(self):
401        return int(self.tensor)
402
403    def __float__(self):
404        return float(self.tensor)
405
406    def __complex__(self):
407        return complex(self.tensor)
408
409    def is_integer(self):
410        try:
411            v = self.tensor.item()
412            result = int(v) == v
413        except Exception:
414            result = False
415        return result
416
417    def __len__(self):
418        return self.tensor.shape[0]
419
420    def __contains__(self, x):
421        return self.tensor.__contains__(x)
422
423    def transpose(self, *axes):
424        # np.transpose(arr, axis=None) but arr.transpose(*axes)
425        return _funcs.transpose(self, axes)
426
427    def reshape(self, *shape, order="C"):
428        # arr.reshape(shape) and arr.reshape(*shape)
429        return _funcs.reshape(self, shape, order=order)
430
431    def sort(self, axis=-1, kind=None, order=None):
432        # ndarray.sort works in-place
433        _funcs.copyto(self, _funcs.sort(self, axis, kind, order))
434
435    def item(self, *args):
436        # Mimic NumPy's implementation with three special cases (no arguments,
437        # a flat index and a multi-index):
438        # https://github.com/numpy/numpy/blob/main/numpy/core/src/multiarray/methods.c#L702
439        if args == ():
440            return self.tensor.item()
441        elif len(args) == 1:
442            # int argument
443            return self.ravel()[args[0]]
444        else:
445            return self.__getitem__(args)
446
447    def __getitem__(self, index):
448        tensor = self.tensor
449
450        def neg_step(i, s):
451            if not (isinstance(s, slice) and s.step is not None and s.step < 0):
452                return s
453
454            nonlocal tensor
455            tensor = torch.flip(tensor, (i,))
456
457            # Account for the fact that a slice includes the start but not the end
458            assert isinstance(s.start, int) or s.start is None
459            assert isinstance(s.stop, int) or s.stop is None
460            start = s.stop + 1 if s.stop else None
461            stop = s.start + 1 if s.start else None
462
463            return slice(start, stop, -s.step)
464
465        if isinstance(index, Sequence):
466            index = type(index)(neg_step(i, s) for i, s in enumerate(index))
467        else:
468            index = neg_step(0, index)
469        index = _util.ndarrays_to_tensors(index)
470        index = _upcast_int_indices(index)
471        return ndarray(tensor.__getitem__(index))
472
473    def __setitem__(self, index, value):
474        index = _util.ndarrays_to_tensors(index)
475        index = _upcast_int_indices(index)
476
477        if not _dtypes_impl.is_scalar(value):
478            value = normalize_array_like(value)
479            value = _util.cast_if_needed(value, self.tensor.dtype)
480
481        return self.tensor.__setitem__(index, value)
482
483    take = _funcs.take
484    put = _funcs.put
485
486    def __dlpack__(self, *, stream=None):
487        return self.tensor.__dlpack__(stream=stream)
488
489    def __dlpack_device__(self):
490        return self.tensor.__dlpack_device__()
491
492
493def _tolist(obj):
494    """Recursively convert tensors into lists."""
495    a1 = []
496    for elem in obj:
497        if isinstance(elem, (list, tuple)):
498            elem = _tolist(elem)
499        if isinstance(elem, ndarray):
500            a1.append(elem.tensor.tolist())
501        else:
502            a1.append(elem)
503    return a1
504
505
506# This is the ideally the only place which talks to ndarray directly.
507# The rest goes through asarray (preferred) or array.
508
509
510def array(obj, dtype=None, *, copy=True, order="K", subok=False, ndmin=0, like=None):
511    if subok is not False:
512        raise NotImplementedError("'subok' parameter is not supported.")
513    if like is not None:
514        raise NotImplementedError("'like' parameter is not supported.")
515    if order != "K":
516        raise NotImplementedError
517
518    # a happy path
519    if (
520        isinstance(obj, ndarray)
521        and copy is False
522        and dtype is None
523        and ndmin <= obj.ndim
524    ):
525        return obj
526
527    if isinstance(obj, (list, tuple)):
528        # FIXME and they have the same dtype, device, etc
529        if obj and all(isinstance(x, torch.Tensor) for x in obj):
530            # list of arrays: *under torch.Dynamo* these are FakeTensors
531            obj = torch.stack(obj)
532        else:
533            # XXX: remove tolist
534            # lists of ndarrays: [1, [2, 3], ndarray(4)] convert to lists of lists
535            obj = _tolist(obj)
536
537    # is obj an ndarray already?
538    if isinstance(obj, ndarray):
539        obj = obj.tensor
540
541    # is a specific dtype requested?
542    torch_dtype = None
543    if dtype is not None:
544        torch_dtype = _dtypes.dtype(dtype).torch_dtype
545
546    tensor = _util._coerce_to_tensor(obj, torch_dtype, copy, ndmin)
547    return ndarray(tensor)
548
549
550def asarray(a, dtype=None, order="K", *, like=None):
551    return array(a, dtype=dtype, order=order, like=like, copy=False, ndmin=0)
552
553
554def ascontiguousarray(a, dtype=None, *, like=None):
555    arr = asarray(a, dtype=dtype, like=like)
556    if not arr.tensor.is_contiguous():
557        arr.tensor = arr.tensor.contiguous()
558    return arr
559
560
561def from_dlpack(x, /):
562    t = torch.from_dlpack(x)
563    return ndarray(t)
564
565
566def _extract_dtype(entry):
567    try:
568        dty = _dtypes.dtype(entry)
569    except Exception:
570        dty = asarray(entry).dtype
571    return dty
572
573
574def can_cast(from_, to, casting="safe"):
575    from_ = _extract_dtype(from_)
576    to_ = _extract_dtype(to)
577
578    return _dtypes_impl.can_cast_impl(from_.torch_dtype, to_.torch_dtype, casting)
579
580
581def result_type(*arrays_and_dtypes):
582    tensors = []
583    for entry in arrays_and_dtypes:
584        try:
585            t = asarray(entry).tensor
586        except (RuntimeError, ValueError, TypeError):
587            dty = _dtypes.dtype(entry)
588            t = torch.empty(1, dtype=dty.torch_dtype)
589        tensors.append(t)
590
591    torch_dtype = _dtypes_impl.result_type_impl(*tensors)
592    return _dtypes.dtype(torch_dtype)
593