xref: /aosp_15_r20/external/pytorch/test/torch_np/numpy_tests/core/test_multiarray.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2
3import builtins
4import collections.abc
5import ctypes
6import functools
7import io
8import itertools
9import mmap
10import operator
11import os
12import sys
13import tempfile
14import warnings
15import weakref
16from contextlib import contextmanager
17from decimal import Decimal
18from pathlib import Path
19from tempfile import mkstemp
20from unittest import expectedFailure as xfail, skipIf as skipif, SkipTest
21
22import numpy
23import pytest
24from pytest import raises as assert_raises
25
26from torch.testing._internal.common_utils import (
27    instantiate_parametrized_tests,
28    parametrize,
29    run_tests,
30    slowTest as slow,
31    subtest,
32    TEST_WITH_TORCHDYNAMO,
33    TestCase,
34    xfailIfTorchDynamo,
35    xpassIfTorchDynamo,
36)
37
38
39# If we are going to trace through these, we should use NumPy
40# If testing on eager mode, we use torch._numpy
41if TEST_WITH_TORCHDYNAMO:
42    import numpy as np
43    from numpy.testing import (
44        assert_,
45        assert_allclose,
46        assert_almost_equal,
47        assert_array_almost_equal,
48        assert_array_equal,
49        assert_array_less,
50        assert_equal,
51        assert_raises_regex,
52        assert_warns,
53        suppress_warnings,
54    )
55
56else:
57    import torch._numpy as np
58    from torch._numpy.testing import (
59        assert_,
60        assert_allclose,
61        assert_almost_equal,
62        assert_array_almost_equal,
63        assert_array_equal,
64        assert_array_less,
65        assert_equal,
66        assert_raises_regex,
67        assert_warns,
68        suppress_warnings,
69    )
70
71
72skip = functools.partial(skipif, True)
73
74IS_PYPY = False
75IS_PYSTON = False
76HAS_REFCOUNT = True
77
78from numpy.core.tests._locales import CommaDecimalPointLocale
79from numpy.testing._private.utils import _no_tracing, requires_memory
80
81
82# #### stubs to make pytest pass the collections stage ####
83
84
85# defined in numpy/testing/_utils.py
86def runstring(astr, dict):
87    exec(astr, dict)
88
89
90@contextmanager
91def temppath(*args, **kwargs):
92    """Context manager for temporary files.
93
94    Context manager that returns the path to a closed temporary file. Its
95    parameters are the same as for tempfile.mkstemp and are passed directly
96    to that function. The underlying file is removed when the context is
97    exited, so it should be closed at that time.
98
99    Windows does not allow a temporary file to be opened if it is already
100    open, so the underlying file must be closed after opening before it
101    can be opened again.
102
103    """
104    fd, path = mkstemp(*args, **kwargs)
105    os.close(fd)
106    try:
107        yield path
108    finally:
109        os.remove(path)
110
111
112# FIXME
113np.asanyarray = np.asarray
114np.asfortranarray = np.asarray
115
116# #### end stubs
117
118
119def _aligned_zeros(shape, dtype=float, order="C", align=None):
120    """
121    Allocate a new ndarray with aligned memory.
122
123    The ndarray is guaranteed *not* aligned to twice the requested alignment.
124    Eg, if align=4, guarantees it is not aligned to 8. If align=None uses
125    dtype.alignment."""
126    dtype = np.dtype(dtype)
127    if dtype == np.dtype(object):
128        # Can't do this, fall back to standard allocation (which
129        # should always be sufficiently aligned)
130        if align is not None:
131            raise ValueError("object array alignment not supported")
132        return np.zeros(shape, dtype=dtype, order=order)
133    if align is None:
134        align = dtype.alignment
135    if not hasattr(shape, "__len__"):
136        shape = (shape,)
137    size = functools.reduce(operator.mul, shape) * dtype.itemsize
138    buf = np.empty(size + 2 * align + 1, np.uint8)
139
140    ptr = buf.__array_interface__["data"][0]
141    offset = ptr % align
142    if offset != 0:
143        offset = align - offset
144    if (ptr % (2 * align)) == 0:
145        offset += align
146
147    # Note: slices producing 0-size arrays do not necessarily change
148    # data pointer --- so we use and allocate size+1
149    buf = buf[offset : offset + size + 1][:-1]
150    buf.fill(0)
151    data = np.ndarray(shape, dtype, buf, order=order)
152    return data
153
154
155@xpassIfTorchDynamo  # (reason="TODO: flags")
156@instantiate_parametrized_tests
157class TestFlag(TestCase):
158    def setUp(self):
159        self.a = np.arange(10)
160
161    @xfail
162    def test_writeable(self):
163        mydict = locals()
164        self.a.flags.writeable = False
165        assert_raises(ValueError, runstring, "self.a[0] = 3", mydict)
166        assert_raises(ValueError, runstring, "self.a[0:1].itemset(3)", mydict)
167        self.a.flags.writeable = True
168        self.a[0] = 5
169        self.a[0] = 0
170
171    def test_writeable_any_base(self):
172        # Ensure that any base being writeable is sufficient to change flag;
173        # this is especially interesting for arrays from an array interface.
174        arr = np.arange(10)
175
176        class subclass(np.ndarray):
177            pass
178
179        # Create subclass so base will not be collapsed, this is OK to change
180        view1 = arr.view(subclass)
181        view2 = view1[...]
182        arr.flags.writeable = False
183        view2.flags.writeable = False
184        view2.flags.writeable = True  # Can be set to True again.
185
186        arr = np.arange(10)
187
188        class frominterface:
189            def __init__(self, arr):
190                self.arr = arr
191                self.__array_interface__ = arr.__array_interface__
192
193        view1 = np.asarray(frominterface)
194        view2 = view1[...]
195        view2.flags.writeable = False
196        view2.flags.writeable = True
197
198        view1.flags.writeable = False
199        view2.flags.writeable = False
200        with assert_raises(ValueError):
201            # Must assume not writeable, since only base is not:
202            view2.flags.writeable = True
203
204    def test_writeable_from_readonly(self):
205        # gh-9440 - make sure fromstring, from buffer on readonly buffers
206        # set writeable False
207        data = b"\x00" * 100
208        vals = np.frombuffer(data, "B")
209        assert_raises(ValueError, vals.setflags, write=True)
210        types = np.dtype([("vals", "u1"), ("res3", "S4")])
211        values = np.core.records.fromstring(data, types)
212        vals = values["vals"]
213        assert_raises(ValueError, vals.setflags, write=True)
214
215    def test_writeable_from_buffer(self):
216        data = bytearray(b"\x00" * 100)
217        vals = np.frombuffer(data, "B")
218        assert_(vals.flags.writeable)
219        vals.setflags(write=False)
220        assert_(vals.flags.writeable is False)
221        vals.setflags(write=True)
222        assert_(vals.flags.writeable)
223        types = np.dtype([("vals", "u1"), ("res3", "S4")])
224        values = np.core.records.fromstring(data, types)
225        vals = values["vals"]
226        assert_(vals.flags.writeable)
227        vals.setflags(write=False)
228        assert_(vals.flags.writeable is False)
229        vals.setflags(write=True)
230        assert_(vals.flags.writeable)
231
232    @skipif(IS_PYPY, reason="PyPy always copies")
233    def test_writeable_pickle(self):
234        import pickle
235
236        # Small arrays will be copied without setting base.
237        # See condition for using PyArray_SetBaseObject in
238        # array_setstate.
239        a = np.arange(1000)
240        for v in range(pickle.HIGHEST_PROTOCOL):
241            vals = pickle.loads(pickle.dumps(a, v))
242            assert_(vals.flags.writeable)
243            assert_(isinstance(vals.base, bytes))
244
245    def test_warnonwrite(self):
246        a = np.arange(10)
247        a.flags._warn_on_write = True
248        with warnings.catch_warnings(record=True) as w:
249            warnings.filterwarnings("always")
250            a[1] = 10
251            a[2] = 10
252            # only warn once
253            assert_(len(w) == 1)
254
255    @parametrize(
256        "flag, flag_value, writeable",
257        [
258            ("writeable", True, True),
259            # Delete _warn_on_write after deprecation and simplify
260            # the parameterization:
261            ("_warn_on_write", True, False),
262            ("writeable", False, False),
263        ],
264    )
265    def test_readonly_flag_protocols(self, flag, flag_value, writeable):
266        a = np.arange(10)
267        setattr(a.flags, flag, flag_value)
268
269        class MyArr:
270            __array_struct__ = a.__array_struct__
271
272        assert memoryview(a).readonly is not writeable
273        assert a.__array_interface__["data"][1] is not writeable
274        assert np.asarray(MyArr()).flags.writeable is writeable
275
276    @xfail
277    def test_otherflags(self):
278        assert_equal(self.a.flags.carray, True)
279        assert_equal(self.a.flags["C"], True)
280        assert_equal(self.a.flags.farray, False)
281        assert_equal(self.a.flags.behaved, True)
282        assert_equal(self.a.flags.fnc, False)
283        assert_equal(self.a.flags.forc, True)
284        assert_equal(self.a.flags.owndata, True)
285        assert_equal(self.a.flags.writeable, True)
286        assert_equal(self.a.flags.aligned, True)
287        assert_equal(self.a.flags.writebackifcopy, False)
288        assert_equal(self.a.flags["X"], False)
289        assert_equal(self.a.flags["WRITEBACKIFCOPY"], False)
290
291    @xfail  # invalid dtype
292    def test_string_align(self):
293        a = np.zeros(4, dtype=np.dtype("|S4"))
294        assert_(a.flags.aligned)
295        # not power of two are accessed byte-wise and thus considered aligned
296        a = np.zeros(5, dtype=np.dtype("|S4"))
297        assert_(a.flags.aligned)
298
299    @xfail  # structured dtypes
300    def test_void_align(self):
301        a = np.zeros(4, dtype=np.dtype([("a", "i4"), ("b", "i4")]))
302        assert_(a.flags.aligned)
303
304
305@xpassIfTorchDynamo  # (reason="TODO: hash")
306class TestHash(TestCase):
307    # see #3793
308    def test_int(self):
309        for st, ut, s in [
310            (np.int8, np.uint8, 8),
311            (np.int16, np.uint16, 16),
312            (np.int32, np.uint32, 32),
313            (np.int64, np.uint64, 64),
314        ]:
315            for i in range(1, s):
316                assert_equal(
317                    hash(st(-(2**i))), hash(-(2**i)), err_msg="%r: -2**%d" % (st, i)
318                )
319                assert_equal(
320                    hash(st(2 ** (i - 1))),
321                    hash(2 ** (i - 1)),
322                    err_msg="%r: 2**%d" % (st, i - 1),
323                )
324                assert_equal(
325                    hash(st(2**i - 1)),
326                    hash(2**i - 1),
327                    err_msg="%r: 2**%d - 1" % (st, i),
328                )
329
330                i = max(i - 1, 1)
331                assert_equal(
332                    hash(ut(2 ** (i - 1))),
333                    hash(2 ** (i - 1)),
334                    err_msg="%r: 2**%d" % (ut, i - 1),
335                )
336                assert_equal(
337                    hash(ut(2**i - 1)),
338                    hash(2**i - 1),
339                    err_msg="%r: 2**%d - 1" % (ut, i),
340                )
341
342
343@xpassIfTorchDynamo  # (reason="TODO: hash")
344class TestAttributes(TestCase):
345    def setUp(self):
346        self.one = np.arange(10)
347        self.two = np.arange(20).reshape(4, 5)
348        self.three = np.arange(60, dtype=np.float64).reshape(2, 5, 6)
349
350    def test_attributes(self):
351        assert_equal(self.one.shape, (10,))
352        assert_equal(self.two.shape, (4, 5))
353        assert_equal(self.three.shape, (2, 5, 6))
354        self.three.shape = (10, 3, 2)
355        assert_equal(self.three.shape, (10, 3, 2))
356        self.three.shape = (2, 5, 6)
357        assert_equal(self.one.strides, (self.one.itemsize,))
358        num = self.two.itemsize
359        assert_equal(self.two.strides, (5 * num, num))
360        num = self.three.itemsize
361        assert_equal(self.three.strides, (30 * num, 6 * num, num))
362        assert_equal(self.one.ndim, 1)
363        assert_equal(self.two.ndim, 2)
364        assert_equal(self.three.ndim, 3)
365        num = self.two.itemsize
366        assert_equal(self.two.size, 20)
367        assert_equal(self.two.nbytes, 20 * num)
368        assert_equal(self.two.itemsize, self.two.dtype.itemsize)
369
370    @xfailIfTorchDynamo  # use ndarray.tensor._base to track the base tensor
371    def test_attributes_2(self):
372        assert_equal(self.two.base, np.arange(20))
373
374    def test_dtypeattr(self):
375        assert_equal(self.one.dtype, np.dtype(np.int_))
376        assert_equal(self.three.dtype, np.dtype(np.float64))
377        assert_equal(self.one.dtype.char, "l")
378        assert_equal(self.three.dtype.char, "d")
379        assert_(self.three.dtype.str[0] in "<>")
380        assert_equal(self.one.dtype.str[1], "i")
381        assert_equal(self.three.dtype.str[1], "f")
382
383    def test_stridesattr(self):
384        x = self.one
385
386        def make_array(size, offset, strides):
387            return np.ndarray(
388                size,
389                buffer=x,
390                dtype=int,
391                offset=offset * x.itemsize,
392                strides=strides * x.itemsize,
393            )
394
395        assert_equal(make_array(4, 4, -1), np.array([4, 3, 2, 1]))
396        assert_raises(ValueError, make_array, 4, 4, -2)
397        assert_raises(ValueError, make_array, 4, 2, -1)
398        assert_raises(ValueError, make_array, 8, 3, 1)
399        assert_equal(make_array(8, 3, 0), np.array([3] * 8))
400        # Check behavior reported in gh-2503:
401        assert_raises(ValueError, make_array, (2, 3), 5, np.array([-2, -3]))
402        make_array(0, 0, 10)
403
404    def test_set_stridesattr(self):
405        x = self.one
406
407        def make_array(size, offset, strides):
408            try:
409                r = np.ndarray([size], dtype=int, buffer=x, offset=offset * x.itemsize)
410            except Exception as e:
411                raise RuntimeError(e)  # noqa: B904
412            r.strides = strides = strides * x.itemsize
413            return r
414
415        assert_equal(make_array(4, 4, -1), np.array([4, 3, 2, 1]))
416        assert_equal(make_array(7, 3, 1), np.array([3, 4, 5, 6, 7, 8, 9]))
417        assert_raises(ValueError, make_array, 4, 4, -2)
418        assert_raises(ValueError, make_array, 4, 2, -1)
419        assert_raises(RuntimeError, make_array, 8, 3, 1)
420        # Check that the true extent of the array is used.
421        # Test relies on as_strided base not exposing a buffer.
422        x = np.lib.stride_tricks.as_strided(np.arange(1), (10, 10), (0, 0))
423
424        def set_strides(arr, strides):
425            arr.strides = strides
426
427        assert_raises(ValueError, set_strides, x, (10 * x.itemsize, x.itemsize))
428
429        # Test for offset calculations:
430        x = np.lib.stride_tricks.as_strided(
431            np.arange(10, dtype=np.int8)[-1], shape=(10,), strides=(-1,)
432        )
433        assert_raises(ValueError, set_strides, x[::-1], -1)
434        a = x[::-1]
435        a.strides = 1
436        a[::2].strides = 2
437
438        # test 0d
439        arr_0d = np.array(0)
440        arr_0d.strides = ()
441        assert_raises(TypeError, set_strides, arr_0d, None)
442
443    def test_fill(self):
444        for t in "?bhilqpBHILQPfdgFDGO":
445            x = np.empty((3, 2, 1), t)
446            y = np.empty((3, 2, 1), t)
447            x.fill(1)
448            y[...] = 1
449            assert_equal(x, y)
450
451    def test_fill_max_uint64(self):
452        x = np.empty((3, 2, 1), dtype=np.uint64)
453        y = np.empty((3, 2, 1), dtype=np.uint64)
454        value = 2**64 - 1
455        y[...] = value
456        x.fill(value)
457        assert_array_equal(x, y)
458
459    def test_fill_struct_array(self):
460        # Filling from a scalar
461        x = np.array([(0, 0.0), (1, 1.0)], dtype="i4,f8")
462        x.fill(x[0])
463        assert_equal(x["f1"][1], x["f1"][0])
464        # Filling from a tuple that can be converted
465        # to a scalar
466        x = np.zeros(2, dtype=[("a", "f8"), ("b", "i4")])
467        x.fill((3.5, -2))
468        assert_array_equal(x["a"], [3.5, 3.5])
469        assert_array_equal(x["b"], [-2, -2])
470
471    def test_fill_readonly(self):
472        # gh-22922
473        a = np.zeros(11)
474        a.setflags(write=False)
475        with pytest.raises(ValueError, match=".*read-only"):
476            a.fill(0)
477
478
479@instantiate_parametrized_tests
480class TestArrayConstruction(TestCase):
481    def test_array(self):
482        d = np.ones(6)
483        r = np.array([d, d])
484        assert_equal(r, np.ones((2, 6)))
485
486        d = np.ones(6)
487        tgt = np.ones((2, 6))
488        r = np.array([d, d])
489        assert_equal(r, tgt)
490        tgt[1] = 2
491        r = np.array([d, d + 1])
492        assert_equal(r, tgt)
493
494        d = np.ones(6)
495        r = np.array([[d, d]])
496        assert_equal(r, np.ones((1, 2, 6)))
497
498        d = np.ones(6)
499        r = np.array([[d, d], [d, d]])
500        assert_equal(r, np.ones((2, 2, 6)))
501
502        d = np.ones((6, 6))
503        r = np.array([d, d])
504        assert_equal(r, np.ones((2, 6, 6)))
505
506        tgt = np.ones((2, 3), dtype=bool)
507        tgt[0, 2] = False
508        tgt[1, 0:2] = False
509        r = np.array([[True, True, False], [False, False, True]])
510        assert_equal(r, tgt)
511        r = np.array([[True, False], [True, False], [False, True]])
512        assert_equal(r, tgt.T)
513
514    @skip(reason="object arrays")
515    def test_array_object(self):
516        d = np.ones((6,))
517        r = np.array([[d, d + 1], d + 2], dtype=object)
518        assert_equal(len(r), 2)
519        assert_equal(r[0], [d, d + 1])
520        assert_equal(r[1], d + 2)
521
522    def test_array_empty(self):
523        assert_raises(TypeError, np.array)
524
525    def test_0d_array_shape(self):
526        assert np.ones(np.array(3)).shape == (3,)
527
528    def test_array_copy_false(self):
529        d = np.array([1, 2, 3])
530        e = np.array(d, copy=False)
531        d[1] = 3
532        assert_array_equal(e, [1, 3, 3])
533
534    @xpassIfTorchDynamo  # (reason="order='F'")
535    def test_array_copy_false_2(self):
536        d = np.array([1, 2, 3])
537        e = np.array(d, copy=False, order="F")
538        d[1] = 4
539        assert_array_equal(e, [1, 4, 3])
540        e[2] = 7
541        assert_array_equal(d, [1, 4, 7])
542
543    def test_array_copy_true(self):
544        d = np.array([[1, 2, 3], [1, 2, 3]])
545        e = np.array(d, copy=True)
546        d[0, 1] = 3
547        e[0, 2] = -7
548        assert_array_equal(e, [[1, 2, -7], [1, 2, 3]])
549        assert_array_equal(d, [[1, 3, 3], [1, 2, 3]])
550
551    @xfail  # (reason="order='F'")
552    def test_array_copy_true_2(self):
553        d = np.array([[1, 2, 3], [1, 2, 3]])
554        e = np.array(d, copy=True, order="F")
555        d[0, 1] = 5
556        e[0, 2] = 7
557        assert_array_equal(e, [[1, 3, 7], [1, 2, 3]])
558        assert_array_equal(d, [[1, 5, 3], [1, 2, 3]])
559
560    @xfailIfTorchDynamo
561    def test_array_cont(self):
562        d = np.ones(10)[::2]
563        assert_(np.ascontiguousarray(d).flags.c_contiguous)
564        assert_(np.ascontiguousarray(d).flags.f_contiguous)
565        assert_(np.asfortranarray(d).flags.c_contiguous)
566        # assert_(np.asfortranarray(d).flags.f_contiguous)   # XXX: f ordering
567        d = np.ones((10, 10))[::2, ::2]
568        assert_(np.ascontiguousarray(d).flags.c_contiguous)
569        # assert_(np.asfortranarray(d).flags.f_contiguous)
570
571    @parametrize(
572        "func",
573        [
574            subtest(np.array, name="array"),
575            subtest(np.asarray, name="asarray"),
576            subtest(np.asanyarray, name="asanyarray"),
577            subtest(np.ascontiguousarray, name="ascontiguousarray"),
578            subtest(np.asfortranarray, name="asfortranarray"),
579        ],
580    )
581    def test_bad_arguments_error(self, func):
582        with pytest.raises(TypeError):
583            func(3, dtype="bad dtype")
584        with pytest.raises(TypeError):
585            func()  # missing arguments
586        with pytest.raises(TypeError):
587            func(1, 2, 3, 4, 5, 6, 7, 8)  # too many arguments
588
589    @skip(reason="np.array w/keyword argument")
590    @parametrize(
591        "func",
592        [
593            subtest(np.array, name="array"),
594            subtest(np.asarray, name="asarray"),
595            subtest(np.asanyarray, name="asanyarray"),
596            subtest(np.ascontiguousarray, name="ascontiguousarray"),
597            subtest(np.asfortranarray, name="asfortranarray"),
598        ],
599    )
600    def test_array_as_keyword(self, func):
601        # This should likely be made positional only, but do not change
602        # the name accidentally.
603        if func is np.array:
604            func(object=3)
605        else:
606            func(a=3)
607
608
609class TestAssignment(TestCase):
610    def test_assignment_broadcasting(self):
611        a = np.arange(6).reshape(2, 3)
612
613        # Broadcasting the input to the output
614        a[...] = np.arange(3)
615        assert_equal(a, [[0, 1, 2], [0, 1, 2]])
616        a[...] = np.arange(2).reshape(2, 1)
617        assert_equal(a, [[0, 0, 0], [1, 1, 1]])
618
619        # For compatibility with <= 1.5, a limited version of broadcasting
620        # the output to the input.
621        #
622        # This behavior is inconsistent with NumPy broadcasting
623        # in general, because it only uses one of the two broadcasting
624        # rules (adding a new "1" dimension to the left of the shape),
625        # applied to the output instead of an input. In NumPy 2.0, this kind
626        # of broadcasting assignment will likely be disallowed.
627        a[...] = np.flip(np.arange(6)).reshape(1, 2, 3)
628        assert_equal(a, [[5, 4, 3], [2, 1, 0]])
629        # The other type of broadcasting would require a reduction operation.
630
631        def assign(a, b):
632            a[...] = b
633
634        assert_raises(
635            (RuntimeError, ValueError), assign, a, np.arange(12).reshape(2, 2, 3)
636        )
637
638    def test_assignment_errors(self):
639        # Address issue #2276
640        class C:
641            pass
642
643        a = np.zeros(1)
644
645        def assign(v):
646            a[0] = v
647
648        assert_raises((RuntimeError, TypeError), assign, C())
649        # assert_raises((TypeError, ValueError), assign, [1])  # numpy raises, we do not
650
651    @skip(reason="object arrays")
652    def test_unicode_assignment(self):
653        # gh-5049
654        from numpy.core.numeric import set_string_function
655
656        @contextmanager
657        def inject_str(s):
658            """replace ndarray.__str__ temporarily"""
659            set_string_function(lambda x: s, repr=False)
660            try:
661                yield
662            finally:
663                set_string_function(None, repr=False)
664
665        a1d = np.array(["test"])
666        a0d = np.array("done")
667        with inject_str("bad"):
668            a1d[0] = a0d  # previously this would invoke __str__
669        assert_equal(a1d[0], "done")
670
671        # this would crash for the same reason
672        np.array([np.array("\xe5\xe4\xf6")])
673
674    @skip(reason="object arrays")
675    def test_stringlike_empty_list(self):
676        # gh-8902
677        u = np.array(["done"])
678        b = np.array([b"done"])
679
680        class bad_sequence:
681            def __getitem__(self, value):
682                pass
683
684            def __len__(self):
685                raise RuntimeError
686
687        assert_raises(ValueError, operator.setitem, u, 0, [])
688        assert_raises(ValueError, operator.setitem, b, 0, [])
689
690        assert_raises(ValueError, operator.setitem, u, 0, bad_sequence())
691        assert_raises(ValueError, operator.setitem, b, 0, bad_sequence())
692
693    @skipif(
694        "torch._numpy" == np.__name__,
695        reason="torch._numpy does not support extended floats and complex dtypes",
696    )
697    def test_longdouble_assignment(self):
698        # only relevant if longdouble is larger than float
699        # we're looking for loss of precision
700
701        for dtype in (np.longdouble, np.clongdouble):
702            # gh-8902
703            tinyb = np.nextafter(np.longdouble(0), 1).astype(dtype)
704            tinya = np.nextafter(np.longdouble(0), -1).astype(dtype)
705
706            # construction
707            tiny1d = np.array([tinya])
708            assert_equal(tiny1d[0], tinya)
709
710            # scalar = scalar
711            tiny1d[0] = tinyb
712            assert_equal(tiny1d[0], tinyb)
713
714            # 0d = scalar
715            tiny1d[0, ...] = tinya
716            assert_equal(tiny1d[0], tinya)
717
718            # 0d = 0d
719            tiny1d[0, ...] = tinyb[...]
720            assert_equal(tiny1d[0], tinyb)
721
722            # scalar = 0d
723            tiny1d[0] = tinyb[...]
724            assert_equal(tiny1d[0], tinyb)
725
726            arr = np.array([np.array(tinya)])
727            assert_equal(arr[0], tinya)
728
729    @skip(reason="object arrays")
730    def test_cast_to_string(self):
731        # cast to str should do "str(scalar)", not "str(scalar.item())"
732        # Example: In python2, str(float) is truncated, so we want to avoid
733        # str(np.float64(...).item()) as this would incorrectly truncate.
734        a = np.zeros(1, dtype="S20")
735        a[:] = np.array(["1.12345678901234567890"], dtype="f8")
736        assert_equal(a[0], b"1.1234567890123457")
737
738
739class TestDtypedescr(TestCase):
740    def test_construction(self):
741        d1 = np.dtype("i4")
742        assert_equal(d1, np.dtype(np.int32))
743        d2 = np.dtype("f8")
744        assert_equal(d2, np.dtype(np.float64))
745
746
747@skip  # (reason="TODO: zero-rank?")   # FIXME: revert skip into xfail
748class TestZeroRank(TestCase):
749    def setUp(self):
750        self.d = np.array(0), np.array("x", object)
751
752    def test_ellipsis_subscript(self):
753        a, b = self.d
754        assert_equal(a[...], 0)
755        assert_equal(b[...], "x")
756        assert_(a[...].base is a)  # `a[...] is a` in numpy <1.9.
757        assert_(b[...].base is b)  # `b[...] is b` in numpy <1.9.
758
759    def test_empty_subscript(self):
760        a, b = self.d
761        assert_equal(a[()], 0)
762        assert_equal(b[()], "x")
763        assert_(type(a[()]) is a.dtype.type)
764        assert_(type(b[()]) is str)
765
766    def test_invalid_subscript(self):
767        a, b = self.d
768        assert_raises(IndexError, lambda x: x[0], a)
769        assert_raises(IndexError, lambda x: x[0], b)
770        assert_raises(IndexError, lambda x: x[np.array([], int)], a)
771        assert_raises(IndexError, lambda x: x[np.array([], int)], b)
772
773    def test_ellipsis_subscript_assignment(self):
774        a, b = self.d
775        a[...] = 42
776        assert_equal(a, 42)
777        b[...] = ""
778        assert_equal(b.item(), "")
779
780    def test_empty_subscript_assignment(self):
781        a, b = self.d
782        a[()] = 42
783        assert_equal(a, 42)
784        b[()] = ""
785        assert_equal(b.item(), "")
786
787    def test_invalid_subscript_assignment(self):
788        a, b = self.d
789
790        def assign(x, i, v):
791            x[i] = v
792
793        assert_raises(IndexError, assign, a, 0, 42)
794        assert_raises(IndexError, assign, b, 0, "")
795        assert_raises(ValueError, assign, a, (), "")
796
797    def test_newaxis(self):
798        a, b = self.d
799        assert_equal(a[np.newaxis].shape, (1,))
800        assert_equal(a[..., np.newaxis].shape, (1,))
801        assert_equal(a[np.newaxis, ...].shape, (1,))
802        assert_equal(a[..., np.newaxis].shape, (1,))
803        assert_equal(a[np.newaxis, ..., np.newaxis].shape, (1, 1))
804        assert_equal(a[..., np.newaxis, np.newaxis].shape, (1, 1))
805        assert_equal(a[np.newaxis, np.newaxis, ...].shape, (1, 1))
806        assert_equal(a[(np.newaxis,) * 10].shape, (1,) * 10)
807
808    def test_invalid_newaxis(self):
809        a, b = self.d
810
811        def subscript(x, i):
812            x[i]
813
814        assert_raises(IndexError, subscript, a, (np.newaxis, 0))
815        assert_raises(IndexError, subscript, a, (np.newaxis,) * 50)
816
817    def test_constructor(self):
818        x = np.ndarray(())
819        x[()] = 5
820        assert_equal(x[()], 5)
821        y = np.ndarray((), buffer=x)
822        y[()] = 6
823        assert_equal(x[()], 6)
824
825        # strides and shape must be the same length
826        with pytest.raises(ValueError):
827            np.ndarray((2,), strides=())
828        with pytest.raises(ValueError):
829            np.ndarray((), strides=(2,))
830
831    def test_output(self):
832        x = np.array(2)
833        assert_raises(ValueError, np.add, x, [1], x)
834
835    def test_real_imag(self):
836        # contiguity checks are for gh-11245
837        x = np.array(1j)
838        xr = x.real
839        xi = x.imag
840
841        assert_equal(xr, np.array(0))
842        assert_(type(xr) is np.ndarray)
843        assert_equal(xr.flags.contiguous, True)
844        assert_equal(xr.flags.f_contiguous, True)
845
846        assert_equal(xi, np.array(1))
847        assert_(type(xi) is np.ndarray)
848        assert_equal(xi.flags.contiguous, True)
849        assert_equal(xi.flags.f_contiguous, True)
850
851
852class TestScalarIndexing(TestCase):
853    def setUp(self):
854        self.d = np.array([0, 1])[0]
855
856    def test_ellipsis_subscript(self):
857        a = self.d
858        assert_equal(a[...], 0)
859        assert_equal(a[...].shape, ())
860
861    def test_empty_subscript(self):
862        a = self.d
863        assert_equal(a[()], 0)
864        assert_equal(a[()].shape, ())
865
866    def test_invalid_subscript(self):
867        a = self.d
868        assert_raises(IndexError, lambda x: x[0], a)
869        assert_raises(IndexError, lambda x: x[np.array([], int)], a)
870
871    def test_invalid_subscript_assignment(self):
872        a = self.d
873
874        def assign(x, i, v):
875            x[i] = v
876
877        assert_raises((IndexError, TypeError), assign, a, 0, 42)
878
879    def test_newaxis(self):
880        a = self.d
881        assert_equal(a[np.newaxis].shape, (1,))
882        assert_equal(a[..., np.newaxis].shape, (1,))
883        assert_equal(a[np.newaxis, ...].shape, (1,))
884        assert_equal(a[..., np.newaxis].shape, (1,))
885        assert_equal(a[np.newaxis, ..., np.newaxis].shape, (1, 1))
886        assert_equal(a[..., np.newaxis, np.newaxis].shape, (1, 1))
887        assert_equal(a[np.newaxis, np.newaxis, ...].shape, (1, 1))
888        assert_equal(a[(np.newaxis,) * 10].shape, (1,) * 10)
889
890    def test_invalid_newaxis(self):
891        a = self.d
892
893        def subscript(x, i):
894            x[i]
895
896        assert_raises(IndexError, subscript, a, (np.newaxis, 0))
897
898        # this assersion fails because 50 > NPY_MAXDIMS = 32
899        # assert_raises(IndexError, subscript, a, (np.newaxis,)*50)
900
901    @xfail  # (reason="pytorch disallows overlapping assignments")
902    def test_overlapping_assignment(self):
903        # With positive strides
904        a = np.arange(4)
905        a[:-1] = a[1:]
906        assert_equal(a, [1, 2, 3, 3])
907
908        a = np.arange(4)
909        a[1:] = a[:-1]
910        assert_equal(a, [0, 0, 1, 2])
911
912        # With positive and negative strides
913        a = np.arange(4)
914        a[:] = a[::-1]
915        assert_equal(a, [3, 2, 1, 0])
916
917        a = np.arange(6).reshape(2, 3)
918        a[::-1, :] = a[:, ::-1]
919        assert_equal(a, [[5, 4, 3], [2, 1, 0]])
920
921        a = np.arange(6).reshape(2, 3)
922        a[::-1, ::-1] = a[:, ::-1]
923        assert_equal(a, [[3, 4, 5], [0, 1, 2]])
924
925        # With just one element overlapping
926        a = np.arange(5)
927        a[:3] = a[2:]
928        assert_equal(a, [2, 3, 4, 3, 4])
929
930        a = np.arange(5)
931        a[2:] = a[:3]
932        assert_equal(a, [0, 1, 0, 1, 2])
933
934        a = np.arange(5)
935        a[2::-1] = a[2:]
936        assert_equal(a, [4, 3, 2, 3, 4])
937
938        a = np.arange(5)
939        a[2:] = a[2::-1]
940        assert_equal(a, [0, 1, 2, 1, 0])
941
942        a = np.arange(5)
943        a[2::-1] = a[:1:-1]
944        assert_equal(a, [2, 3, 4, 3, 4])
945
946        a = np.arange(5)
947        a[:1:-1] = a[2::-1]
948        assert_equal(a, [0, 1, 0, 1, 2])
949
950
951@skip(reason="object, void, structured dtypes")
952@instantiate_parametrized_tests
953class TestCreation(TestCase):
954    """
955    Test the np.array constructor
956    """
957
958    def test_from_attribute(self):
959        class x:
960            def __array__(self, dtype=None):
961                pass
962
963        assert_raises(ValueError, np.array, x())
964
965    def test_from_string(self):
966        types = np.typecodes["AllInteger"] + np.typecodes["Float"]
967        nstr = ["123", "123"]
968        result = np.array([123, 123], dtype=int)
969        for type in types:
970            msg = f"String conversion for {type}"
971            assert_equal(np.array(nstr, dtype=type), result, err_msg=msg)
972
973    def test_void(self):
974        arr = np.array([], dtype="V")
975        assert arr.dtype == "V8"  # current default
976        # Same length scalars (those that go to the same void) work:
977        arr = np.array([b"1234", b"1234"], dtype="V")
978        assert arr.dtype == "V4"
979
980        # Promoting different lengths will fail (pre 1.20 this worked)
981        # by going via S5 and casting to V5.
982        with pytest.raises(TypeError):
983            np.array([b"1234", b"12345"], dtype="V")
984        with pytest.raises(TypeError):
985            np.array([b"12345", b"1234"], dtype="V")
986
987        # Check the same for the casting path:
988        arr = np.array([b"1234", b"1234"], dtype="O").astype("V")
989        assert arr.dtype == "V4"
990        with pytest.raises(TypeError):
991            np.array([b"1234", b"12345"], dtype="O").astype("V")
992
993    @parametrize(
994        #  "idx", [pytest.param(Ellipsis, id="arr"), pytest.param((), id="scalar")]
995        "idx",
996        [subtest(Ellipsis, name="arr"), subtest((), name="scalar")],
997    )
998    def test_structured_void_promotion(self, idx):
999        arr = np.array(
1000            [np.array(1, dtype="i,i")[idx], np.array(2, dtype="i,i")[idx]], dtype="V"
1001        )
1002        assert_array_equal(arr, np.array([(1, 1), (2, 2)], dtype="i,i"))
1003        # The following fails to promote the two dtypes, resulting in an error
1004        with pytest.raises(TypeError):
1005            np.array(
1006                [np.array(1, dtype="i,i")[idx], np.array(2, dtype="i,i,i")[idx]],
1007                dtype="V",
1008            )
1009
1010    def test_too_big_error(self):
1011        # 45341 is the smallest integer greater than sqrt(2**31 - 1).
1012        # 3037000500 is the smallest integer greater than sqrt(2**63 - 1).
1013        # We want to make sure that the square byte array with those dimensions
1014        # is too big on 32 or 64 bit systems respectively.
1015        if np.iinfo("intp").max == 2**31 - 1:
1016            shape = (46341, 46341)
1017        elif np.iinfo("intp").max == 2**63 - 1:
1018            shape = (3037000500, 3037000500)
1019        else:
1020            return
1021        assert_raises(ValueError, np.empty, shape, dtype=np.int8)
1022        assert_raises(ValueError, np.zeros, shape, dtype=np.int8)
1023        assert_raises(ValueError, np.ones, shape, dtype=np.int8)
1024
1025    @skipif(
1026        np.dtype(np.intp).itemsize != 8, reason="malloc may not fail on 32 bit systems"
1027    )
1028    def test_malloc_fails(self):
1029        # This test is guaranteed to fail due to a too large allocation
1030        with assert_raises(np.core._exceptions._ArrayMemoryError):
1031            np.empty(np.iinfo(np.intp).max, dtype=np.uint8)
1032
1033    def test_zeros(self):
1034        types = np.typecodes["AllInteger"] + np.typecodes["AllFloat"]
1035        for dt in types:
1036            d = np.zeros((13,), dtype=dt)
1037            assert_equal(np.count_nonzero(d), 0)
1038            # true for ieee floats
1039            assert_equal(d.sum(), 0)
1040            assert_(not d.any())
1041
1042            d = np.zeros(2, dtype="(2,4)i4")
1043            assert_equal(np.count_nonzero(d), 0)
1044            assert_equal(d.sum(), 0)
1045            assert_(not d.any())
1046
1047            d = np.zeros(2, dtype="4i4")
1048            assert_equal(np.count_nonzero(d), 0)
1049            assert_equal(d.sum(), 0)
1050            assert_(not d.any())
1051
1052            d = np.zeros(2, dtype="(2,4)i4, (2,4)i4")
1053            assert_equal(np.count_nonzero(d), 0)
1054
1055    @slow
1056    def test_zeros_big(self):
1057        # test big array as they might be allocated different by the system
1058        types = np.typecodes["AllInteger"] + np.typecodes["AllFloat"]
1059        for dt in types:
1060            d = np.zeros((30 * 1024**2,), dtype=dt)
1061            assert_(not d.any())
1062            # This test can fail on 32-bit systems due to insufficient
1063            # contiguous memory. Deallocating the previous array increases the
1064            # chance of success.
1065            del d
1066
1067    def test_zeros_obj(self):
1068        # test initialization from PyLong(0)
1069        d = np.zeros((13,), dtype=object)
1070        assert_array_equal(d, [0] * 13)
1071        assert_equal(np.count_nonzero(d), 0)
1072
1073    def test_zeros_obj_obj(self):
1074        d = np.zeros(10, dtype=[("k", object, 2)])
1075        assert_array_equal(d["k"], 0)
1076
1077    def test_zeros_like_like_zeros(self):
1078        # test zeros_like returns the same as zeros
1079        for c in np.typecodes["All"]:
1080            if c == "V":
1081                continue
1082            d = np.zeros((3, 3), dtype=c)
1083            assert_array_equal(np.zeros_like(d), d)
1084            assert_equal(np.zeros_like(d).dtype, d.dtype)
1085        # explicitly check some special cases
1086        d = np.zeros((3, 3), dtype="S5")
1087        assert_array_equal(np.zeros_like(d), d)
1088        assert_equal(np.zeros_like(d).dtype, d.dtype)
1089        d = np.zeros((3, 3), dtype="U5")
1090        assert_array_equal(np.zeros_like(d), d)
1091        assert_equal(np.zeros_like(d).dtype, d.dtype)
1092
1093        d = np.zeros((3, 3), dtype="<i4")
1094        assert_array_equal(np.zeros_like(d), d)
1095        assert_equal(np.zeros_like(d).dtype, d.dtype)
1096        d = np.zeros((3, 3), dtype=">i4")
1097        assert_array_equal(np.zeros_like(d), d)
1098        assert_equal(np.zeros_like(d).dtype, d.dtype)
1099
1100        d = np.zeros((3, 3), dtype="<M8[s]")
1101        assert_array_equal(np.zeros_like(d), d)
1102        assert_equal(np.zeros_like(d).dtype, d.dtype)
1103        d = np.zeros((3, 3), dtype=">M8[s]")
1104        assert_array_equal(np.zeros_like(d), d)
1105        assert_equal(np.zeros_like(d).dtype, d.dtype)
1106
1107        d = np.zeros((3, 3), dtype="f4,f4")
1108        assert_array_equal(np.zeros_like(d), d)
1109        assert_equal(np.zeros_like(d).dtype, d.dtype)
1110
1111    def test_empty_unicode(self):
1112        # don't throw decode errors on garbage memory
1113        for i in range(5, 100, 5):
1114            d = np.empty(i, dtype="U")
1115            str(d)
1116
1117    def test_sequence_non_homogeneous(self):
1118        assert_equal(np.array([4, 2**80]).dtype, object)
1119        assert_equal(np.array([4, 2**80, 4]).dtype, object)
1120        assert_equal(np.array([2**80, 4]).dtype, object)
1121        assert_equal(np.array([2**80] * 3).dtype, object)
1122        assert_equal(np.array([[1, 1], [1j, 1j]]).dtype, complex)
1123        assert_equal(np.array([[1j, 1j], [1, 1]]).dtype, complex)
1124        assert_equal(np.array([[1, 1, 1], [1, 1j, 1.0], [1, 1, 1]]).dtype, complex)
1125
1126    def test_non_sequence_sequence(self):
1127        """Should not segfault.
1128
1129        Class Fail breaks the sequence protocol for new style classes, i.e.,
1130        those derived from object. Class Map is a mapping type indicated by
1131        raising a ValueError. At some point we may raise a warning instead
1132        of an error in the Fail case.
1133
1134        """
1135
1136        class Fail:
1137            def __len__(self):
1138                return 1
1139
1140            def __getitem__(self, index):
1141                raise ValueError
1142
1143        class Map:
1144            def __len__(self):
1145                return 1
1146
1147            def __getitem__(self, index):
1148                raise KeyError
1149
1150        a = np.array([Map()])
1151        assert_(a.shape == (1,))
1152        assert_(a.dtype == np.dtype(object))
1153        assert_raises(ValueError, np.array, [Fail()])
1154
1155    def test_no_len_object_type(self):
1156        # gh-5100, want object array from iterable object without len()
1157        class Point2:
1158            def __init__(self) -> None:
1159                pass
1160
1161            def __getitem__(self, ind):
1162                if ind in [0, 1]:
1163                    return ind
1164                else:
1165                    raise IndexError
1166
1167        d = np.array([Point2(), Point2(), Point2()])
1168        assert_equal(d.dtype, np.dtype(object))
1169
1170    def test_false_len_sequence(self):
1171        # gh-7264, segfault for this example
1172        class C:
1173            def __getitem__(self, i):
1174                raise IndexError
1175
1176            def __len__(self):
1177                return 42
1178
1179        a = np.array(C())  # segfault?
1180        assert_equal(len(a), 0)
1181
1182    def test_false_len_iterable(self):
1183        # Special case where a bad __getitem__ makes us fall back on __iter__:
1184        class C:
1185            def __getitem__(self, x):
1186                raise Exception  # noqa: TRY002
1187
1188            def __iter__(self):
1189                return iter(())
1190
1191            def __len__(self):
1192                return 2
1193
1194        a = np.empty(2)
1195        with assert_raises(ValueError):
1196            a[:] = C()  # Segfault!
1197
1198        assert_equal(np.array(C()), list(C()))
1199
1200    def test_failed_len_sequence(self):
1201        # gh-7393
1202        class A:
1203            def __init__(self, data):
1204                self._data = data
1205
1206            def __getitem__(self, item):
1207                return type(self)(self._data[item])
1208
1209            def __len__(self):
1210                return len(self._data)
1211
1212        # len(d) should give 3, but len(d[0]) will fail
1213        d = A([1, 2, 3])
1214        assert_equal(len(np.array(d)), 3)
1215
1216    def test_array_too_big(self):
1217        # Test that array creation succeeds for arrays addressable by intp
1218        # on the byte level and fails for too large arrays.
1219        buf = np.zeros(100)
1220
1221        max_bytes = np.iinfo(np.intp).max
1222        for dtype in ["intp", "S20", "b"]:
1223            dtype = np.dtype(dtype)
1224            itemsize = dtype.itemsize
1225
1226            np.ndarray(
1227                buffer=buf, strides=(0,), shape=(max_bytes // itemsize,), dtype=dtype
1228            )
1229            assert_raises(
1230                ValueError,
1231                np.ndarray,
1232                buffer=buf,
1233                strides=(0,),
1234                shape=(max_bytes // itemsize + 1,),
1235                dtype=dtype,
1236            )
1237
1238    def _ragged_creation(self, seq):
1239        # without dtype=object, the ragged object raises
1240        with pytest.raises(ValueError, match=".*detected shape was"):
1241            a = np.array(seq)
1242
1243        return np.array(seq, dtype=object)
1244
1245    def test_ragged_ndim_object(self):
1246        # Lists of mismatching depths are treated as object arrays
1247        a = self._ragged_creation([[1], 2, 3])
1248        assert_equal(a.shape, (3,))
1249        assert_equal(a.dtype, object)
1250
1251        a = self._ragged_creation([1, [2], 3])
1252        assert_equal(a.shape, (3,))
1253        assert_equal(a.dtype, object)
1254
1255        a = self._ragged_creation([1, 2, [3]])
1256        assert_equal(a.shape, (3,))
1257        assert_equal(a.dtype, object)
1258
1259    def test_ragged_shape_object(self):
1260        # The ragged dimension of a list is turned into an object array
1261        a = self._ragged_creation([[1, 1], [2], [3]])
1262        assert_equal(a.shape, (3,))
1263        assert_equal(a.dtype, object)
1264
1265        a = self._ragged_creation([[1], [2, 2], [3]])
1266        assert_equal(a.shape, (3,))
1267        assert_equal(a.dtype, object)
1268
1269        a = self._ragged_creation([[1], [2], [3, 3]])
1270        assert a.shape == (3,)
1271        assert a.dtype == object
1272
1273    def test_array_of_ragged_array(self):
1274        outer = np.array([None, None])
1275        outer[0] = outer[1] = np.array([1, 2, 3])
1276        assert np.array(outer).shape == (2,)
1277        assert np.array([outer]).shape == (1, 2)
1278
1279        outer_ragged = np.array([None, None])
1280        outer_ragged[0] = np.array([1, 2, 3])
1281        outer_ragged[1] = np.array([1, 2, 3, 4])
1282        # should both of these emit deprecation warnings?
1283        assert np.array(outer_ragged).shape == (2,)
1284        assert np.array([outer_ragged]).shape == (
1285            1,
1286            2,
1287        )
1288
1289    def test_deep_nonragged_object(self):
1290        # None of these should raise, even though they are missing dtype=object
1291        a = np.array([[[Decimal(1)]]])
1292        a = np.array([1, Decimal(1)])
1293        a = np.array([[1], [Decimal(1)]])
1294
1295    @parametrize("dtype", [object, "O,O", "O,(3)O", "(2,3)O"])
1296    @parametrize(
1297        "function",
1298        [
1299            np.ndarray,
1300            np.empty,
1301            lambda shape, dtype: np.empty_like(np.empty(shape, dtype=dtype)),
1302        ],
1303    )
1304    def test_object_initialized_to_None(self, function, dtype):
1305        # NumPy has support for object fields to be NULL (meaning None)
1306        # but generally, we should always fill with the proper None, and
1307        # downstream may rely on that.  (For fully initialized arrays!)
1308        arr = function(3, dtype=dtype)
1309        # We expect a fill value of None, which is not NULL:
1310        expected = np.array(None).tobytes()
1311        expected = expected * (arr.nbytes // len(expected))
1312        assert arr.tobytes() == expected
1313
1314
1315class TestBool(TestCase):
1316    @xfail  # (reason="bools not interned")
1317    def test_test_interning(self):
1318        a0 = np.bool_(0)
1319        b0 = np.bool_(False)
1320        assert_(a0 is b0)
1321        a1 = np.bool_(1)
1322        b1 = np.bool_(True)
1323        assert_(a1 is b1)
1324        assert_(np.array([True])[0] is a1)
1325        assert_(np.array(True)[()] is a1)
1326
1327    def test_sum(self):
1328        d = np.ones(101, dtype=bool)
1329        assert_equal(d.sum(), d.size)
1330        assert_equal(d[::2].sum(), d[::2].size)
1331        # assert_equal(d[::-2].sum(), d[::-2].size)
1332
1333    @xpassIfTorchDynamo  # (reason="frombuffer")
1334    def test_sum_2(self):
1335        d = np.frombuffer(b"\xff\xff" * 100, dtype=bool)
1336        assert_equal(d.sum(), d.size)
1337        assert_equal(d[::2].sum(), d[::2].size)
1338        assert_equal(d[::-2].sum(), d[::-2].size)
1339
1340    def check_count_nonzero(self, power, length):
1341        powers = [2**i for i in range(length)]
1342        for i in range(2**power):
1343            l = [(i & x) != 0 for x in powers]
1344            a = np.array(l, dtype=bool)
1345            c = builtins.sum(l)
1346            assert_equal(np.count_nonzero(a), c)
1347            av = a.view(np.uint8)
1348            av *= 3
1349            assert_equal(np.count_nonzero(a), c)
1350            av *= 4
1351            assert_equal(np.count_nonzero(a), c)
1352            av[av != 0] = 0xFF
1353            assert_equal(np.count_nonzero(a), c)
1354
1355    def test_count_nonzero(self):
1356        # check all 12 bit combinations in a length 17 array
1357        # covers most cases of the 16 byte unrolled code
1358        self.check_count_nonzero(12, 17)
1359
1360    @slow
1361    def test_count_nonzero_all(self):
1362        # check all combinations in a length 17 array
1363        # covers all cases of the 16 byte unrolled code
1364        self.check_count_nonzero(17, 17)
1365
1366    def test_count_nonzero_unaligned(self):
1367        # prevent mistakes as e.g. gh-4060
1368        for o in range(7):
1369            a = np.zeros((18,), dtype=bool)[o + 1 :]
1370            a[:o] = True
1371            assert_equal(np.count_nonzero(a), builtins.sum(a.tolist()))
1372            a = np.ones((18,), dtype=bool)[o + 1 :]
1373            a[:o] = False
1374            assert_equal(np.count_nonzero(a), builtins.sum(a.tolist()))
1375
1376    def _test_cast_from_flexible(self, dtype):
1377        # empty string -> false
1378        for n in range(3):
1379            v = np.array(b"", (dtype, n))
1380            assert_equal(bool(v), False)
1381            assert_equal(bool(v[()]), False)
1382            assert_equal(v.astype(bool), False)
1383            assert_(isinstance(v.astype(bool), np.ndarray))
1384            assert_(v[()].astype(bool) is np.False_)
1385
1386        # anything else -> true
1387        for n in range(1, 4):
1388            for val in [b"a", b"0", b" "]:
1389                v = np.array(val, (dtype, n))
1390                assert_equal(bool(v), True)
1391                assert_equal(bool(v[()]), True)
1392                assert_equal(v.astype(bool), True)
1393                assert_(isinstance(v.astype(bool), np.ndarray))
1394                assert_(v[()].astype(bool) is np.True_)
1395
1396    @skip(reason="np.void")
1397    def test_cast_from_void(self):
1398        self._test_cast_from_flexible(np.void)
1399
1400    @xfail  # (reason="See gh-9847")
1401    def test_cast_from_unicode(self):
1402        self._test_cast_from_flexible(np.str_)
1403
1404    @xfail  # (reason="See gh-9847")
1405    def test_cast_from_bytes(self):
1406        self._test_cast_from_flexible(np.bytes_)
1407
1408
1409@instantiate_parametrized_tests
1410class TestMethods(TestCase):
1411    sort_kinds = ["quicksort", "heapsort", "stable"]
1412
1413    @xpassIfTorchDynamo  # (reason="all(..., where=...)")
1414    def test_all_where(self):
1415        a = np.array([[True, False, True], [False, False, False], [True, True, True]])
1416        wh_full = np.array(
1417            [[True, False, True], [False, False, False], [True, False, True]]
1418        )
1419        wh_lower = np.array([[False], [False], [True]])
1420        for _ax in [0, None]:
1421            assert_equal(
1422                a.all(axis=_ax, where=wh_lower), np.all(a[wh_lower[:, 0], :], axis=_ax)
1423            )
1424            assert_equal(
1425                np.all(a, axis=_ax, where=wh_lower), a[wh_lower[:, 0], :].all(axis=_ax)
1426            )
1427
1428        assert_equal(a.all(where=wh_full), True)
1429        assert_equal(np.all(a, where=wh_full), True)
1430        assert_equal(a.all(where=False), True)
1431        assert_equal(np.all(a, where=False), True)
1432
1433    @xpassIfTorchDynamo  # (reason="any(..., where=...)")
1434    def test_any_where(self):
1435        a = np.array([[True, False, True], [False, False, False], [True, True, True]])
1436        wh_full = np.array(
1437            [[False, True, False], [True, True, True], [False, False, False]]
1438        )
1439        wh_middle = np.array([[False], [True], [False]])
1440        for _ax in [0, None]:
1441            assert_equal(
1442                a.any(axis=_ax, where=wh_middle),
1443                np.any(a[wh_middle[:, 0], :], axis=_ax),
1444            )
1445            assert_equal(
1446                np.any(a, axis=_ax, where=wh_middle),
1447                a[wh_middle[:, 0], :].any(axis=_ax),
1448            )
1449        assert_equal(a.any(where=wh_full), False)
1450        assert_equal(np.any(a, where=wh_full), False)
1451        assert_equal(a.any(where=False), False)
1452        assert_equal(np.any(a, where=False), False)
1453
1454    @xpassIfTorchDynamo  # (reason="TODO: compress")
1455    def test_compress(self):
1456        tgt = [[5, 6, 7, 8, 9]]
1457        arr = np.arange(10).reshape(2, 5)
1458        out = arr.compress([0, 1], axis=0)
1459        assert_equal(out, tgt)
1460
1461        tgt = [[1, 3], [6, 8]]
1462        out = arr.compress([0, 1, 0, 1, 0], axis=1)
1463        assert_equal(out, tgt)
1464
1465        tgt = [[1], [6]]
1466        arr = np.arange(10).reshape(2, 5)
1467        out = arr.compress([0, 1], axis=1)
1468        assert_equal(out, tgt)
1469
1470        arr = np.arange(10).reshape(2, 5)
1471        out = arr.compress([0, 1])
1472        assert_equal(out, 1)
1473
1474    def test_choose(self):
1475        x = 2 * np.ones((3,), dtype=int)
1476        y = 3 * np.ones((3,), dtype=int)
1477        x2 = 2 * np.ones((2, 3), dtype=int)
1478        y2 = 3 * np.ones((2, 3), dtype=int)
1479        ind = np.array([0, 0, 1])
1480
1481        A = ind.choose((x, y))
1482        assert_equal(A, [2, 2, 3])
1483
1484        A = ind.choose((x2, y2))
1485        assert_equal(A, [[2, 2, 3], [2, 2, 3]])
1486
1487        A = ind.choose((x, y2))
1488        assert_equal(A, [[2, 2, 3], [2, 2, 3]])
1489
1490        out = np.array(0)
1491        ret = np.choose(np.array(1), [10, 20, 30], out=out)
1492        assert out is ret
1493        assert_equal(out[()], 20)
1494
1495    @xpassIfTorchDynamo  # (reason="choose(..., mode=...) not implemented")
1496    def test_choose_2(self):
1497        # gh-6272 check overlap on out
1498        x = np.arange(5)
1499        y = np.choose([0, 0, 0], [x[:3], x[:3], x[:3]], out=x[1:4], mode="wrap")
1500        assert_equal(y, np.array([0, 1, 2]))
1501
1502    def test_prod(self):
1503        ba = [1, 2, 10, 11, 6, 5, 4]
1504        ba2 = [[1, 2, 3, 4], [5, 6, 7, 9], [10, 3, 4, 5]]
1505
1506        for ctype in [
1507            np.int16,
1508            np.int32,
1509            np.float32,
1510            np.float64,
1511            np.complex64,
1512            np.complex128,
1513        ]:
1514            a = np.array(ba, ctype)
1515            a2 = np.array(ba2, ctype)
1516            if ctype in ["1", "b"]:
1517                assert_raises(ArithmeticError, a.prod)
1518                assert_raises(ArithmeticError, a2.prod, axis=1)
1519            else:
1520                assert_equal(a.prod(axis=0), 26400)
1521                assert_array_equal(a2.prod(axis=0), np.array([50, 36, 84, 180], ctype))
1522                assert_array_equal(a2.prod(axis=-1), np.array([24, 1890, 600], ctype))
1523
1524    def test_repeat(self):
1525        m = np.array([1, 2, 3, 4, 5, 6])
1526        m_rect = m.reshape((2, 3))
1527
1528        A = m.repeat([1, 3, 2, 1, 1, 2])
1529        assert_equal(A, [1, 2, 2, 2, 3, 3, 4, 5, 6, 6])
1530
1531        A = m.repeat(2)
1532        assert_equal(A, [1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6])
1533
1534        A = m_rect.repeat([2, 1], axis=0)
1535        assert_equal(A, [[1, 2, 3], [1, 2, 3], [4, 5, 6]])
1536
1537        A = m_rect.repeat([1, 3, 2], axis=1)
1538        assert_equal(A, [[1, 2, 2, 2, 3, 3], [4, 5, 5, 5, 6, 6]])
1539
1540        A = m_rect.repeat(2, axis=0)
1541        assert_equal(A, [[1, 2, 3], [1, 2, 3], [4, 5, 6], [4, 5, 6]])
1542
1543        A = m_rect.repeat(2, axis=1)
1544        assert_equal(A, [[1, 1, 2, 2, 3, 3], [4, 4, 5, 5, 6, 6]])
1545
1546    @xpassIfTorchDynamo  # (reason="reshape(..., order='F')")
1547    def test_reshape(self):
1548        arr = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
1549
1550        tgt = [[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]
1551        assert_equal(arr.reshape(2, 6), tgt)
1552
1553        tgt = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]
1554        assert_equal(arr.reshape(3, 4), tgt)
1555
1556        tgt = [[1, 10, 8, 6], [4, 2, 11, 9], [7, 5, 3, 12]]
1557        assert_equal(arr.reshape((3, 4), order="F"), tgt)
1558
1559        tgt = [[1, 4, 7, 10], [2, 5, 8, 11], [3, 6, 9, 12]]
1560        assert_equal(arr.T.reshape((3, 4), order="C"), tgt)
1561
1562    def test_round(self):
1563        def check_round(arr, expected, *round_args):
1564            assert_equal(arr.round(*round_args), expected)
1565            # With output array
1566            out = np.zeros_like(arr)
1567            res = arr.round(*round_args, out=out)
1568            assert_equal(out, expected)
1569            assert out is res
1570
1571        check_round(np.array([1.2, 1.5]), [1, 2])
1572        check_round(np.array(1.5), 2)
1573        check_round(np.array([12.2, 15.5]), [10, 20], -1)
1574        check_round(np.array([12.15, 15.51]), [12.2, 15.5], 1)
1575        # Complex rounding
1576        check_round(np.array([4.5 + 1.5j]), [4 + 2j])
1577        check_round(np.array([12.5 + 15.5j]), [10 + 20j], -1)
1578
1579    def test_squeeze(self):
1580        a = np.array([[[1], [2], [3]]])
1581        assert_equal(a.squeeze(), [1, 2, 3])
1582        assert_equal(a.squeeze(axis=(0,)), [[1], [2], [3]])
1583        #  assert_raises(ValueError, a.squeeze, axis=(1,))   # a noop in pytorch
1584        assert_equal(a.squeeze(axis=(2,)), [[1, 2, 3]])
1585
1586    def test_transpose(self):
1587        a = np.array([[1, 2], [3, 4]])
1588        assert_equal(a.transpose(), [[1, 3], [2, 4]])
1589        assert_raises((RuntimeError, ValueError), lambda: a.transpose(0))
1590        assert_raises((RuntimeError, ValueError), lambda: a.transpose(0, 0))
1591        assert_raises((RuntimeError, ValueError), lambda: a.transpose(0, 1, 2))
1592
1593    def test_sort(self):
1594        # test ordering for floats and complex containing nans. It is only
1595        # necessary to check the less-than comparison, so sorts that
1596        # only follow the insertion sort path are sufficient. We only
1597        # test doubles and complex doubles as the logic is the same.
1598
1599        # check doubles
1600        msg = "Test real sort order with nans"
1601        a = np.array([np.nan, 1, 0])
1602        b = np.sort(a)
1603        assert_equal(b, np.flip(a), msg)
1604
1605    @xpassIfTorchDynamo  # (reason="sort complex")
1606    def test_sort_complex_nans(self):
1607        # check complex
1608        msg = "Test complex sort order with nans"
1609        a = np.zeros(9, dtype=np.complex128)
1610        a.real += [np.nan, np.nan, np.nan, 1, 0, 1, 1, 0, 0]
1611        a.imag += [np.nan, 1, 0, np.nan, np.nan, 1, 0, 1, 0]
1612        b = np.sort(a)
1613        assert_equal(b, a[::-1], msg)
1614
1615    # all c scalar sorts use the same code with different types
1616    # so it suffices to run a quick check with one type. The number
1617    # of sorted items must be greater than ~50 to check the actual
1618    # algorithm because quick and merge sort fall over to insertion
1619    # sort for small arrays.
1620
1621    @parametrize("dtype", [np.uint8, np.float16, np.float32, np.float64])
1622    def test_sort_unsigned(self, dtype):
1623        a = np.arange(101, dtype=dtype)
1624        b = np.flip(a)
1625        for kind in self.sort_kinds:
1626            msg = f"scalar sort, kind={kind}"
1627            c = a.copy()
1628            c.sort(kind=kind)
1629            assert_equal(c, a, msg)
1630            c = b.copy()
1631            c.sort(kind=kind)
1632            assert_equal(c, a, msg)
1633
1634    @parametrize(
1635        "dtype",
1636        [np.int8, np.int16, np.int32, np.int64, np.float16, np.float32, np.float64],
1637    )
1638    def test_sort_signed(self, dtype):
1639        a = np.arange(-50, 51, dtype=dtype)
1640        b = np.flip(a)
1641        for kind in self.sort_kinds:
1642            msg = f"scalar sort, kind={kind}"
1643            c = a.copy()
1644            c.sort(kind=kind)
1645            assert_equal(c, a, msg)
1646            c = b.copy()
1647            c.sort(kind=kind)
1648            assert_equal(c, a, msg)
1649
1650    @xpassIfTorchDynamo  # (reason="sort complex")
1651    @parametrize("dtype", [np.float32, np.float64])
1652    @parametrize("part", ["real", "imag"])
1653    def test_sort_complex(self, part, dtype):
1654        # test complex sorts. These use the same code as the scalars
1655        # but the compare function differs.
1656        cdtype = {
1657            np.single: np.csingle,
1658            np.double: np.cdouble,
1659        }[dtype]
1660        a = np.arange(-50, 51, dtype=dtype)
1661        b = a[::-1].copy()
1662        ai = (a * (1 + 1j)).astype(cdtype)
1663        bi = (b * (1 + 1j)).astype(cdtype)
1664        setattr(ai, part, 1)
1665        setattr(bi, part, 1)
1666        for kind in self.sort_kinds:
1667            msg = f"complex sort, {part} part == 1, kind={kind}"
1668            c = ai.copy()
1669            c.sort(kind=kind)
1670            assert_equal(c, ai, msg)
1671            c = bi.copy()
1672            c.sort(kind=kind)
1673            assert_equal(c, ai, msg)
1674
1675    def test_sort_axis(self):
1676        # check axis handling. This should be the same for all type
1677        # specific sorts, so we only check it for one type and one kind
1678        a = np.array([[3, 2], [1, 0]])
1679        b = np.array([[1, 0], [3, 2]])
1680        c = np.array([[2, 3], [0, 1]])
1681        d = a.copy()
1682        d.sort(axis=0)
1683        assert_equal(d, b, "test sort with axis=0")
1684        d = a.copy()
1685        d.sort(axis=1)
1686        assert_equal(d, c, "test sort with axis=1")
1687        d = a.copy()
1688        d.sort()
1689        assert_equal(d, c, "test sort with default axis")
1690
1691    def test_sort_size_0(self):
1692        # check axis handling for multidimensional empty arrays
1693        a = np.array([])
1694        a = a.reshape(3, 2, 1, 0)
1695        for axis in range(-a.ndim, a.ndim):
1696            msg = f"test empty array sort with axis={axis}"
1697            assert_equal(np.sort(a, axis=axis), a, msg)
1698        msg = "test empty array sort with axis=None"
1699        assert_equal(np.sort(a, axis=None), a.ravel(), msg)
1700
1701    @skip(reason="waaay tooo sloooow")
1702    def test_sort_degraded(self):
1703        # test degraded dataset would take minutes to run with normal qsort
1704        d = np.arange(1000000)
1705        do = d.copy()
1706        x = d
1707        # create a median of 3 killer where each median is the sorted second
1708        # last element of the quicksort partition
1709        while x.size > 3:
1710            mid = x.size // 2
1711            x[mid], x[-2] = x[-2], x[mid]
1712            x = x[:-2]
1713
1714        assert_equal(np.sort(d), do)
1715        assert_equal(d[np.argsort(d)], do)
1716
1717    @xfail  # (reason="order='F'")
1718    def test_copy(self):
1719        def assert_fortran(arr):
1720            assert_(arr.flags.fortran)
1721            assert_(arr.flags.f_contiguous)
1722            assert_(not arr.flags.c_contiguous)
1723
1724        def assert_c(arr):
1725            assert_(not arr.flags.fortran)
1726            assert_(not arr.flags.f_contiguous)
1727            assert_(arr.flags.c_contiguous)
1728
1729        a = np.empty((2, 2), order="F")
1730        # Test copying a Fortran array
1731        assert_c(a.copy())
1732        assert_c(a.copy("C"))
1733        assert_fortran(a.copy("F"))
1734        assert_fortran(a.copy("A"))
1735
1736        # Now test starting with a C array.
1737        a = np.empty((2, 2), order="C")
1738        assert_c(a.copy())
1739        assert_c(a.copy("C"))
1740        assert_fortran(a.copy("F"))
1741        assert_c(a.copy("A"))
1742
1743    @skip(reason="no .ctypes attribute")
1744    @parametrize("dtype", [np.int32])
1745    def test__deepcopy__(self, dtype):
1746        # Force the entry of NULLs into array
1747        a = np.empty(4, dtype=dtype)
1748        ctypes.memset(a.ctypes.data, 0, a.nbytes)
1749
1750        # Ensure no error is raised, see gh-21833
1751        b = a.__deepcopy__({})
1752
1753        a[0] = 42
1754        with pytest.raises(AssertionError):
1755            assert_array_equal(a, b)
1756
1757    def test_argsort(self):
1758        # all c scalar argsorts use the same code with different types
1759        # so it suffices to run a quick check with one type. The number
1760        # of sorted items must be greater than ~50 to check the actual
1761        # algorithm because quick and merge sort fall over to insertion
1762        # sort for small arrays.
1763
1764        for dtype in [np.int32, np.uint8, np.float32]:
1765            a = np.arange(101, dtype=dtype)
1766            b = np.flip(a)
1767            for kind in self.sort_kinds:
1768                msg = f"scalar argsort, kind={kind}, dtype={dtype}"
1769                assert_equal(a.copy().argsort(kind=kind), a, msg)
1770                assert_equal(b.copy().argsort(kind=kind), b, msg)
1771
1772    @skip(reason="argsort complex")
1773    def test_argsort_complex(self):
1774        a = np.arange(101, dtype=np.float32)
1775        b = np.flip(a)
1776
1777        # test complex argsorts. These use the same code as the scalars
1778        # but the compare function differs.
1779        ai = a * 1j + 1
1780        bi = b * 1j + 1
1781        for kind in self.sort_kinds:
1782            msg = f"complex argsort, kind={kind}"
1783            assert_equal(ai.copy().argsort(kind=kind), a, msg)
1784            assert_equal(bi.copy().argsort(kind=kind), b, msg)
1785        ai = a + 1j
1786        bi = b + 1j
1787        for kind in self.sort_kinds:
1788            msg = f"complex argsort, kind={kind}"
1789            assert_equal(ai.copy().argsort(kind=kind), a, msg)
1790            assert_equal(bi.copy().argsort(kind=kind), b, msg)
1791
1792        # test argsort of complex arrays requiring byte-swapping, gh-5441
1793        for endianness in "<>":
1794            for dt in np.typecodes["Complex"]:
1795                arr = np.array([1 + 3.0j, 2 + 2.0j, 3 + 1.0j], dtype=endianness + dt)
1796                msg = f"byte-swapped complex argsort, dtype={dt}"
1797                assert_equal(arr.argsort(), np.arange(len(arr), dtype=np.intp), msg)
1798
1799    @xpassIfTorchDynamo  # (reason="argsort axis TODO")
1800    def test_argsort_axis(self):
1801        # check axis handling. This should be the same for all type
1802        # specific argsorts, so we only check it for one type and one kind
1803        a = np.array([[3, 2], [1, 0]])
1804        b = np.array([[1, 1], [0, 0]])
1805        c = np.array([[1, 0], [1, 0]])
1806        assert_equal(a.copy().argsort(axis=0), b)
1807        assert_equal(a.copy().argsort(axis=1), c)
1808        assert_equal(a.copy().argsort(), c)
1809
1810        # check axis handling for multidimensional empty arrays
1811        a = np.array([])
1812        a = a.reshape(3, 2, 1, 0)
1813        for axis in range(-a.ndim, a.ndim):
1814            msg = f"test empty array argsort with axis={axis}"
1815            assert_equal(np.argsort(a, axis=axis), np.zeros_like(a, dtype=np.intp), msg)
1816        msg = "test empty array argsort with axis=None"
1817        assert_equal(
1818            np.argsort(a, axis=None), np.zeros_like(a.ravel(), dtype=np.intp), msg
1819        )
1820
1821        # check that stable argsorts are stable
1822        r = np.arange(100)
1823        # scalars
1824        a = np.zeros(100)
1825        assert_equal(a.argsort(kind="m"), r)
1826        # complex
1827        a = np.zeros(100, dtype=complex)
1828        assert_equal(a.argsort(kind="m"), r)
1829        # string
1830        a = np.array(["aaaaaaaaa" for i in range(100)])
1831        assert_equal(a.argsort(kind="m"), r)
1832        # unicode
1833        a = np.array(["aaaaaaaaa" for i in range(100)], dtype=np.str_)
1834        assert_equal(a.argsort(kind="m"), r)
1835
1836    @xpassIfTorchDynamo  # (reason="TODO: searchsorted with nans differs in pytorch")
1837    @parametrize(
1838        "a",
1839        [
1840            subtest(np.array([0, 1, np.nan], dtype=np.float16), name="f16"),
1841            subtest(np.array([0, 1, np.nan], dtype=np.float32), name="f32"),
1842            subtest(np.array([0, 1, np.nan]), name="default_dtype"),
1843        ],
1844    )
1845    def test_searchsorted_floats(self, a):
1846        # test for floats arrays containing nans. Explicitly test
1847        # half, single, and double precision floats to verify that
1848        # the NaN-handling is correct.
1849        msg = f"Test real ({a.dtype}) searchsorted with nans, side='l'"
1850        b = a.searchsorted(a, side="left")
1851        assert_equal(b, np.arange(3), msg)
1852        msg = f"Test real ({a.dtype}) searchsorted with nans, side='r'"
1853        b = a.searchsorted(a, side="right")
1854        assert_equal(b, np.arange(1, 4), msg)
1855        # check keyword arguments
1856        a.searchsorted(v=1)
1857        x = np.array([0, 1, np.nan], dtype="float32")
1858        y = np.searchsorted(x, x[-1])
1859        assert_equal(y, 2)
1860
1861    @xfail  # (
1862    #    reason="'searchsorted_out_cpu' not implemented for 'ComplexDouble'"
1863    # )
1864    def test_searchsorted_complex(self):
1865        # test for complex arrays containing nans.
1866        # The search sorted routines use the compare functions for the
1867        # array type, so this checks if that is consistent with the sort
1868        # order.
1869        # check double complex
1870        a = np.zeros(9, dtype=np.complex128)
1871        a.real += [0, 0, 1, 1, 0, 1, np.nan, np.nan, np.nan]
1872        a.imag += [0, 1, 0, 1, np.nan, np.nan, 0, 1, np.nan]
1873        msg = "Test complex searchsorted with nans, side='l'"
1874        b = a.searchsorted(a, side="left")
1875        assert_equal(b, np.arange(9), msg)
1876        msg = "Test complex searchsorted with nans, side='r'"
1877        b = a.searchsorted(a, side="right")
1878        assert_equal(b, np.arange(1, 10), msg)
1879        msg = "Test searchsorted with little endian, side='l'"
1880        a = np.array([0, 128], dtype="<i4")
1881        b = a.searchsorted(np.array(128, dtype="<i4"))
1882        assert_equal(b, 1, msg)
1883        msg = "Test searchsorted with big endian, side='l'"
1884        a = np.array([0, 128], dtype=">i4")
1885        b = a.searchsorted(np.array(128, dtype=">i4"))
1886        assert_equal(b, 1, msg)
1887
1888    def test_searchsorted_n_elements(self):
1889        # Check 0 elements
1890        a = np.ones(0)
1891        b = a.searchsorted([0, 1, 2], "left")
1892        assert_equal(b, [0, 0, 0])
1893        b = a.searchsorted([0, 1, 2], "right")
1894        assert_equal(b, [0, 0, 0])
1895        a = np.ones(1)
1896        # Check 1 element
1897        b = a.searchsorted([0, 1, 2], "left")
1898        assert_equal(b, [0, 0, 1])
1899        b = a.searchsorted([0, 1, 2], "right")
1900        assert_equal(b, [0, 1, 1])
1901        # Check all elements equal
1902        a = np.ones(2)
1903        b = a.searchsorted([0, 1, 2], "left")
1904        assert_equal(b, [0, 0, 2])
1905        b = a.searchsorted([0, 1, 2], "right")
1906        assert_equal(b, [0, 2, 2])
1907
1908    @xpassIfTorchDynamo  # (
1909    #    reason="RuntimeError: self.storage_offset() must be divisible by 8"
1910    # )
1911    def test_searchsorted_unaligned_array(self):
1912        # Test searching unaligned array
1913        a = np.arange(10)
1914        aligned = np.empty(a.itemsize * a.size + 1, dtype="uint8")
1915        unaligned = aligned[1:].view(a.dtype)
1916        unaligned[:] = a
1917        # Test searching unaligned array
1918        b = unaligned.searchsorted(a, "left")
1919        assert_equal(b, a)
1920        b = unaligned.searchsorted(a, "right")
1921        assert_equal(b, a + 1)
1922        # Test searching for unaligned keys
1923        b = a.searchsorted(unaligned, "left")
1924        assert_equal(b, a)
1925        b = a.searchsorted(unaligned, "right")
1926        assert_equal(b, a + 1)
1927
1928    def test_searchsorted_resetting(self):
1929        # Test smart resetting of binsearch indices
1930        a = np.arange(5)
1931        b = a.searchsorted([6, 5, 4], "left")
1932        assert_equal(b, [5, 5, 4])
1933        b = a.searchsorted([6, 5, 4], "right")
1934        assert_equal(b, [5, 5, 5])
1935
1936    def test_searchsorted_type_specific(self):
1937        # Test all type specific binary search functions
1938        types = "".join((np.typecodes["AllInteger"], np.typecodes["Float"]))
1939        for dt in types:
1940            if dt == "?":
1941                a = np.arange(2, dtype=dt)
1942                out = np.arange(2)
1943            else:
1944                a = np.arange(0, 5, dtype=dt)
1945                out = np.arange(5)
1946            b = a.searchsorted(a, "left")
1947            assert_equal(b, out)
1948            b = a.searchsorted(a, "right")
1949            assert_equal(b, out + 1)
1950
1951    @xpassIfTorchDynamo  # (reason="ndarray ctor")
1952    def test_searchsorted_type_specific_2(self):
1953        # Test all type specific binary search functions
1954        types = "".join((np.typecodes["AllInteger"], np.typecodes["AllFloat"], "?"))
1955        for dt in types:
1956            if dt == "?":
1957                a = np.arange(2, dtype=dt)
1958                out = np.arange(2)
1959            else:
1960                a = np.arange(0, 5, dtype=dt)
1961                out = np.arange(5)
1962
1963            # Test empty array, use a fresh array to get warnings in
1964            # valgrind if access happens.
1965            e = np.ndarray(shape=0, buffer=b"", dtype=dt)
1966            b = e.searchsorted(a, "left")
1967            assert_array_equal(b, np.zeros(len(a), dtype=np.intp))
1968            b = a.searchsorted(e, "left")
1969            assert_array_equal(b, np.zeros(0, dtype=np.intp))
1970
1971    def test_searchsorted_with_invalid_sorter(self):
1972        a = np.array([5, 2, 1, 3, 4])
1973        s = np.argsort(a)
1974        assert_raises((TypeError, RuntimeError), np.searchsorted, a, 0, sorter=[1.1])
1975        assert_raises(
1976            (ValueError, RuntimeError), np.searchsorted, a, 0, sorter=[1, 2, 3, 4]
1977        )
1978        assert_raises(
1979            (ValueError, RuntimeError), np.searchsorted, a, 0, sorter=[1, 2, 3, 4, 5, 6]
1980        )
1981
1982        # bounds check : XXX torch does not raise
1983        # assert_raises(ValueError, np.searchsorted, a, 4, sorter=[0, 1, 2, 3, 5])
1984        # assert_raises(ValueError, np.searchsorted, a, 0, sorter=[-1, 0, 1, 2, 3])
1985        # assert_raises(ValueError, np.searchsorted, a, 0, sorter=[4, 0, -1, 2, 3])
1986
1987    @xpassIfTorchDynamo  # (reason="self.storage_offset() must be divisible by 8")
1988    def test_searchsorted_with_sorter(self):
1989        a = np.random.rand(300)
1990        s = a.argsort()
1991        b = np.sort(a)
1992        k = np.linspace(0, 1, 20)
1993        assert_equal(b.searchsorted(k), a.searchsorted(k, sorter=s))
1994
1995        a = np.array([0, 1, 2, 3, 5] * 20)
1996        s = a.argsort()
1997        k = [0, 1, 2, 3, 5]
1998        expected = [0, 20, 40, 60, 80]
1999        assert_equal(a.searchsorted(k, side="left", sorter=s), expected)
2000        expected = [20, 40, 60, 80, 100]
2001        assert_equal(a.searchsorted(k, side="right", sorter=s), expected)
2002
2003        # Test searching unaligned array
2004        keys = np.arange(10)
2005        a = keys.copy()
2006        np.random.shuffle(s)
2007        s = a.argsort()
2008        aligned = np.empty(a.itemsize * a.size + 1, dtype="uint8")
2009        unaligned = aligned[1:].view(a.dtype)
2010        # Test searching unaligned array
2011        unaligned[:] = a
2012        b = unaligned.searchsorted(keys, "left", s)
2013        assert_equal(b, keys)
2014        b = unaligned.searchsorted(keys, "right", s)
2015        assert_equal(b, keys + 1)
2016        # Test searching for unaligned keys
2017        unaligned[:] = keys
2018        b = a.searchsorted(unaligned, "left", s)
2019        assert_equal(b, keys)
2020        b = a.searchsorted(unaligned, "right", s)
2021        assert_equal(b, keys + 1)
2022
2023        # Test all type specific indirect binary search functions
2024        types = "".join((np.typecodes["AllInteger"], np.typecodes["AllFloat"], "?"))
2025        for dt in types:
2026            if dt == "?":
2027                a = np.array([1, 0], dtype=dt)
2028                # We want the sorter array to be of a type that is different
2029                # from np.intp in all platforms, to check for #4698
2030                s = np.array([1, 0], dtype=np.int16)
2031                out = np.array([1, 0])
2032            else:
2033                a = np.array([3, 4, 1, 2, 0], dtype=dt)
2034                # We want the sorter array to be of a type that is different
2035                # from np.intp in all platforms, to check for #4698
2036                s = np.array([4, 2, 3, 0, 1], dtype=np.int16)
2037                out = np.array([3, 4, 1, 2, 0], dtype=np.intp)
2038            b = a.searchsorted(a, "left", s)
2039            assert_equal(b, out)
2040            b = a.searchsorted(a, "right", s)
2041            assert_equal(b, out + 1)
2042            # Test empty array, use a fresh array to get warnings in
2043            # valgrind if access happens.
2044            e = np.ndarray(shape=0, buffer=b"", dtype=dt)
2045            b = e.searchsorted(a, "left", s[:0])
2046            assert_array_equal(b, np.zeros(len(a), dtype=np.intp))
2047            b = a.searchsorted(e, "left", s)
2048            assert_array_equal(b, np.zeros(0, dtype=np.intp))
2049
2050        # Test non-contiguous sorter array
2051        a = np.array([3, 4, 1, 2, 0])
2052        srt = np.empty((10,), dtype=np.intp)
2053        srt[1::2] = -1
2054        srt[::2] = [4, 2, 3, 0, 1]
2055        s = srt[::2]
2056        out = np.array([3, 4, 1, 2, 0], dtype=np.intp)
2057        b = a.searchsorted(a, "left", s)
2058        assert_equal(b, out)
2059        b = a.searchsorted(a, "right", s)
2060        assert_equal(b, out + 1)
2061
2062    @xpassIfTorchDynamo  # (reason="TODO argpartition")
2063    @parametrize("dtype", "efdFDBbhil?")
2064    def test_argpartition_out_of_range(self, dtype):
2065        # Test out of range values in kth raise an error, gh-5469
2066        d = np.arange(10).astype(dtype=dtype)
2067        assert_raises(ValueError, d.argpartition, 10)
2068        assert_raises(ValueError, d.argpartition, -11)
2069
2070    @xpassIfTorchDynamo  # (reason="TODO partition")
2071    @parametrize("dtype", "efdFDBbhil?")
2072    def test_partition_out_of_range(self, dtype):
2073        # Test out of range values in kth raise an error, gh-5469
2074        d = np.arange(10).astype(dtype=dtype)
2075        assert_raises(ValueError, d.partition, 10)
2076        assert_raises(ValueError, d.partition, -11)
2077
2078    @xpassIfTorchDynamo  # (reason="TODO argpartition")
2079    def test_argpartition_integer(self):
2080        # Test non-integer values in kth raise an error/
2081        d = np.arange(10)
2082        assert_raises(TypeError, d.argpartition, 9.0)
2083        # Test also for generic type argpartition, which uses sorting
2084        # and used to not bound check kth
2085        d_obj = np.arange(10, dtype=object)
2086        assert_raises(TypeError, d_obj.argpartition, 9.0)
2087
2088    @xpassIfTorchDynamo  # (reason="TODO partition")
2089    def test_partition_integer(self):
2090        # Test out of range values in kth raise an error, gh-5469
2091        d = np.arange(10)
2092        assert_raises(TypeError, d.partition, 9.0)
2093        # Test also for generic type partition, which uses sorting
2094        # and used to not bound check kth
2095        d_obj = np.arange(10, dtype=object)
2096        assert_raises(TypeError, d_obj.partition, 9.0)
2097
2098    @xpassIfTorchDynamo  # (reason="TODO partition")
2099    @parametrize("kth_dtype", "Bbhil")
2100    def test_partition_empty_array(self, kth_dtype):
2101        # check axis handling for multidimensional empty arrays
2102        kth = np.array(0, dtype=kth_dtype)[()]
2103        a = np.array([])
2104        a.shape = (3, 2, 1, 0)
2105        for axis in range(-a.ndim, a.ndim):
2106            msg = f"test empty array partition with axis={axis}"
2107            assert_equal(np.partition(a, kth, axis=axis), a, msg)
2108        msg = "test empty array partition with axis=None"
2109        assert_equal(np.partition(a, kth, axis=None), a.ravel(), msg)
2110
2111    @xpassIfTorchDynamo  # (reason="TODO argpartition")
2112    @parametrize("kth_dtype", "Bbhil")
2113    def test_argpartition_empty_array(self, kth_dtype):
2114        # check axis handling for multidimensional empty arrays
2115        kth = np.array(0, dtype=kth_dtype)[()]
2116        a = np.array([])
2117        a.shape = (3, 2, 1, 0)
2118        for axis in range(-a.ndim, a.ndim):
2119            msg = f"test empty array argpartition with axis={axis}"
2120            assert_equal(
2121                np.partition(a, kth, axis=axis), np.zeros_like(a, dtype=np.intp), msg
2122            )
2123        msg = "test empty array argpartition with axis=None"
2124        assert_equal(
2125            np.partition(a, kth, axis=None),
2126            np.zeros_like(a.ravel(), dtype=np.intp),
2127            msg,
2128        )
2129
2130    @xpassIfTorchDynamo  # (reason="TODO partition")
2131    def test_partition(self):
2132        d = np.arange(10)
2133        assert_raises(TypeError, np.partition, d, 2, kind=1)
2134        assert_raises(ValueError, np.partition, d, 2, kind="nonsense")
2135        assert_raises(ValueError, np.argpartition, d, 2, kind="nonsense")
2136        assert_raises(ValueError, d.partition, 2, axis=0, kind="nonsense")
2137        assert_raises(ValueError, d.argpartition, 2, axis=0, kind="nonsense")
2138        for k in ("introselect",):
2139            d = np.array([])
2140            assert_array_equal(np.partition(d, 0, kind=k), d)
2141            assert_array_equal(np.argpartition(d, 0, kind=k), d)
2142            d = np.ones(1)
2143            assert_array_equal(np.partition(d, 0, kind=k)[0], d)
2144            assert_array_equal(
2145                d[np.argpartition(d, 0, kind=k)], np.partition(d, 0, kind=k)
2146            )
2147
2148            # kth not modified
2149            kth = np.array([30, 15, 5])
2150            okth = kth.copy()
2151            np.partition(np.arange(40), kth)
2152            assert_array_equal(kth, okth)
2153
2154            for r in ([2, 1], [1, 2], [1, 1]):
2155                d = np.array(r)
2156                tgt = np.sort(d)
2157                assert_array_equal(np.partition(d, 0, kind=k)[0], tgt[0])
2158                assert_array_equal(np.partition(d, 1, kind=k)[1], tgt[1])
2159                assert_array_equal(
2160                    d[np.argpartition(d, 0, kind=k)], np.partition(d, 0, kind=k)
2161                )
2162                assert_array_equal(
2163                    d[np.argpartition(d, 1, kind=k)], np.partition(d, 1, kind=k)
2164                )
2165                for i in range(d.size):
2166                    d[i:].partition(0, kind=k)
2167                assert_array_equal(d, tgt)
2168
2169            for r in (
2170                [3, 2, 1],
2171                [1, 2, 3],
2172                [2, 1, 3],
2173                [2, 3, 1],
2174                [1, 1, 1],
2175                [1, 2, 2],
2176                [2, 2, 1],
2177                [1, 2, 1],
2178            ):
2179                d = np.array(r)
2180                tgt = np.sort(d)
2181                assert_array_equal(np.partition(d, 0, kind=k)[0], tgt[0])
2182                assert_array_equal(np.partition(d, 1, kind=k)[1], tgt[1])
2183                assert_array_equal(np.partition(d, 2, kind=k)[2], tgt[2])
2184                assert_array_equal(
2185                    d[np.argpartition(d, 0, kind=k)], np.partition(d, 0, kind=k)
2186                )
2187                assert_array_equal(
2188                    d[np.argpartition(d, 1, kind=k)], np.partition(d, 1, kind=k)
2189                )
2190                assert_array_equal(
2191                    d[np.argpartition(d, 2, kind=k)], np.partition(d, 2, kind=k)
2192                )
2193                for i in range(d.size):
2194                    d[i:].partition(0, kind=k)
2195                assert_array_equal(d, tgt)
2196
2197            d = np.ones(50)
2198            assert_array_equal(np.partition(d, 0, kind=k), d)
2199            assert_array_equal(
2200                d[np.argpartition(d, 0, kind=k)], np.partition(d, 0, kind=k)
2201            )
2202
2203            # sorted
2204            d = np.arange(49)
2205            assert_equal(np.partition(d, 5, kind=k)[5], 5)
2206            assert_equal(np.partition(d, 15, kind=k)[15], 15)
2207            assert_array_equal(
2208                d[np.argpartition(d, 5, kind=k)], np.partition(d, 5, kind=k)
2209            )
2210            assert_array_equal(
2211                d[np.argpartition(d, 15, kind=k)], np.partition(d, 15, kind=k)
2212            )
2213
2214            # rsorted
2215            d = np.arange(47)[::-1]
2216            assert_equal(np.partition(d, 6, kind=k)[6], 6)
2217            assert_equal(np.partition(d, 16, kind=k)[16], 16)
2218            assert_array_equal(
2219                d[np.argpartition(d, 6, kind=k)], np.partition(d, 6, kind=k)
2220            )
2221            assert_array_equal(
2222                d[np.argpartition(d, 16, kind=k)], np.partition(d, 16, kind=k)
2223            )
2224
2225            assert_array_equal(np.partition(d, -6, kind=k), np.partition(d, 41, kind=k))
2226            assert_array_equal(
2227                np.partition(d, -16, kind=k), np.partition(d, 31, kind=k)
2228            )
2229            assert_array_equal(
2230                d[np.argpartition(d, -6, kind=k)], np.partition(d, 41, kind=k)
2231            )
2232
2233            # median of 3 killer, O(n^2) on pure median 3 pivot quickselect
2234            # exercises the median of median of 5 code used to keep O(n)
2235            d = np.arange(1000000)
2236            x = np.roll(d, d.size // 2)
2237            mid = x.size // 2 + 1
2238            assert_equal(np.partition(x, mid)[mid], mid)
2239            d = np.arange(1000001)
2240            x = np.roll(d, d.size // 2 + 1)
2241            mid = x.size // 2 + 1
2242            assert_equal(np.partition(x, mid)[mid], mid)
2243
2244            # max
2245            d = np.ones(10)
2246            d[1] = 4
2247            assert_equal(np.partition(d, (2, -1))[-1], 4)
2248            assert_equal(np.partition(d, (2, -1))[2], 1)
2249            assert_equal(d[np.argpartition(d, (2, -1))][-1], 4)
2250            assert_equal(d[np.argpartition(d, (2, -1))][2], 1)
2251            d[1] = np.nan
2252            assert_(np.isnan(d[np.argpartition(d, (2, -1))][-1]))
2253            assert_(np.isnan(np.partition(d, (2, -1))[-1]))
2254
2255            # equal elements
2256            d = np.arange(47) % 7
2257            tgt = np.sort(np.arange(47) % 7)
2258            np.random.shuffle(d)
2259            for i in range(d.size):
2260                assert_equal(np.partition(d, i, kind=k)[i], tgt[i])
2261            assert_array_equal(
2262                d[np.argpartition(d, 6, kind=k)], np.partition(d, 6, kind=k)
2263            )
2264            assert_array_equal(
2265                d[np.argpartition(d, 16, kind=k)], np.partition(d, 16, kind=k)
2266            )
2267            for i in range(d.size):
2268                d[i:].partition(0, kind=k)
2269            assert_array_equal(d, tgt)
2270
2271            d = np.array(
2272                [0, 1, 2, 3, 4, 5, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 9]
2273            )
2274            kth = [0, 3, 19, 20]
2275            assert_equal(np.partition(d, kth, kind=k)[kth], (0, 3, 7, 7))
2276            assert_equal(d[np.argpartition(d, kth, kind=k)][kth], (0, 3, 7, 7))
2277
2278            d = np.array([2, 1])
2279            d.partition(0, kind=k)
2280            assert_raises(ValueError, d.partition, 2)
2281            assert_raises(np.AxisError, d.partition, 3, axis=1)
2282            assert_raises(ValueError, np.partition, d, 2)
2283            assert_raises(np.AxisError, np.partition, d, 2, axis=1)
2284            assert_raises(ValueError, d.argpartition, 2)
2285            assert_raises(np.AxisError, d.argpartition, 3, axis=1)
2286            assert_raises(ValueError, np.argpartition, d, 2)
2287            assert_raises(np.AxisError, np.argpartition, d, 2, axis=1)
2288            d = np.arange(10).reshape((2, 5))
2289            d.partition(1, axis=0, kind=k)
2290            d.partition(4, axis=1, kind=k)
2291            np.partition(d, 1, axis=0, kind=k)
2292            np.partition(d, 4, axis=1, kind=k)
2293            np.partition(d, 1, axis=None, kind=k)
2294            np.partition(d, 9, axis=None, kind=k)
2295            d.argpartition(1, axis=0, kind=k)
2296            d.argpartition(4, axis=1, kind=k)
2297            np.argpartition(d, 1, axis=0, kind=k)
2298            np.argpartition(d, 4, axis=1, kind=k)
2299            np.argpartition(d, 1, axis=None, kind=k)
2300            np.argpartition(d, 9, axis=None, kind=k)
2301            assert_raises(ValueError, d.partition, 2, axis=0)
2302            assert_raises(ValueError, d.partition, 11, axis=1)
2303            assert_raises(TypeError, d.partition, 2, axis=None)
2304            assert_raises(ValueError, np.partition, d, 9, axis=1)
2305            assert_raises(ValueError, np.partition, d, 11, axis=None)
2306            assert_raises(ValueError, d.argpartition, 2, axis=0)
2307            assert_raises(ValueError, d.argpartition, 11, axis=1)
2308            assert_raises(ValueError, np.argpartition, d, 9, axis=1)
2309            assert_raises(ValueError, np.argpartition, d, 11, axis=None)
2310
2311            td = [
2312                (dt, s) for dt in [np.int32, np.float32, np.complex64] for s in (9, 16)
2313            ]
2314            for dt, s in td:
2315                aae = assert_array_equal
2316                at = assert_
2317
2318                d = np.arange(s, dtype=dt)
2319                np.random.shuffle(d)
2320                d1 = np.tile(np.arange(s, dtype=dt), (4, 1))
2321                map(np.random.shuffle, d1)
2322                d0 = np.transpose(d1)
2323                for i in range(d.size):
2324                    p = np.partition(d, i, kind=k)
2325                    assert_equal(p[i], i)
2326                    # all before are smaller
2327                    assert_array_less(p[:i], p[i])
2328                    # all after are larger
2329                    assert_array_less(p[i], p[i + 1 :])
2330                    aae(p, d[np.argpartition(d, i, kind=k)])
2331
2332                    p = np.partition(d1, i, axis=1, kind=k)
2333                    aae(p[:, i], np.array([i] * d1.shape[0], dtype=dt))
2334                    # array_less does not seem to work right
2335                    at(
2336                        (p[:, :i].T <= p[:, i]).all(),
2337                        msg="%d: %r <= %r" % (i, p[:, i], p[:, :i].T),
2338                    )
2339                    at(
2340                        (p[:, i + 1 :].T > p[:, i]).all(),
2341                        msg="%d: %r < %r" % (i, p[:, i], p[:, i + 1 :].T),
2342                    )
2343                    aae(
2344                        p,
2345                        d1[
2346                            np.arange(d1.shape[0])[:, None],
2347                            np.argpartition(d1, i, axis=1, kind=k),
2348                        ],
2349                    )
2350
2351                    p = np.partition(d0, i, axis=0, kind=k)
2352                    aae(p[i, :], np.array([i] * d1.shape[0], dtype=dt))
2353                    # array_less does not seem to work right
2354                    at(
2355                        (p[:i, :] <= p[i, :]).all(),
2356                        msg="%d: %r <= %r" % (i, p[i, :], p[:i, :]),
2357                    )
2358                    at(
2359                        (p[i + 1 :, :] > p[i, :]).all(),
2360                        msg="%d: %r < %r" % (i, p[i, :], p[:, i + 1 :]),
2361                    )
2362                    aae(
2363                        p,
2364                        d0[
2365                            np.argpartition(d0, i, axis=0, kind=k),
2366                            np.arange(d0.shape[1])[None, :],
2367                        ],
2368                    )
2369
2370                    # check inplace
2371                    dc = d.copy()
2372                    dc.partition(i, kind=k)
2373                    assert_equal(dc, np.partition(d, i, kind=k))
2374                    dc = d0.copy()
2375                    dc.partition(i, axis=0, kind=k)
2376                    assert_equal(dc, np.partition(d0, i, axis=0, kind=k))
2377                    dc = d1.copy()
2378                    dc.partition(i, axis=1, kind=k)
2379                    assert_equal(dc, np.partition(d1, i, axis=1, kind=k))
2380
2381    def assert_partitioned(self, d, kth):
2382        prev = 0
2383        for k in np.sort(kth):
2384            assert_array_less(d[prev:k], d[k], err_msg="kth %d" % k)
2385            assert_(
2386                (d[k:] >= d[k]).all(),
2387                msg="kth %d, %r not greater equal %d" % (k, d[k:], d[k]),
2388            )
2389            prev = k + 1
2390
2391    @xpassIfTorchDynamo  # (reason="TODO partition")
2392    def test_partition_iterative(self):
2393        d = np.arange(17)
2394        kth = (0, 1, 2, 429, 231)
2395        assert_raises(ValueError, d.partition, kth)
2396        assert_raises(ValueError, d.argpartition, kth)
2397        d = np.arange(10).reshape((2, 5))
2398        assert_raises(ValueError, d.partition, kth, axis=0)
2399        assert_raises(ValueError, d.partition, kth, axis=1)
2400        assert_raises(ValueError, np.partition, d, kth, axis=1)
2401        assert_raises(ValueError, np.partition, d, kth, axis=None)
2402
2403        d = np.array([3, 4, 2, 1])
2404        p = np.partition(d, (0, 3))
2405        self.assert_partitioned(p, (0, 3))
2406        self.assert_partitioned(d[np.argpartition(d, (0, 3))], (0, 3))
2407
2408        assert_array_equal(p, np.partition(d, (-3, -1)))
2409        assert_array_equal(p, d[np.argpartition(d, (-3, -1))])
2410
2411        d = np.arange(17)
2412        np.random.shuffle(d)
2413        d.partition(range(d.size))
2414        assert_array_equal(np.arange(17), d)
2415        np.random.shuffle(d)
2416        assert_array_equal(np.arange(17), d[d.argpartition(range(d.size))])
2417
2418        # test unsorted kth
2419        d = np.arange(17)
2420        np.random.shuffle(d)
2421        keys = np.array([1, 3, 8, -2])
2422        np.random.shuffle(d)
2423        p = np.partition(d, keys)
2424        self.assert_partitioned(p, keys)
2425        p = d[np.argpartition(d, keys)]
2426        self.assert_partitioned(p, keys)
2427        np.random.shuffle(keys)
2428        assert_array_equal(np.partition(d, keys), p)
2429        assert_array_equal(d[np.argpartition(d, keys)], p)
2430
2431        # equal kth
2432        d = np.arange(20)[::-1]
2433        self.assert_partitioned(np.partition(d, [5] * 4), [5])
2434        self.assert_partitioned(np.partition(d, [5] * 4 + [6, 13]), [5] * 4 + [6, 13])
2435        self.assert_partitioned(d[np.argpartition(d, [5] * 4)], [5])
2436        self.assert_partitioned(
2437            d[np.argpartition(d, [5] * 4 + [6, 13])], [5] * 4 + [6, 13]
2438        )
2439
2440        d = np.arange(12)
2441        np.random.shuffle(d)
2442        d1 = np.tile(np.arange(12), (4, 1))
2443        map(np.random.shuffle, d1)
2444        d0 = np.transpose(d1)
2445
2446        kth = (1, 6, 7, -1)
2447        p = np.partition(d1, kth, axis=1)
2448        pa = d1[np.arange(d1.shape[0])[:, None], d1.argpartition(kth, axis=1)]
2449        assert_array_equal(p, pa)
2450        for i in range(d1.shape[0]):
2451            self.assert_partitioned(p[i, :], kth)
2452        p = np.partition(d0, kth, axis=0)
2453        pa = d0[np.argpartition(d0, kth, axis=0), np.arange(d0.shape[1])[None, :]]
2454        assert_array_equal(p, pa)
2455        for i in range(d0.shape[1]):
2456            self.assert_partitioned(p[:, i], kth)
2457
2458    @xpassIfTorchDynamo  # (reason="TODO partition")
2459    def test_partition_fuzz(self):
2460        # a few rounds of random data testing
2461        for j in range(10, 30):
2462            for i in range(1, j - 2):
2463                d = np.arange(j)
2464                np.random.shuffle(d)
2465                d = d % np.random.randint(2, 30)
2466                idx = np.random.randint(d.size)
2467                kth = [0, idx, i, i + 1]
2468                tgt = np.sort(d)[kth]
2469                assert_array_equal(
2470                    np.partition(d, kth)[kth],
2471                    tgt,
2472                    err_msg=f"data: {d!r}\n kth: {kth!r}",
2473                )
2474
2475    @xpassIfTorchDynamo  # (reason="TODO partition")
2476    @parametrize("kth_dtype", "Bbhil")
2477    def test_argpartition_gh5524(self, kth_dtype):
2478        #  A test for functionality of argpartition on lists.
2479        kth = np.array(1, dtype=kth_dtype)[()]
2480        d = [6, 7, 3, 2, 9, 0]
2481        p = np.argpartition(d, kth)
2482        self.assert_partitioned(np.array(d)[p], [1])
2483
2484    @xpassIfTorchDynamo  # (reason="TODO order='F'")
2485    def test_flatten(self):
2486        x0 = np.array([[1, 2, 3], [4, 5, 6]], np.int32)
2487        x1 = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], np.int32)
2488        y0 = np.array([1, 2, 3, 4, 5, 6], np.int32)
2489        y0f = np.array([1, 4, 2, 5, 3, 6], np.int32)
2490        y1 = np.array([1, 2, 3, 4, 5, 6, 7, 8], np.int32)
2491        y1f = np.array([1, 5, 3, 7, 2, 6, 4, 8], np.int32)
2492        assert_equal(x0.flatten(), y0)
2493        assert_equal(x0.flatten("F"), y0f)
2494        assert_equal(x0.flatten("F"), x0.T.flatten())
2495        assert_equal(x1.flatten(), y1)
2496        assert_equal(x1.flatten("F"), y1f)
2497        assert_equal(x1.flatten("F"), x1.T.flatten())
2498
2499    @parametrize("func", (np.dot, np.matmul))
2500    def test_arr_mult(self, func):
2501        a = np.array([[1, 0], [0, 1]])
2502        b = np.array([[0, 1], [1, 0]])
2503        c = np.array([[9, 1], [1, -9]])
2504        d = np.arange(24).reshape(4, 6)
2505        ddt = np.array(
2506            [
2507                [55, 145, 235, 325],
2508                [145, 451, 757, 1063],
2509                [235, 757, 1279, 1801],
2510                [325, 1063, 1801, 2539],
2511            ]
2512        )
2513        dtd = np.array(
2514            [
2515                [504, 540, 576, 612, 648, 684],
2516                [540, 580, 620, 660, 700, 740],
2517                [576, 620, 664, 708, 752, 796],
2518                [612, 660, 708, 756, 804, 852],
2519                [648, 700, 752, 804, 856, 908],
2520                [684, 740, 796, 852, 908, 964],
2521            ]
2522        )
2523
2524        # gemm vs syrk optimizations
2525        for et in [np.float32, np.float64, np.complex64, np.complex128]:
2526            eaf = a.astype(et)
2527            assert_equal(func(eaf, eaf), eaf)
2528            assert_equal(func(eaf.T, eaf), eaf)
2529            assert_equal(func(eaf, eaf.T), eaf)
2530            assert_equal(func(eaf.T, eaf.T), eaf)
2531            assert_equal(func(eaf.T.copy(), eaf), eaf)
2532            assert_equal(func(eaf, eaf.T.copy()), eaf)
2533            assert_equal(func(eaf.T.copy(), eaf.T.copy()), eaf)
2534
2535        # syrk validations
2536        for et in [np.float32, np.float64, np.complex64, np.complex128]:
2537            eaf = a.astype(et)
2538            ebf = b.astype(et)
2539            assert_equal(func(ebf, ebf), eaf)
2540            assert_equal(func(ebf.T, ebf), eaf)
2541            assert_equal(func(ebf, ebf.T), eaf)
2542            assert_equal(func(ebf.T, ebf.T), eaf)
2543        # syrk - different shape
2544        for et in [np.float32, np.float64, np.complex64, np.complex128]:
2545            edf = d.astype(et)
2546            eddtf = ddt.astype(et)
2547            edtdf = dtd.astype(et)
2548            assert_equal(func(edf, edf.T), eddtf)
2549            assert_equal(func(edf.T, edf), edtdf)
2550
2551            assert_equal(
2552                func(edf[: edf.shape[0] // 2, :], edf[::2, :].T),
2553                func(edf[: edf.shape[0] // 2, :].copy(), edf[::2, :].T.copy()),
2554            )
2555            assert_equal(
2556                func(edf[::2, :], edf[: edf.shape[0] // 2, :].T),
2557                func(edf[::2, :].copy(), edf[: edf.shape[0] // 2, :].T.copy()),
2558            )
2559
2560    @skip(reason="dot/matmul with negative strides")
2561    @parametrize("func", (np.dot, np.matmul))
2562    def test_arr_mult_2(self, func):
2563        # syrk - different shape, stride, and view validations
2564        for et in [np.float32, np.float64, np.complex64, np.complex128]:
2565            edf = d.astype(et)
2566            assert_equal(
2567                func(edf[::-1, :], edf.T), func(edf[::-1, :].copy(), edf.T.copy())
2568            )
2569            assert_equal(
2570                func(edf[:, ::-1], edf.T), func(edf[:, ::-1].copy(), edf.T.copy())
2571            )
2572            assert_equal(func(edf, edf[::-1, :].T), func(edf, edf[::-1, :].T.copy()))
2573            assert_equal(func(edf, edf[:, ::-1].T), func(edf, edf[:, ::-1].T.copy()))
2574
2575    @parametrize("func", (np.dot, np.matmul))
2576    @parametrize("dtype", "ifdFD")
2577    def test_no_dgemv(self, func, dtype):
2578        # check vector arg for contiguous before gemv
2579        # gh-12156
2580        a = np.arange(8.0, dtype=dtype).reshape(2, 4)
2581        b = np.broadcast_to(1.0, (4, 1))
2582        ret1 = func(a, b)
2583        ret2 = func(a, b.copy())
2584        assert_equal(ret1, ret2)
2585
2586        ret1 = func(b.T, a.T)
2587        ret2 = func(b.T.copy(), a.T)
2588        assert_equal(ret1, ret2)
2589
2590    @skip(reason="__array_interface__")
2591    @parametrize("func", (np.dot, np.matmul))
2592    @parametrize("dtype", "ifdFD")
2593    def test_no_dgemv_2(self, func, dtype):
2594        # check for unaligned data
2595        dt = np.dtype(dtype)
2596        a = np.zeros(8 * dt.itemsize // 2 + 1, dtype="int16")[1:].view(dtype)
2597        a = a.reshape(2, 4)
2598        b = a[0]
2599        # make sure it is not aligned
2600        assert_(a.__array_interface__["data"][0] % dt.itemsize != 0)
2601        ret1 = func(a, b)
2602        ret2 = func(a.copy(), b.copy())
2603        assert_equal(ret1, ret2)
2604
2605        ret1 = func(b.T, a.T)
2606        ret2 = func(b.T.copy(), a.T.copy())
2607        assert_equal(ret1, ret2)
2608
2609    def test_dot(self):
2610        a = np.array([[1, 0], [0, 1]])
2611        b = np.array([[0, 1], [1, 0]])
2612        c = np.array([[9, 1], [1, -9]])
2613        # function versus methods
2614        assert_equal(np.dot(a, b), a.dot(b))
2615        assert_equal(np.dot(np.dot(a, b), c), a.dot(b).dot(c))
2616
2617        # test passing in an output array
2618        c = np.zeros_like(a)
2619        a.dot(b, c)
2620        assert_equal(c, np.dot(a, b))
2621
2622        # test keyword args
2623        c = np.zeros_like(a)
2624        a.dot(b=b, out=c)
2625        assert_equal(c, np.dot(a, b))
2626
2627    @xpassIfTorchDynamo  # (reason="_aligned_zeros")
2628    def test_dot_out_mem_overlap(self):
2629        np.random.seed(1)
2630
2631        # Test BLAS and non-BLAS code paths, including all dtypes
2632        # that dot() supports
2633        dtypes = [np.dtype(code) for code in np.typecodes["All"] if code not in "USVM"]
2634        for dtype in dtypes:
2635            a = np.random.rand(3, 3).astype(dtype)
2636
2637            # Valid dot() output arrays must be aligned
2638            b = _aligned_zeros((3, 3), dtype=dtype)
2639            b[...] = np.random.rand(3, 3)
2640
2641            y = np.dot(a, b)
2642            x = np.dot(a, b, out=b)
2643            assert_equal(x, y, err_msg=repr(dtype))
2644
2645            # Check invalid output array
2646            assert_raises(ValueError, np.dot, a, b, out=b[::2])
2647            assert_raises(ValueError, np.dot, a, b, out=b.T)
2648
2649    @xpassIfTorchDynamo  # (reason="TODO: overlapping memor in matmul")
2650    def test_matmul_out(self):
2651        # overlapping memory
2652        a = np.arange(18).reshape(2, 3, 3)
2653        b = np.matmul(a, a)
2654        c = np.matmul(a, a, out=a)
2655        assert_(c is a)
2656        assert_equal(c, b)
2657        a = np.arange(18).reshape(2, 3, 3)
2658        c = np.matmul(a, a, out=a[::-1, ...])
2659        assert_(c.base is a.base)
2660        assert_equal(c, b)
2661
2662    def test_diagonal(self):
2663        a = np.arange(12).reshape((3, 4))
2664        assert_equal(a.diagonal(), [0, 5, 10])
2665        assert_equal(a.diagonal(0), [0, 5, 10])
2666        assert_equal(a.diagonal(1), [1, 6, 11])
2667        assert_equal(a.diagonal(-1), [4, 9])
2668        assert_raises(np.AxisError, a.diagonal, axis1=0, axis2=5)
2669        assert_raises(np.AxisError, a.diagonal, axis1=5, axis2=0)
2670        assert_raises(np.AxisError, a.diagonal, axis1=5, axis2=5)
2671        assert_raises((ValueError, RuntimeError), a.diagonal, axis1=1, axis2=1)
2672
2673        b = np.arange(8).reshape((2, 2, 2))
2674        assert_equal(b.diagonal(), [[0, 6], [1, 7]])
2675        assert_equal(b.diagonal(0), [[0, 6], [1, 7]])
2676        assert_equal(b.diagonal(1), [[2], [3]])
2677        assert_equal(b.diagonal(-1), [[4], [5]])
2678        assert_raises((ValueError, RuntimeError), b.diagonal, axis1=0, axis2=0)
2679        assert_equal(b.diagonal(0, 1, 2), [[0, 3], [4, 7]])
2680        assert_equal(b.diagonal(0, 0, 1), [[0, 6], [1, 7]])
2681        assert_equal(b.diagonal(offset=1, axis1=0, axis2=2), [[1], [3]])
2682        # Order of axis argument doesn't matter:
2683        assert_equal(b.diagonal(0, 2, 1), [[0, 3], [4, 7]])
2684
2685    @xfail  # (reason="no readonly views")
2686    def test_diagonal_view_notwriteable(self):
2687        a = np.eye(3).diagonal()
2688        assert_(not a.flags.writeable)
2689        assert_(not a.flags.owndata)
2690
2691        a = np.diagonal(np.eye(3))
2692        assert_(not a.flags.writeable)
2693        assert_(not a.flags.owndata)
2694
2695        a = np.diag(np.eye(3))
2696        assert_(not a.flags.writeable)
2697        assert_(not a.flags.owndata)
2698
2699    def test_diagonal_memleak(self):
2700        # Regression test for a bug that crept in at one point
2701        a = np.zeros((100, 100))
2702        if HAS_REFCOUNT:
2703            assert_(sys.getrefcount(a) < 50)
2704        for i in range(100):
2705            a.diagonal()
2706        if HAS_REFCOUNT:
2707            assert_(sys.getrefcount(a) < 50)
2708
2709    def test_size_zero_memleak(self):
2710        # Regression test for issue 9615
2711        # Exercises a special-case code path for dot products of length
2712        # zero in cblasfuncs (making it is specific to floating dtypes).
2713        a = np.array([], dtype=np.float64)
2714        x = np.array(2.0)
2715        for _ in range(100):
2716            np.dot(a, a, out=x)
2717        if HAS_REFCOUNT:
2718            assert_(sys.getrefcount(x) < 50)
2719
2720    def test_trace(self):
2721        a = np.arange(12).reshape((3, 4))
2722        assert_equal(a.trace(), 15)
2723        assert_equal(a.trace(0), 15)
2724        assert_equal(a.trace(1), 18)
2725        assert_equal(a.trace(-1), 13)
2726
2727        b = np.arange(8).reshape((2, 2, 2))
2728        assert_equal(b.trace(), [6, 8])
2729        assert_equal(b.trace(0), [6, 8])
2730        assert_equal(b.trace(1), [2, 3])
2731        assert_equal(b.trace(-1), [4, 5])
2732        assert_equal(b.trace(0, 0, 1), [6, 8])
2733        assert_equal(b.trace(0, 0, 2), [5, 9])
2734        assert_equal(b.trace(0, 1, 2), [3, 11])
2735        assert_equal(b.trace(offset=1, axis1=0, axis2=2), [1, 3])
2736
2737        out = np.array(1)
2738        ret = a.trace(out=out)
2739        assert ret is out
2740
2741    def test_put(self):
2742        icodes = np.typecodes["AllInteger"]
2743        fcodes = np.typecodes["AllFloat"]
2744        for dt in icodes + fcodes:
2745            tgt = np.array([0, 1, 0, 3, 0, 5], dtype=dt)
2746
2747            # test 1-d
2748            a = np.zeros(6, dtype=dt)
2749            a.put([1, 3, 5], [1, 3, 5])
2750            assert_equal(a, tgt)
2751
2752            # test 2-d
2753            a = np.zeros((2, 3), dtype=dt)
2754            a.put([1, 3, 5], [1, 3, 5])
2755            assert_equal(a, tgt.reshape(2, 3))
2756
2757        for dt in "?":
2758            tgt = np.array([False, True, False, True, False, True], dtype=dt)
2759
2760            # test 1-d
2761            a = np.zeros(6, dtype=dt)
2762            a.put([1, 3, 5], [True] * 3)
2763            assert_equal(a, tgt)
2764
2765            # test 2-d
2766            a = np.zeros((2, 3), dtype=dt)
2767            a.put([1, 3, 5], [True] * 3)
2768            assert_equal(a, tgt.reshape(2, 3))
2769
2770        # when calling np.put, make sure a
2771        # TypeError is raised if the object
2772        # isn't an ndarray
2773        bad_array = [1, 2, 3]
2774        assert_raises(TypeError, np.put, bad_array, [0, 2], 5)
2775
2776    @xpassIfTorchDynamo  # (reason="TODO: implement order='F'")
2777    def test_ravel(self):
2778        a = np.array([[0, 1], [2, 3]])
2779        assert_equal(a.ravel(), [0, 1, 2, 3])
2780        assert_(not a.ravel().flags.owndata)
2781        assert_equal(a.ravel("F"), [0, 2, 1, 3])
2782        assert_equal(a.ravel(order="C"), [0, 1, 2, 3])
2783        assert_equal(a.ravel(order="F"), [0, 2, 1, 3])
2784        assert_equal(a.ravel(order="A"), [0, 1, 2, 3])
2785        assert_(not a.ravel(order="A").flags.owndata)
2786        assert_equal(a.ravel(order="K"), [0, 1, 2, 3])
2787        assert_(not a.ravel(order="K").flags.owndata)
2788        assert_equal(a.ravel(), a.reshape(-1))
2789
2790        a = np.array([[0, 1], [2, 3]], order="F")
2791        assert_equal(a.ravel(), [0, 1, 2, 3])
2792        assert_equal(a.ravel(order="A"), [0, 2, 1, 3])
2793        assert_equal(a.ravel(order="K"), [0, 2, 1, 3])
2794        assert_(not a.ravel(order="A").flags.owndata)
2795        assert_(not a.ravel(order="K").flags.owndata)
2796        assert_equal(a.ravel(), a.reshape(-1))
2797        assert_equal(a.ravel(order="A"), a.reshape(-1, order="A"))
2798
2799        a = np.array([[0, 1], [2, 3]])[::-1, :]
2800        assert_equal(a.ravel(), [2, 3, 0, 1])
2801        assert_equal(a.ravel(order="C"), [2, 3, 0, 1])
2802        assert_equal(a.ravel(order="F"), [2, 0, 3, 1])
2803        assert_equal(a.ravel(order="A"), [2, 3, 0, 1])
2804        # 'K' doesn't reverse the axes of negative strides
2805        assert_equal(a.ravel(order="K"), [2, 3, 0, 1])
2806        assert_(a.ravel(order="K").flags.owndata)
2807
2808        # Test simple 1-d copy behaviour:
2809        a = np.arange(10)[::2]
2810        assert_(a.ravel("K").flags.owndata)
2811        assert_(a.ravel("C").flags.owndata)
2812        assert_(a.ravel("F").flags.owndata)
2813
2814        # Not contiguous and 1-sized axis with non matching stride
2815        a = np.arange(2**3 * 2)[::2]
2816        a = a.reshape(2, 1, 2, 2).swapaxes(-1, -2)
2817        strides = list(a.strides)
2818        strides[1] = 123
2819        a.strides = strides
2820        assert_(a.ravel(order="K").flags.owndata)
2821        assert_equal(a.ravel("K"), np.arange(0, 15, 2))
2822
2823        # contiguous and 1-sized axis with non matching stride works:
2824        a = np.arange(2**3)
2825        a = a.reshape(2, 1, 2, 2).swapaxes(-1, -2)
2826        strides = list(a.strides)
2827        strides[1] = 123
2828        a.strides = strides
2829        assert_(np.may_share_memory(a.ravel(order="K"), a))
2830        assert_equal(a.ravel(order="K"), np.arange(2**3))
2831
2832        # Test negative strides (not very interesting since non-contiguous):
2833        a = np.arange(4)[::-1].reshape(2, 2)
2834        assert_(a.ravel(order="C").flags.owndata)
2835        assert_(a.ravel(order="K").flags.owndata)
2836        assert_equal(a.ravel("C"), [3, 2, 1, 0])
2837        assert_equal(a.ravel("K"), [3, 2, 1, 0])
2838
2839        # 1-element tidy strides test:
2840        a = np.array([[1]])
2841        a.strides = (123, 432)
2842        # If the following stride is not 8, NPY_RELAXED_STRIDES_DEBUG is
2843        # messing them up on purpose:
2844        if np.ones(1).strides == (8,):
2845            assert_(np.may_share_memory(a.ravel("K"), a))
2846            assert_equal(a.ravel("K").strides, (a.dtype.itemsize,))
2847
2848        for order in ("C", "F", "A", "K"):
2849            # 0-d corner case:
2850            a = np.array(0)
2851            assert_equal(a.ravel(order), [0])
2852            assert_(np.may_share_memory(a.ravel(order), a))
2853
2854        # Test that certain non-inplace ravels work right (mostly) for 'K':
2855        b = np.arange(2**4 * 2)[::2].reshape(2, 2, 2, 2)
2856        a = b[..., ::2]
2857        assert_equal(a.ravel("K"), [0, 4, 8, 12, 16, 20, 24, 28])
2858        assert_equal(a.ravel("C"), [0, 4, 8, 12, 16, 20, 24, 28])
2859        assert_equal(a.ravel("A"), [0, 4, 8, 12, 16, 20, 24, 28])
2860        assert_equal(a.ravel("F"), [0, 16, 8, 24, 4, 20, 12, 28])
2861
2862        a = b[::2, ...]
2863        assert_equal(a.ravel("K"), [0, 2, 4, 6, 8, 10, 12, 14])
2864        assert_equal(a.ravel("C"), [0, 2, 4, 6, 8, 10, 12, 14])
2865        assert_equal(a.ravel("A"), [0, 2, 4, 6, 8, 10, 12, 14])
2866        assert_equal(a.ravel("F"), [0, 8, 4, 12, 2, 10, 6, 14])
2867
2868    @xfailIfTorchDynamo  # flags["OWNDATA"]
2869    def test_swapaxes(self):
2870        a = np.arange(1 * 2 * 3 * 4).reshape(1, 2, 3, 4).copy()
2871        idx = np.indices(a.shape)
2872        assert_(a.flags["OWNDATA"])
2873        b = a.copy()
2874        # check exceptions
2875        assert_raises(np.AxisError, a.swapaxes, -5, 0)
2876        assert_raises(np.AxisError, a.swapaxes, 4, 0)
2877        assert_raises(np.AxisError, a.swapaxes, 0, -5)
2878        assert_raises(np.AxisError, a.swapaxes, 0, 4)
2879
2880        for i in range(-4, 4):
2881            for j in range(-4, 4):
2882                for k, src in enumerate((a, b)):
2883                    c = src.swapaxes(i, j)
2884                    # check shape
2885                    shape = list(src.shape)
2886                    shape[i] = src.shape[j]
2887                    shape[j] = src.shape[i]
2888                    assert_equal(c.shape, shape, str((i, j, k)))
2889                    # check array contents
2890                    i0, i1, i2, i3 = (dim - 1 for dim in c.shape)
2891                    j0, j1, j2, j3 = (dim - 1 for dim in src.shape)
2892                    assert_equal(
2893                        src[idx[j0], idx[j1], idx[j2], idx[j3]],
2894                        c[idx[i0], idx[i1], idx[i2], idx[i3]],
2895                        str((i, j, k)),
2896                    )
2897                    # check a view is always returned, gh-5260
2898                    assert_(not c.flags["OWNDATA"], str((i, j, k)))
2899                    # check on non-contiguous input array
2900                    if k == 1:
2901                        b = c
2902
2903    def test_conjugate(self):
2904        a = np.array([1 - 1j, 1 + 1j, 23 + 23.0j])
2905        ac = a.conj()
2906        assert_equal(a.real, ac.real)
2907        assert_equal(a.imag, -ac.imag)
2908        assert_equal(ac, a.conjugate())
2909        assert_equal(ac, np.conjugate(a))
2910
2911        a = np.array([1 - 1j, 1 + 1j, 23 + 23.0j], "F")
2912        ac = a.conj()
2913        assert_equal(a.real, ac.real)
2914        assert_equal(a.imag, -ac.imag)
2915        assert_equal(ac, a.conjugate())
2916        assert_equal(ac, np.conjugate(a))
2917
2918        a = np.array([1, 2, 3])
2919        ac = a.conj()
2920        assert_equal(a, ac)
2921        assert_equal(ac, a.conjugate())
2922        assert_equal(ac, np.conjugate(a))
2923
2924        a = np.array([1.0, 2.0, 3.0])
2925        ac = a.conj()
2926        assert_equal(a, ac)
2927        assert_equal(ac, a.conjugate())
2928        assert_equal(ac, np.conjugate(a))
2929
2930    def test_conjugate_out(self):
2931        # Minimal test for the out argument being passed on correctly
2932        # NOTE: The ability to pass `out` is currently undocumented!
2933        a = np.array([1 - 1j, 1 + 1j, 23 + 23.0j])
2934        out = np.empty_like(a)
2935        res = a.conjugate(out)
2936        assert res is out
2937        assert_array_equal(out, a.conjugate())
2938
2939    def test__complex__(self):
2940        dtypes = [
2941            "i1",
2942            "i2",
2943            "i4",
2944            "i8",
2945            "u1",
2946            "f",
2947            "d",
2948            "F",
2949            "D",
2950            "?",
2951        ]
2952        for dt in dtypes:
2953            a = np.array(7, dtype=dt)
2954            b = np.array([7], dtype=dt)
2955            c = np.array([[[[[7]]]]], dtype=dt)
2956
2957            msg = f"dtype: {dt}"
2958            ap = complex(a)
2959            assert_equal(ap, a, msg)
2960            bp = complex(b)
2961            assert_equal(bp, b, msg)
2962            cp = complex(c)
2963            assert_equal(cp, c, msg)
2964
2965    def test__complex__should_not_work(self):
2966        dtypes = [
2967            "i1",
2968            "i2",
2969            "i4",
2970            "i8",
2971            "u1",
2972            "f",
2973            "d",
2974            "F",
2975            "D",
2976            "?",
2977        ]
2978        for dt in dtypes:
2979            a = np.array([1, 2, 3], dtype=dt)
2980            assert_raises((TypeError, ValueError), complex, a)
2981
2982        c = np.array([(1.0, 3), (2e-3, 7)], dtype=dt)
2983        assert_raises((TypeError, ValueError), complex, c)
2984
2985
2986class TestCequenceMethods(TestCase):
2987    def test_array_contains(self):
2988        assert_(4.0 in np.arange(16.0).reshape(4, 4))
2989        assert_(20.0 not in np.arange(16.0).reshape(4, 4))
2990
2991
2992class TestBinop(TestCase):
2993    def test_inplace(self):
2994        # test refcount 1 inplace conversion
2995        assert_array_almost_equal(np.array([0.5]) * np.array([1.0, 2.0]), [0.5, 1.0])
2996
2997        d = np.array([0.5, 0.5])[::2]
2998        assert_array_almost_equal(d * (d * np.array([1.0, 2.0])), [0.25, 0.5])
2999
3000        a = np.array([0.5])
3001        b = np.array([0.5])
3002        c = a + b
3003        c = a - b
3004        c = a * b
3005        c = a / b
3006        assert_equal(a, b)
3007        assert_almost_equal(c, 1.0)
3008
3009        c = a + b * 2.0 / b * a - a / b
3010        assert_equal(a, b)
3011        assert_equal(c, 0.5)
3012
3013        # true divide
3014        a = np.array([5])
3015        b = np.array([3])
3016        c = (a * a) / b
3017
3018        assert_almost_equal(c, 25 / 3, decimal=5)
3019        assert_equal(a, 5)
3020        assert_equal(b, 3)
3021
3022
3023class TestSubscripting(TestCase):
3024    def test_test_zero_rank(self):
3025        x = np.array([1, 2, 3])
3026        assert_(isinstance(x[0], (np.int_, np.ndarray)))
3027        assert_(type(x[0, ...]) is np.ndarray)
3028
3029
3030class TestFancyIndexing(TestCase):
3031    def test_list(self):
3032        x = np.ones((1, 1))
3033        x[:, [0]] = 2.0
3034        assert_array_equal(x, np.array([[2.0]]))
3035
3036        x = np.ones((1, 1, 1))
3037        x[:, :, [0]] = 2.0
3038        assert_array_equal(x, np.array([[[2.0]]]))
3039
3040    def test_tuple(self):
3041        x = np.ones((1, 1))
3042        x[:, (0,)] = 2.0
3043        assert_array_equal(x, np.array([[2.0]]))
3044        x = np.ones((1, 1, 1))
3045        x[:, :, (0,)] = 2.0
3046        assert_array_equal(x, np.array([[[2.0]]]))
3047
3048    def test_mask(self):
3049        x = np.array([1, 2, 3, 4])
3050        m = np.array([0, 1, 0, 0], bool)
3051        assert_array_equal(x[m], np.array([2]))
3052
3053    def test_mask2(self):
3054        x = np.array([[1, 2, 3, 4], [5, 6, 7, 8]])
3055        m = np.array([0, 1], bool)
3056        m2 = np.array([[0, 1, 0, 0], [1, 0, 0, 0]], bool)
3057        m3 = np.array([[0, 1, 0, 0], [0, 0, 0, 0]], bool)
3058        assert_array_equal(x[m], np.array([[5, 6, 7, 8]]))
3059        assert_array_equal(x[m2], np.array([2, 5]))
3060        assert_array_equal(x[m3], np.array([2]))
3061
3062    def test_assign_mask(self):
3063        x = np.array([1, 2, 3, 4])
3064        m = np.array([0, 1, 0, 0], bool)
3065        x[m] = 5
3066        assert_array_equal(x, np.array([1, 5, 3, 4]))
3067
3068    def test_assign_mask2(self):
3069        xorig = np.array([[1, 2, 3, 4], [5, 6, 7, 8]])
3070        m = np.array([0, 1], bool)
3071        m2 = np.array([[0, 1, 0, 0], [1, 0, 0, 0]], bool)
3072        m3 = np.array([[0, 1, 0, 0], [0, 0, 0, 0]], bool)
3073        x = xorig.copy()
3074        x[m] = 10
3075        assert_array_equal(x, np.array([[1, 2, 3, 4], [10, 10, 10, 10]]))
3076        x = xorig.copy()
3077        x[m2] = 10
3078        assert_array_equal(x, np.array([[1, 10, 3, 4], [10, 6, 7, 8]]))
3079        x = xorig.copy()
3080        x[m3] = 10
3081        assert_array_equal(x, np.array([[1, 10, 3, 4], [5, 6, 7, 8]]))
3082
3083
3084@instantiate_parametrized_tests
3085class TestArgmaxArgminCommon(TestCase):
3086    sizes = [
3087        (),
3088        (3,),
3089        (3, 2),
3090        (2, 3),
3091        (3, 3),
3092        (2, 3, 4),
3093        (4, 3, 2),
3094        (1, 2, 3, 4),
3095        (2, 3, 4, 1),
3096        (3, 4, 1, 2),
3097        (4, 1, 2, 3),
3098        (64,),
3099        (128,),
3100        (256,),
3101    ]
3102
3103    @parametrize(
3104        "size, axis",
3105        list(
3106            itertools.chain(
3107                *[
3108                    [
3109                        (size, axis)
3110                        for axis in list(range(-len(size), len(size))) + [None]
3111                    ]
3112                    for size in sizes
3113                ]
3114            )
3115        ),
3116    )
3117    @skipif(numpy.__version__ < "1.23", reason="keepdims is new in numpy 1.22")
3118    @parametrize("method", [np.argmax, np.argmin])
3119    def test_np_argmin_argmax_keepdims(self, size, axis, method):
3120        arr = np.random.normal(size=size)
3121        if size is None or size == ():
3122            arr = np.asarray(arr)
3123
3124        # contiguous arrays
3125        if axis is None:
3126            new_shape = [1 for _ in range(len(size))]
3127        else:
3128            new_shape = list(size)
3129            new_shape[axis] = 1
3130        new_shape = tuple(new_shape)
3131
3132        _res_orig = method(arr, axis=axis)
3133        res_orig = _res_orig.reshape(new_shape)
3134        res = method(arr, axis=axis, keepdims=True)
3135        assert_equal(res, res_orig)
3136        assert_(res.shape == new_shape)
3137        outarray = np.empty(res.shape, dtype=res.dtype)
3138        res1 = method(arr, axis=axis, out=outarray, keepdims=True)
3139        assert_(res1 is outarray)
3140        assert_equal(res, outarray)
3141
3142        if len(size) > 0:
3143            wrong_shape = list(new_shape)
3144            if axis is not None:
3145                wrong_shape[axis] = 2
3146            else:
3147                wrong_shape[0] = 2
3148            wrong_outarray = np.empty(wrong_shape, dtype=res.dtype)
3149            with pytest.raises(ValueError):
3150                method(arr.T, axis=axis, out=wrong_outarray, keepdims=True)
3151
3152        # non-contiguous arrays
3153        if axis is None:
3154            new_shape = [1 for _ in range(len(size))]
3155        else:
3156            new_shape = list(size)[::-1]
3157            new_shape[axis] = 1
3158        new_shape = tuple(new_shape)
3159
3160        _res_orig = method(arr.T, axis=axis)
3161        res_orig = _res_orig.reshape(new_shape)
3162        res = method(arr.T, axis=axis, keepdims=True)
3163        assert_equal(res, res_orig)
3164        assert_(res.shape == new_shape)
3165        outarray = np.empty(new_shape[::-1], dtype=res.dtype)
3166        outarray = outarray.T
3167        res1 = method(arr.T, axis=axis, out=outarray, keepdims=True)
3168        assert_(res1 is outarray)
3169        assert_equal(res, outarray)
3170
3171        if len(size) > 0:
3172            # one dimension lesser for non-zero sized
3173            # array should raise an error
3174            with pytest.raises(ValueError):
3175                method(arr[0], axis=axis, out=outarray, keepdims=True)
3176
3177        if len(size) > 0:
3178            wrong_shape = list(new_shape)
3179            if axis is not None:
3180                wrong_shape[axis] = 2
3181            else:
3182                wrong_shape[0] = 2
3183            wrong_outarray = np.empty(wrong_shape, dtype=res.dtype)
3184            with pytest.raises(ValueError):
3185                method(arr.T, axis=axis, out=wrong_outarray, keepdims=True)
3186
3187    @xpassIfTorchDynamo  # (reason="TODO: implement choose")
3188    @parametrize("method", ["max", "min"])
3189    def test_all(self, method):
3190        a = np.random.normal(0, 1, (4, 5, 6, 7, 8))
3191        arg_method = getattr(a, "arg" + method)
3192        val_method = getattr(a, method)
3193        for i in range(a.ndim):
3194            a_maxmin = val_method(i)
3195            aarg_maxmin = arg_method(i)
3196            axes = list(range(a.ndim))
3197            axes.remove(i)
3198            assert_(np.all(a_maxmin == aarg_maxmin.choose(*a.transpose(i, *axes))))
3199
3200    @parametrize("method", ["argmax", "argmin"])
3201    def test_output_shape(self, method):
3202        # see also gh-616
3203        a = np.ones((10, 5))
3204        arg_method = getattr(a, method)
3205        # Check some simple shape mismatches
3206        out = np.ones(11, dtype=np.int_)
3207        assert_raises(ValueError, arg_method, -1, out)
3208
3209        out = np.ones((2, 5), dtype=np.int_)
3210        assert_raises(ValueError, arg_method, -1, out)
3211
3212        # these could be relaxed possibly (used to allow even the previous)
3213        out = np.ones((1, 10), dtype=np.int_)
3214        assert_raises(ValueError, arg_method, -1, out)
3215
3216        out = np.ones(10, dtype=np.int_)
3217        arg_method(-1, out=out)
3218        assert_equal(out, arg_method(-1))
3219
3220    @parametrize("ndim", [0, 1])
3221    @parametrize("method", ["argmax", "argmin"])
3222    def test_ret_is_out(self, ndim, method):
3223        a = np.ones((4,) + (256,) * ndim)
3224        arg_method = getattr(a, method)
3225        out = np.empty((256,) * ndim, dtype=np.intp)
3226        ret = arg_method(axis=0, out=out)
3227        assert ret is out
3228
3229    @parametrize(
3230        "arr_method, np_method", [("argmax", np.argmax), ("argmin", np.argmin)]
3231    )
3232    def test_np_vs_ndarray(self, arr_method, np_method):
3233        # make sure both ndarray.argmax/argmin and
3234        # numpy.argmax/argmin support out/axis args
3235        a = np.random.normal(size=(2, 3))
3236        arg_method = getattr(a, arr_method)
3237
3238        # check positional args
3239        out1 = np.zeros(2, dtype=int)
3240        out2 = np.zeros(2, dtype=int)
3241        assert_equal(arg_method(1, out1), np_method(a, 1, out2))
3242        assert_equal(out1, out2)
3243
3244        # check keyword args
3245        out1 = np.zeros(3, dtype=int)
3246        out2 = np.zeros(3, dtype=int)
3247        assert_equal(arg_method(out=out1, axis=0), np_method(a, out=out2, axis=0))
3248        assert_equal(out1, out2)
3249
3250
3251@instantiate_parametrized_tests
3252class TestArgmax(TestCase):
3253    usg_data = [
3254        ([1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], 0),
3255        ([3, 3, 3, 3, 2, 2, 2, 2], 0),
3256        ([0, 1, 2, 3, 4, 5, 6, 7], 7),
3257        ([7, 6, 5, 4, 3, 2, 1, 0], 0),
3258    ]
3259    sg_data = usg_data + [
3260        ([1, 2, 3, 4, -4, -3, -2, -1], 3),
3261        ([1, 2, 3, 4, -1, -2, -3, -4], 3),
3262    ]
3263    darr = [
3264        (np.array(d[0], dtype=t), d[1])
3265        for d, t in (itertools.product(usg_data, (np.uint8,)))
3266    ]
3267    darr += [
3268        (np.array(d[0], dtype=t), d[1])
3269        for d, t in (
3270            itertools.product(
3271                sg_data, (np.int8, np.int16, np.int32, np.int64, np.float32, np.float64)
3272            )
3273        )
3274    ]
3275    darr += [
3276        (np.array(d[0], dtype=t), d[1])
3277        for d, t in (
3278            itertools.product(
3279                (
3280                    ([0, 1, 2, 3, np.nan], 4),
3281                    ([0, 1, 2, np.nan, 3], 3),
3282                    ([np.nan, 0, 1, 2, 3], 0),
3283                    ([np.nan, 0, np.nan, 2, 3], 0),
3284                    # To hit the tail of SIMD multi-level(x4, x1) inner loops
3285                    # on variant SIMD widthes
3286                    ([1] * (2 * 5 - 1) + [np.nan], 2 * 5 - 1),
3287                    ([1] * (4 * 5 - 1) + [np.nan], 4 * 5 - 1),
3288                    ([1] * (8 * 5 - 1) + [np.nan], 8 * 5 - 1),
3289                    ([1] * (16 * 5 - 1) + [np.nan], 16 * 5 - 1),
3290                    ([1] * (32 * 5 - 1) + [np.nan], 32 * 5 - 1),
3291                ),
3292                (np.float32, np.float64),
3293            )
3294        )
3295    ]
3296    nan_arr = darr + [
3297        # RuntimeError: "max_values_cpu" not implemented for 'ComplexDouble'
3298        #      ([0, 1, 2, 3, complex(0, np.nan)], 4),
3299        #      ([0, 1, 2, 3, complex(np.nan, 0)], 4),
3300        #      ([0, 1, 2, complex(np.nan, 0), 3], 3),
3301        #      ([0, 1, 2, complex(0, np.nan), 3], 3),
3302        #      ([complex(0, np.nan), 0, 1, 2, 3], 0),
3303        #      ([complex(np.nan, np.nan), 0, 1, 2, 3], 0),
3304        #      ([complex(np.nan, 0), complex(np.nan, 2), complex(np.nan, 1)], 0),
3305        #      ([complex(np.nan, np.nan), complex(np.nan, 2), complex(np.nan, 1)], 0),
3306        #      ([complex(np.nan, 0), complex(np.nan, 2), complex(np.nan, np.nan)], 0),
3307        #      ([complex(0, 0), complex(0, 2), complex(0, 1)], 1),
3308        #      ([complex(1, 0), complex(0, 2), complex(0, 1)], 0),
3309        #      ([complex(1, 0), complex(0, 2), complex(1, 1)], 2),
3310        ([False, False, False, False, True], 4),
3311        ([False, False, False, True, False], 3),
3312        ([True, False, False, False, False], 0),
3313        ([True, False, True, False, False], 0),
3314    ]
3315
3316    @parametrize("data", nan_arr)
3317    def test_combinations(self, data):
3318        arr, pos = data
3319        with suppress_warnings() as sup:
3320            sup.filter(RuntimeWarning, "invalid value encountered in reduce")
3321            val = np.max(arr)
3322
3323        assert_equal(np.argmax(arr), pos, err_msg=f"{arr!r}")
3324        assert_equal(arr[np.argmax(arr)], val, err_msg=f"{arr!r}")
3325
3326        # add padding to test SIMD loops
3327        rarr = np.repeat(arr, 129)
3328        rpos = pos * 129
3329        assert_equal(np.argmax(rarr), rpos, err_msg=f"{rarr!r}")
3330        assert_equal(rarr[np.argmax(rarr)], val, err_msg=f"{rarr!r}")
3331
3332        padd = np.repeat(np.min(arr), 513)
3333        rarr = np.concatenate((arr, padd))
3334        rpos = pos
3335        assert_equal(np.argmax(rarr), rpos, err_msg=f"{rarr!r}")
3336        assert_equal(rarr[np.argmax(rarr)], val, err_msg=f"{rarr!r}")
3337
3338    def test_maximum_signed_integers(self):
3339        a = np.array([1, 2**7 - 1, -(2**7)], dtype=np.int8)
3340        assert_equal(np.argmax(a), 1)
3341        a = a.repeat(129)
3342        assert_equal(np.argmax(a), 129)
3343
3344        a = np.array([1, 2**15 - 1, -(2**15)], dtype=np.int16)
3345        assert_equal(np.argmax(a), 1)
3346        a = a.repeat(129)
3347        assert_equal(np.argmax(a), 129)
3348
3349        a = np.array([1, 2**31 - 1, -(2**31)], dtype=np.int32)
3350        assert_equal(np.argmax(a), 1)
3351        a = a.repeat(129)
3352        assert_equal(np.argmax(a), 129)
3353
3354        a = np.array([1, 2**63 - 1, -(2**63)], dtype=np.int64)
3355        assert_equal(np.argmax(a), 1)
3356        a = a.repeat(129)
3357        assert_equal(np.argmax(a), 129)
3358
3359
3360@instantiate_parametrized_tests
3361class TestArgmin(TestCase):
3362    usg_data = [
3363        ([1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], 8),
3364        ([3, 3, 3, 3, 2, 2, 2, 2], 4),
3365        ([0, 1, 2, 3, 4, 5, 6, 7], 0),
3366        ([7, 6, 5, 4, 3, 2, 1, 0], 7),
3367    ]
3368    sg_data = usg_data + [
3369        ([1, 2, 3, 4, -4, -3, -2, -1], 4),
3370        ([1, 2, 3, 4, -1, -2, -3, -4], 7),
3371    ]
3372    darr = [
3373        (np.array(d[0], dtype=t), d[1])
3374        for d, t in (itertools.product(usg_data, (np.uint8,)))
3375    ]
3376    darr += [
3377        (np.array(d[0], dtype=t), d[1])
3378        for d, t in (
3379            itertools.product(
3380                sg_data, (np.int8, np.int16, np.int32, np.int64, np.float32, np.float64)
3381            )
3382        )
3383    ]
3384    darr += [
3385        (np.array(d[0], dtype=t), d[1])
3386        for d, t in (
3387            itertools.product(
3388                (
3389                    ([0, 1, 2, 3, np.nan], 4),
3390                    ([0, 1, 2, np.nan, 3], 3),
3391                    ([np.nan, 0, 1, 2, 3], 0),
3392                    ([np.nan, 0, np.nan, 2, 3], 0),
3393                    # To hit the tail of SIMD multi-level(x4, x1) inner loops
3394                    # on variant SIMD widthes
3395                    ([1] * (2 * 5 - 1) + [np.nan], 2 * 5 - 1),
3396                    ([1] * (4 * 5 - 1) + [np.nan], 4 * 5 - 1),
3397                    ([1] * (8 * 5 - 1) + [np.nan], 8 * 5 - 1),
3398                    ([1] * (16 * 5 - 1) + [np.nan], 16 * 5 - 1),
3399                    ([1] * (32 * 5 - 1) + [np.nan], 32 * 5 - 1),
3400                ),
3401                (np.float32, np.float64),
3402            )
3403        )
3404    ]
3405    nan_arr = darr + [
3406        # RuntimeError: "min_values_cpu" not implemented for 'ComplexDouble'
3407        #    ([0, 1, 2, 3, complex(0, np.nan)], 4),
3408        #    ([0, 1, 2, 3, complex(np.nan, 0)], 4),
3409        #    ([0, 1, 2, complex(np.nan, 0), 3], 3),
3410        #    ([0, 1, 2, complex(0, np.nan), 3], 3),
3411        #    ([complex(0, np.nan), 0, 1, 2, 3], 0),
3412        #    ([complex(np.nan, np.nan), 0, 1, 2, 3], 0),
3413        #    ([complex(np.nan, 0), complex(np.nan, 2), complex(np.nan, 1)], 0),
3414        #    ([complex(np.nan, np.nan), complex(np.nan, 2), complex(np.nan, 1)], 0),
3415        #    ([complex(np.nan, 0), complex(np.nan, 2), complex(np.nan, np.nan)], 0),
3416        #    ([complex(0, 0), complex(0, 2), complex(0, 1)], 0),
3417        #    ([complex(1, 0), complex(0, 2), complex(0, 1)], 2),
3418        #    ([complex(1, 0), complex(0, 2), complex(1, 1)], 1),
3419        ([True, True, True, True, False], 4),
3420        ([True, True, True, False, True], 3),
3421        ([False, True, True, True, True], 0),
3422        ([False, True, False, True, True], 0),
3423    ]
3424
3425    @parametrize("data", nan_arr)
3426    def test_combinations(self, data):
3427        arr, pos = data
3428        with suppress_warnings() as sup:
3429            sup.filter(RuntimeWarning, "invalid value encountered in reduce")
3430            min_val = np.min(arr)
3431
3432        assert_equal(np.argmin(arr), pos, err_msg=f"{arr!r}")
3433        assert_equal(arr[np.argmin(arr)], min_val, err_msg=f"{arr!r}")
3434
3435        # add padding to test SIMD loops
3436        rarr = np.repeat(arr, 129)
3437        rpos = pos * 129
3438        assert_equal(np.argmin(rarr), rpos, err_msg=f"{rarr!r}")
3439        assert_equal(rarr[np.argmin(rarr)], min_val, err_msg=f"{rarr!r}")
3440
3441        padd = np.repeat(np.max(arr), 513)
3442        rarr = np.concatenate((arr, padd))
3443        rpos = pos
3444        assert_equal(np.argmin(rarr), rpos, err_msg=f"{rarr!r}")
3445        assert_equal(rarr[np.argmin(rarr)], min_val, err_msg=f"{rarr!r}")
3446
3447    def test_minimum_signed_integers(self):
3448        a = np.array([1, -(2**7), -(2**7) + 1, 2**7 - 1], dtype=np.int8)
3449        assert_equal(np.argmin(a), 1)
3450        a = a.repeat(129)
3451        assert_equal(np.argmin(a), 129)
3452
3453        a = np.array([1, -(2**15), -(2**15) + 1, 2**15 - 1], dtype=np.int16)
3454        assert_equal(np.argmin(a), 1)
3455        a = a.repeat(129)
3456        assert_equal(np.argmin(a), 129)
3457
3458        a = np.array([1, -(2**31), -(2**31) + 1, 2**31 - 1], dtype=np.int32)
3459        assert_equal(np.argmin(a), 1)
3460        a = a.repeat(129)
3461        assert_equal(np.argmin(a), 129)
3462
3463        a = np.array([1, -(2**63), -(2**63) + 1, 2**63 - 1], dtype=np.int64)
3464        assert_equal(np.argmin(a), 1)
3465        a = a.repeat(129)
3466        assert_equal(np.argmin(a), 129)
3467
3468
3469class TestMinMax(TestCase):
3470    @xpassIfTorchDynamo
3471    def test_scalar(self):
3472        assert_raises(np.AxisError, np.amax, 1, 1)
3473        assert_raises(np.AxisError, np.amin, 1, 1)
3474
3475        assert_equal(np.amax(1, axis=0), 1)
3476        assert_equal(np.amin(1, axis=0), 1)
3477        assert_equal(np.amax(1, axis=None), 1)
3478        assert_equal(np.amin(1, axis=None), 1)
3479
3480    def test_axis(self):
3481        assert_raises(np.AxisError, np.amax, [1, 2, 3], 1000)
3482        assert_equal(np.amax([[1, 2, 3]], axis=1), 3)
3483
3484
3485class TestNewaxis(TestCase):
3486    def test_basic(self):
3487        sk = np.array([0, -0.1, 0.1])
3488        res = 250 * sk[:, np.newaxis]
3489        assert_almost_equal(res.ravel(), 250 * sk)
3490
3491
3492_sctypes = {
3493    "int": [np.int8, np.int16, np.int32, np.int64],
3494    "uint": [np.uint8, np.uint16, np.uint32, np.uint64],
3495    "float": [np.float32, np.float64],
3496    "complex": [np.complex64, np.complex128]
3497    # no complex256 in torch._numpy
3498    + ([np.clongdouble] if hasattr(np, "clongdouble") else []),
3499}
3500
3501
3502class TestClip(TestCase):
3503    def _check_range(self, x, cmin, cmax):
3504        assert_(np.all(x >= cmin))
3505        assert_(np.all(x <= cmax))
3506
3507    def _clip_type(
3508        self,
3509        type_group,
3510        array_max,
3511        clip_min,
3512        clip_max,
3513        inplace=False,
3514        expected_min=None,
3515        expected_max=None,
3516    ):
3517        if expected_min is None:
3518            expected_min = clip_min
3519        if expected_max is None:
3520            expected_max = clip_max
3521
3522        for T in _sctypes[type_group]:
3523            if sys.byteorder == "little":
3524                byte_orders = ["=", ">"]
3525            else:
3526                byte_orders = ["<", "="]
3527
3528            for byteorder in byte_orders:
3529                dtype = np.dtype(T).newbyteorder(byteorder)
3530
3531                x = (np.random.random(1000) * array_max).astype(dtype)
3532                if inplace:
3533                    # The tests that call us pass clip_min and clip_max that
3534                    # might not fit in the destination dtype. They were written
3535                    # assuming the previous unsafe casting, which now must be
3536                    # passed explicitly to avoid a warning.
3537                    x.clip(clip_min, clip_max, x, casting="unsafe")
3538                else:
3539                    x = x.clip(clip_min, clip_max)
3540                    byteorder = "="
3541
3542                if x.dtype.byteorder == "|":
3543                    byteorder = "|"
3544                assert_equal(x.dtype.byteorder, byteorder)
3545                self._check_range(x, expected_min, expected_max)
3546        return x
3547
3548    @skip(reason="endianness")
3549    def test_basic(self):
3550        for inplace in [False, True]:
3551            self._clip_type("float", 1024, -12.8, 100.2, inplace=inplace)
3552            self._clip_type("float", 1024, 0, 0, inplace=inplace)
3553
3554            self._clip_type("int", 1024, -120, 100, inplace=inplace)
3555            self._clip_type("int", 1024, 0, 0, inplace=inplace)
3556
3557            self._clip_type("uint", 1024, 0, 0, inplace=inplace)
3558            self._clip_type("uint", 1024, -120, 100, inplace=inplace, expected_min=0)
3559
3560    def test_max_or_min(self):
3561        val = np.array([0, 1, 2, 3, 4, 5, 6, 7])
3562        x = val.clip(3)
3563        assert_(np.all(x >= 3))
3564        x = val.clip(min=3)
3565        assert_(np.all(x >= 3))
3566        x = val.clip(max=4)
3567        assert_(np.all(x <= 4))
3568
3569    def test_nan(self):
3570        input_arr = np.array([-2.0, np.nan, 0.5, 3.0, 0.25, np.nan])
3571        result = input_arr.clip(-1, 1)
3572        expected = np.array([-1.0, np.nan, 0.5, 1.0, 0.25, np.nan])
3573        assert_array_equal(result, expected)
3574
3575
3576@xpassIfTorchDynamo  # (reason="TODO")
3577class TestCompress(TestCase):
3578    def test_axis(self):
3579        tgt = [[5, 6, 7, 8, 9]]
3580        arr = np.arange(10).reshape(2, 5)
3581        out = np.compress([0, 1], arr, axis=0)
3582        assert_equal(out, tgt)
3583
3584        tgt = [[1, 3], [6, 8]]
3585        out = np.compress([0, 1, 0, 1, 0], arr, axis=1)
3586        assert_equal(out, tgt)
3587
3588    def test_truncate(self):
3589        tgt = [[1], [6]]
3590        arr = np.arange(10).reshape(2, 5)
3591        out = np.compress([0, 1], arr, axis=1)
3592        assert_equal(out, tgt)
3593
3594    def test_flatten(self):
3595        arr = np.arange(10).reshape(2, 5)
3596        out = np.compress([0, 1], arr)
3597        assert_equal(out, 1)
3598
3599
3600@xpassIfTorchDynamo  # (reason="TODO")
3601@instantiate_parametrized_tests
3602class TestPutmask(TestCase):
3603    def tst_basic(self, x, T, mask, val):
3604        np.putmask(x, mask, val)
3605        assert_equal(x[mask], np.array(val, T))
3606
3607    def test_ip_types(self):
3608        unchecked_types = [bytes, str, np.void]
3609
3610        x = np.random.random(1000) * 100
3611        mask = x < 40
3612
3613        for val in [-100, 0, 15]:
3614            for types in "efdFDBbhil?":
3615                for T in types:
3616                    if T not in unchecked_types:
3617                        if val < 0 and np.dtype(T).kind == "u":
3618                            val = np.iinfo(T).max - 99
3619                        self.tst_basic(x.copy().astype(T), T, mask, val)
3620
3621            # Also test string of a length which uses an untypical length
3622            dt = np.dtype("S3")
3623            self.tst_basic(x.astype(dt), dt.type, mask, dt.type(val)[:3])
3624
3625    def test_mask_size(self):
3626        assert_raises(ValueError, np.putmask, np.array([1, 2, 3]), [True], 5)
3627
3628    @parametrize("greater", (True, False))
3629    def test_byteorder(self, greater):
3630        dtype = ">i4" if greater else "<i4"
3631        x = np.array([1, 2, 3], dtype)
3632        np.putmask(x, [True, False, True], -1)
3633        assert_array_equal(x, [-1, 2, -1])
3634
3635    def test_record_array(self):
3636        # Note mixed byteorder.
3637        rec = np.array(
3638            [(-5, 2.0, 3.0), (5.0, 4.0, 3.0)],
3639            dtype=[("x", "<f8"), ("y", ">f8"), ("z", "<f8")],
3640        )
3641        np.putmask(rec["x"], [True, False], 10)
3642        assert_array_equal(rec["x"], [10, 5])
3643        assert_array_equal(rec["y"], [2, 4])
3644        assert_array_equal(rec["z"], [3, 3])
3645        np.putmask(rec["y"], [True, False], 11)
3646        assert_array_equal(rec["x"], [10, 5])
3647        assert_array_equal(rec["y"], [11, 4])
3648        assert_array_equal(rec["z"], [3, 3])
3649
3650    def test_overlaps(self):
3651        # gh-6272 check overlap
3652        x = np.array([True, False, True, False])
3653        np.putmask(x[1:4], [True, True, True], x[:3])
3654        assert_equal(x, np.array([True, True, False, True]))
3655
3656        x = np.array([True, False, True, False])
3657        np.putmask(x[1:4], x[:3], [True, False, True])
3658        assert_equal(x, np.array([True, True, True, True]))
3659
3660    def test_writeable(self):
3661        a = np.arange(5)
3662        a.flags.writeable = False
3663
3664        with pytest.raises(ValueError):
3665            np.putmask(a, a >= 2, 3)
3666
3667    def test_kwargs(self):
3668        x = np.array([0, 0])
3669        np.putmask(x, [0, 1], [-1, -2])
3670        assert_array_equal(x, [0, -2])
3671
3672        x = np.array([0, 0])
3673        np.putmask(x, mask=[0, 1], values=[-1, -2])
3674        assert_array_equal(x, [0, -2])
3675
3676        x = np.array([0, 0])
3677        np.putmask(x, values=[-1, -2], mask=[0, 1])
3678        assert_array_equal(x, [0, -2])
3679
3680        with pytest.raises(TypeError):
3681            np.putmask(a=x, values=[-1, -2], mask=[0, 1])
3682
3683
3684@instantiate_parametrized_tests
3685class TestTake(TestCase):
3686    def tst_basic(self, x):
3687        ind = list(range(x.shape[0]))
3688        assert_array_equal(np.take(x, ind, axis=0), x)
3689
3690    def test_ip_types(self):
3691        x = np.random.random(24) * 100
3692        x = np.reshape(x, (2, 3, 4))
3693        for types in "efdFDBbhil?":
3694            for T in types:
3695                self.tst_basic(x.copy().astype(T))
3696
3697    def test_raise(self):
3698        x = np.random.random(24) * 100
3699        x = np.reshape(x, (2, 3, 4))
3700        assert_raises(IndexError, np.take, x, [0, 1, 2], axis=0)
3701        assert_raises(IndexError, np.take, x, [-3], axis=0)
3702        assert_array_equal(np.take(x, [-1], axis=0)[0], x[1])
3703
3704    @xpassIfTorchDynamo  # (reason="XXX: take(..., mode='clip')")
3705    def test_clip(self):
3706        x = np.random.random(24) * 100
3707        x = np.reshape(x, (2, 3, 4))
3708        assert_array_equal(np.take(x, [-1], axis=0, mode="clip")[0], x[0])
3709        assert_array_equal(np.take(x, [2], axis=0, mode="clip")[0], x[1])
3710
3711    @xpassIfTorchDynamo  # (reason="XXX: take(..., mode='wrap')")
3712    def test_wrap(self):
3713        x = np.random.random(24) * 100
3714        x = np.reshape(x, (2, 3, 4))
3715        assert_array_equal(np.take(x, [-1], axis=0, mode="wrap")[0], x[1])
3716        assert_array_equal(np.take(x, [2], axis=0, mode="wrap")[0], x[0])
3717        assert_array_equal(np.take(x, [3], axis=0, mode="wrap")[0], x[1])
3718
3719    @xpassIfTorchDynamo  # (reason="XXX: take(mode='wrap')")
3720    def test_out_overlap(self):
3721        # gh-6272 check overlap on out
3722        x = np.arange(5)
3723        y = np.take(x, [1, 2, 3], out=x[2:5], mode="wrap")
3724        assert_equal(y, np.array([1, 2, 3]))
3725
3726    @parametrize("shape", [(1, 2), (1,), ()])
3727    def test_ret_is_out(self, shape):
3728        # 0d arrays should not be an exception to this rule
3729        x = np.arange(5)
3730        inds = np.zeros(shape, dtype=np.intp)
3731        out = np.zeros(shape, dtype=x.dtype)
3732        ret = np.take(x, inds, out=out)
3733        assert ret is out
3734
3735
3736@xpassIfTorchDynamo  # (reason="TODO")
3737@instantiate_parametrized_tests
3738class TestLexsort(TestCase):
3739    @parametrize(
3740        "dtype",
3741        [
3742            np.uint8,
3743            np.int8,
3744            np.int16,
3745            np.int32,
3746            np.int64,
3747            np.float16,
3748            np.float32,
3749            np.float64,
3750        ],
3751    )
3752    def test_basic(self, dtype):
3753        a = np.array([1, 2, 1, 3, 1, 5], dtype=dtype)
3754        b = np.array([0, 4, 5, 6, 2, 3], dtype=dtype)
3755        idx = np.lexsort((b, a))
3756        expected_idx = np.array([0, 4, 2, 1, 3, 5])
3757        assert_array_equal(idx, expected_idx)
3758        assert_array_equal(a[idx], np.sort(a))
3759
3760    def test_mixed(self):
3761        a = np.array([1, 2, 1, 3, 1, 5])
3762        b = np.array([0, 4, 5, 6, 2, 3], dtype="datetime64[D]")
3763
3764        idx = np.lexsort((b, a))
3765        expected_idx = np.array([0, 4, 2, 1, 3, 5])
3766        assert_array_equal(idx, expected_idx)
3767
3768    def test_datetime(self):
3769        a = np.array([0, 0, 0], dtype="datetime64[D]")
3770        b = np.array([2, 1, 0], dtype="datetime64[D]")
3771        idx = np.lexsort((b, a))
3772        expected_idx = np.array([2, 1, 0])
3773        assert_array_equal(idx, expected_idx)
3774
3775        a = np.array([0, 0, 0], dtype="timedelta64[D]")
3776        b = np.array([2, 1, 0], dtype="timedelta64[D]")
3777        idx = np.lexsort((b, a))
3778        expected_idx = np.array([2, 1, 0])
3779        assert_array_equal(idx, expected_idx)
3780
3781    def test_object(self):  # gh-6312
3782        a = np.random.choice(10, 1000)
3783        b = np.random.choice(["abc", "xy", "wz", "efghi", "qwst", "x"], 1000)
3784
3785        for u in a, b:
3786            left = np.lexsort((u.astype("O"),))
3787            right = np.argsort(u, kind="mergesort")
3788            assert_array_equal(left, right)
3789
3790        for u, v in (a, b), (b, a):
3791            idx = np.lexsort((u, v))
3792            assert_array_equal(idx, np.lexsort((u.astype("O"), v)))
3793            assert_array_equal(idx, np.lexsort((u, v.astype("O"))))
3794            u, v = np.array(u, dtype="object"), np.array(v, dtype="object")
3795            assert_array_equal(idx, np.lexsort((u, v)))
3796
3797    def test_invalid_axis(self):  # gh-7528
3798        x = np.linspace(0.0, 1.0, 42 * 3).reshape(42, 3)
3799        assert_raises(np.AxisError, np.lexsort, x, axis=2)
3800
3801
3802@skip(reason="dont worry about IO")
3803class TestIO(TestCase):
3804    """Test tofile, fromfile, tobytes, and fromstring"""
3805
3806    @pytest.fixture()
3807    def x(self):
3808        shape = (2, 4, 3)
3809        rand = np.random.random
3810        x = rand(shape) + rand(shape).astype(complex) * 1j
3811        x[0, :, 1] = [np.nan, np.inf, -np.inf, np.nan]
3812        return x
3813
3814    @pytest.fixture(params=["string", "path_obj"])
3815    def tmp_filename(self, tmp_path, request):
3816        # This fixture covers two cases:
3817        # one where the filename is a string and
3818        # another where it is a pathlib object
3819        filename = tmp_path / "file"
3820        if request.param == "string":
3821            filename = str(filename)
3822        return filename
3823
3824    def test_nofile(self):
3825        # this should probably be supported as a file
3826        # but for now test for proper errors
3827        b = io.BytesIO()
3828        assert_raises(OSError, np.fromfile, b, np.uint8, 80)
3829        d = np.ones(7)
3830        assert_raises(OSError, lambda x: x.tofile(b), d)
3831
3832    def test_bool_fromstring(self):
3833        v = np.array([True, False, True, False], dtype=np.bool_)
3834        y = np.fromstring("1 0 -2.3 0.0", sep=" ", dtype=np.bool_)
3835        assert_array_equal(v, y)
3836
3837    def test_uint64_fromstring(self):
3838        d = np.fromstring(
3839            "9923372036854775807 104783749223640", dtype=np.uint64, sep=" "
3840        )
3841        e = np.array([9923372036854775807, 104783749223640], dtype=np.uint64)
3842        assert_array_equal(d, e)
3843
3844    def test_int64_fromstring(self):
3845        d = np.fromstring("-25041670086757 104783749223640", dtype=np.int64, sep=" ")
3846        e = np.array([-25041670086757, 104783749223640], dtype=np.int64)
3847        assert_array_equal(d, e)
3848
3849    def test_fromstring_count0(self):
3850        d = np.fromstring("1,2", sep=",", dtype=np.int64, count=0)
3851        assert d.shape == (0,)
3852
3853    def test_empty_files_text(self, tmp_filename):
3854        with open(tmp_filename, "w") as f:
3855            pass
3856        y = np.fromfile(tmp_filename)
3857        assert_(y.size == 0, "Array not empty")
3858
3859    def test_empty_files_binary(self, tmp_filename):
3860        with open(tmp_filename, "wb") as f:
3861            pass
3862        y = np.fromfile(tmp_filename, sep=" ")
3863        assert_(y.size == 0, "Array not empty")
3864
3865    def test_roundtrip_file(self, x, tmp_filename):
3866        with open(tmp_filename, "wb") as f:
3867            x.tofile(f)
3868        # NB. doesn't work with flush+seek, due to use of C stdio
3869        with open(tmp_filename, "rb") as f:
3870            y = np.fromfile(f, dtype=x.dtype)
3871        assert_array_equal(y, x.flat)
3872
3873    def test_roundtrip(self, x, tmp_filename):
3874        x.tofile(tmp_filename)
3875        y = np.fromfile(tmp_filename, dtype=x.dtype)
3876        assert_array_equal(y, x.flat)
3877
3878    def test_roundtrip_dump_pathlib(self, x, tmp_filename):
3879        p = Path(tmp_filename)
3880        x.dump(p)
3881        y = np.load(p, allow_pickle=True)
3882        assert_array_equal(y, x)
3883
3884    def test_roundtrip_binary_str(self, x):
3885        s = x.tobytes()
3886        y = np.frombuffer(s, dtype=x.dtype)
3887        assert_array_equal(y, x.flat)
3888
3889        s = x.tobytes("F")
3890        y = np.frombuffer(s, dtype=x.dtype)
3891        assert_array_equal(y, x.flatten("F"))
3892
3893    def test_roundtrip_str(self, x):
3894        x = x.real.ravel()
3895        s = "@".join(map(str, x))
3896        y = np.fromstring(s, sep="@")
3897        # NB. str imbues less precision
3898        nan_mask = ~np.isfinite(x)
3899        assert_array_equal(x[nan_mask], y[nan_mask])
3900        assert_array_almost_equal(x[~nan_mask], y[~nan_mask], decimal=5)
3901
3902    def test_roundtrip_repr(self, x):
3903        x = x.real.ravel()
3904        s = "@".join(map(repr, x))
3905        y = np.fromstring(s, sep="@")
3906        assert_array_equal(x, y)
3907
3908    def test_unseekable_fromfile(self, x, tmp_filename):
3909        # gh-6246
3910        x.tofile(tmp_filename)
3911
3912        def fail(*args, **kwargs):
3913            raise OSError("Can not tell or seek")
3914
3915        with open(tmp_filename, "rb", buffering=0) as f:
3916            f.seek = fail
3917            f.tell = fail
3918            assert_raises(OSError, np.fromfile, f, dtype=x.dtype)
3919
3920    def test_io_open_unbuffered_fromfile(self, x, tmp_filename):
3921        # gh-6632
3922        x.tofile(tmp_filename)
3923        with open(tmp_filename, "rb", buffering=0) as f:
3924            y = np.fromfile(f, dtype=x.dtype)
3925            assert_array_equal(y, x.flat)
3926
3927    def test_largish_file(self, tmp_filename):
3928        # check the fallocate path on files > 16MB
3929        d = np.zeros(4 * 1024**2)
3930        d.tofile(tmp_filename)
3931        assert_equal(os.path.getsize(tmp_filename), d.nbytes)
3932        assert_array_equal(d, np.fromfile(tmp_filename))
3933        # check offset
3934        with open(tmp_filename, "r+b") as f:
3935            f.seek(d.nbytes)
3936            d.tofile(f)
3937            assert_equal(os.path.getsize(tmp_filename), d.nbytes * 2)
3938        # check append mode (gh-8329)
3939        open(tmp_filename, "w").close()  # delete file contents
3940        with open(tmp_filename, "ab") as f:
3941            d.tofile(f)
3942        assert_array_equal(d, np.fromfile(tmp_filename))
3943        with open(tmp_filename, "ab") as f:
3944            d.tofile(f)
3945        assert_equal(os.path.getsize(tmp_filename), d.nbytes * 2)
3946
3947    def test_io_open_buffered_fromfile(self, x, tmp_filename):
3948        # gh-6632
3949        x.tofile(tmp_filename)
3950        with open(tmp_filename, "rb", buffering=-1) as f:
3951            y = np.fromfile(f, dtype=x.dtype)
3952        assert_array_equal(y, x.flat)
3953
3954    def test_file_position_after_fromfile(self, tmp_filename):
3955        # gh-4118
3956        sizes = [
3957            io.DEFAULT_BUFFER_SIZE // 8,
3958            io.DEFAULT_BUFFER_SIZE,
3959            io.DEFAULT_BUFFER_SIZE * 8,
3960        ]
3961
3962        for size in sizes:
3963            with open(tmp_filename, "wb") as f:
3964                f.seek(size - 1)
3965                f.write(b"\0")
3966
3967            for mode in ["rb", "r+b"]:
3968                err_msg = "%d %s" % (size, mode)
3969
3970                with open(tmp_filename, mode) as f:
3971                    f.read(2)
3972                    np.fromfile(f, dtype=np.float64, count=1)
3973                    pos = f.tell()
3974                assert_equal(pos, 10, err_msg=err_msg)
3975
3976    def test_file_position_after_tofile(self, tmp_filename):
3977        # gh-4118
3978        sizes = [
3979            io.DEFAULT_BUFFER_SIZE // 8,
3980            io.DEFAULT_BUFFER_SIZE,
3981            io.DEFAULT_BUFFER_SIZE * 8,
3982        ]
3983
3984        for size in sizes:
3985            err_msg = "%d" % (size,)
3986
3987            with open(tmp_filename, "wb") as f:
3988                f.seek(size - 1)
3989                f.write(b"\0")
3990                f.seek(10)
3991                f.write(b"12")
3992                np.array([0], dtype=np.float64).tofile(f)
3993                pos = f.tell()
3994            assert_equal(pos, 10 + 2 + 8, err_msg=err_msg)
3995
3996            with open(tmp_filename, "r+b") as f:
3997                f.read(2)
3998                f.seek(0, 1)  # seek between read&write required by ANSI C
3999                np.array([0], dtype=np.float64).tofile(f)
4000                pos = f.tell()
4001            assert_equal(pos, 10, err_msg=err_msg)
4002
4003    def test_load_object_array_fromfile(self, tmp_filename):
4004        # gh-12300
4005        with open(tmp_filename, "w") as f:
4006            # Ensure we have a file with consistent contents
4007            pass
4008
4009        with open(tmp_filename, "rb") as f:
4010            assert_raises_regex(
4011                ValueError,
4012                "Cannot read into object array",
4013                np.fromfile,
4014                f,
4015                dtype=object,
4016            )
4017
4018        assert_raises_regex(
4019            ValueError,
4020            "Cannot read into object array",
4021            np.fromfile,
4022            tmp_filename,
4023            dtype=object,
4024        )
4025
4026    def test_fromfile_offset(self, x, tmp_filename):
4027        with open(tmp_filename, "wb") as f:
4028            x.tofile(f)
4029
4030        with open(tmp_filename, "rb") as f:
4031            y = np.fromfile(f, dtype=x.dtype, offset=0)
4032            assert_array_equal(y, x.flat)
4033
4034        with open(tmp_filename, "rb") as f:
4035            count_items = len(x.flat) // 8
4036            offset_items = len(x.flat) // 4
4037            offset_bytes = x.dtype.itemsize * offset_items
4038            y = np.fromfile(f, dtype=x.dtype, count=count_items, offset=offset_bytes)
4039            assert_array_equal(y, x.flat[offset_items : offset_items + count_items])
4040
4041            # subsequent seeks should stack
4042            offset_bytes = x.dtype.itemsize
4043            z = np.fromfile(f, dtype=x.dtype, offset=offset_bytes)
4044            assert_array_equal(z, x.flat[offset_items + count_items + 1 :])
4045
4046        with open(tmp_filename, "wb") as f:
4047            x.tofile(f, sep=",")
4048
4049        with open(tmp_filename, "rb") as f:
4050            assert_raises_regex(
4051                TypeError,
4052                "'offset' argument only permitted for binary files",
4053                np.fromfile,
4054                tmp_filename,
4055                dtype=x.dtype,
4056                sep=",",
4057                offset=1,
4058            )
4059
4060    @skipif(IS_PYPY, reason="bug in PyPy's PyNumber_AsSsize_t")
4061    def test_fromfile_bad_dup(self, x, tmp_filename):
4062        def dup_str(fd):
4063            return "abc"
4064
4065        def dup_bigint(fd):
4066            return 2**68
4067
4068        old_dup = os.dup
4069        try:
4070            with open(tmp_filename, "wb") as f:
4071                x.tofile(f)
4072                for dup, exc in ((dup_str, TypeError), (dup_bigint, OSError)):
4073                    os.dup = dup
4074                    assert_raises(exc, np.fromfile, f)
4075        finally:
4076            os.dup = old_dup
4077
4078    def _check_from(self, s, value, filename, **kw):
4079        if "sep" not in kw:
4080            y = np.frombuffer(s, **kw)
4081        else:
4082            y = np.fromstring(s, **kw)
4083        assert_array_equal(y, value)
4084
4085        with open(filename, "wb") as f:
4086            f.write(s)
4087        y = np.fromfile(filename, **kw)
4088        assert_array_equal(y, value)
4089
4090    @pytest.fixture(params=["period", "comma"])
4091    def decimal_sep_localization(self, request):
4092        """
4093        Including this fixture in a test will automatically
4094        execute it with both types of decimal separator.
4095
4096        So::
4097
4098            def test_decimal(decimal_sep_localization):
4099                pass
4100
4101        is equivalent to the following two tests::
4102
4103            def test_decimal_period_separator():
4104                pass
4105
4106            def test_decimal_comma_separator():
4107                with CommaDecimalPointLocale():
4108                    pass
4109        """
4110        if request.param == "period":
4111            yield
4112        elif request.param == "comma":
4113            with CommaDecimalPointLocale():
4114                yield
4115        else:
4116            raise AssertionError(request.param)
4117
4118    def test_nan(self, tmp_filename, decimal_sep_localization):
4119        self._check_from(
4120            b"nan +nan -nan NaN nan(foo) +NaN(BAR) -NAN(q_u_u_x_)",
4121            [np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan],
4122            tmp_filename,
4123            sep=" ",
4124        )
4125
4126    def test_inf(self, tmp_filename, decimal_sep_localization):
4127        self._check_from(
4128            b"inf +inf -inf infinity -Infinity iNfInItY -inF",
4129            [np.inf, np.inf, -np.inf, np.inf, -np.inf, np.inf, -np.inf],
4130            tmp_filename,
4131            sep=" ",
4132        )
4133
4134    def test_numbers(self, tmp_filename, decimal_sep_localization):
4135        self._check_from(
4136            b"1.234 -1.234 .3 .3e55 -123133.1231e+133",
4137            [1.234, -1.234, 0.3, 0.3e55, -123133.1231e133],
4138            tmp_filename,
4139            sep=" ",
4140        )
4141
4142    def test_binary(self, tmp_filename):
4143        self._check_from(
4144            b"\x00\x00\x80?\x00\x00\x00@\x00\x00@@\x00\x00\x80@",
4145            np.array([1, 2, 3, 4]),
4146            tmp_filename,
4147            dtype="<f4",
4148        )
4149
4150    @slow  # takes > 1 minute on mechanical hard drive
4151    def test_big_binary(self):
4152        """Test workarounds for 32-bit limit for MSVC fwrite, fseek, and ftell
4153
4154        These normally would hang doing something like this.
4155        See : https://github.com/numpy/numpy/issues/2256
4156        """
4157        if sys.platform != "win32" or "[GCC " in sys.version:
4158            return
4159        try:
4160            # before workarounds, only up to 2**32-1 worked
4161            fourgbplus = 2**32 + 2**16
4162            testbytes = np.arange(8, dtype=np.int8)
4163            n = len(testbytes)
4164            flike = tempfile.NamedTemporaryFile()
4165            f = flike.file
4166            np.tile(testbytes, fourgbplus // testbytes.nbytes).tofile(f)
4167            flike.seek(0)
4168            a = np.fromfile(f, dtype=np.int8)
4169            flike.close()
4170            assert_(len(a) == fourgbplus)
4171            # check only start and end for speed:
4172            assert_((a[:n] == testbytes).all())
4173            assert_((a[-n:] == testbytes).all())
4174        except (MemoryError, ValueError):
4175            pass
4176
4177    def test_string(self, tmp_filename):
4178        self._check_from(b"1,2,3,4", [1.0, 2.0, 3.0, 4.0], tmp_filename, sep=",")
4179
4180    def test_counted_string(self, tmp_filename, decimal_sep_localization):
4181        self._check_from(
4182            b"1,2,3,4", [1.0, 2.0, 3.0, 4.0], tmp_filename, count=4, sep=","
4183        )
4184        self._check_from(b"1,2,3,4", [1.0, 2.0, 3.0], tmp_filename, count=3, sep=",")
4185        self._check_from(
4186            b"1,2,3,4", [1.0, 2.0, 3.0, 4.0], tmp_filename, count=-1, sep=","
4187        )
4188
4189    def test_string_with_ws(self, tmp_filename):
4190        self._check_from(
4191            b"1 2  3     4   ", [1, 2, 3, 4], tmp_filename, dtype=int, sep=" "
4192        )
4193
4194    def test_counted_string_with_ws(self, tmp_filename):
4195        self._check_from(
4196            b"1 2  3     4   ", [1, 2, 3], tmp_filename, count=3, dtype=int, sep=" "
4197        )
4198
4199    def test_ascii(self, tmp_filename, decimal_sep_localization):
4200        self._check_from(b"1 , 2 , 3 , 4", [1.0, 2.0, 3.0, 4.0], tmp_filename, sep=",")
4201        self._check_from(
4202            b"1,2,3,4", [1.0, 2.0, 3.0, 4.0], tmp_filename, dtype=float, sep=","
4203        )
4204
4205    def test_malformed(self, tmp_filename, decimal_sep_localization):
4206        with assert_warns(DeprecationWarning):
4207            self._check_from(b"1.234 1,234", [1.234, 1.0], tmp_filename, sep=" ")
4208
4209    def test_long_sep(self, tmp_filename):
4210        self._check_from(b"1_x_3_x_4_x_5", [1, 3, 4, 5], tmp_filename, sep="_x_")
4211
4212    def test_dtype(self, tmp_filename):
4213        v = np.array([1, 2, 3, 4], dtype=np.int_)
4214        self._check_from(b"1,2,3,4", v, tmp_filename, sep=",", dtype=np.int_)
4215
4216    def test_dtype_bool(self, tmp_filename):
4217        # can't use _check_from because fromstring can't handle True/False
4218        v = np.array([True, False, True, False], dtype=np.bool_)
4219        s = b"1,0,-2.3,0"
4220        with open(tmp_filename, "wb") as f:
4221            f.write(s)
4222        y = np.fromfile(tmp_filename, sep=",", dtype=np.bool_)
4223        assert_(y.dtype == "?")
4224        assert_array_equal(y, v)
4225
4226    def test_tofile_sep(self, tmp_filename, decimal_sep_localization):
4227        x = np.array([1.51, 2, 3.51, 4], dtype=float)
4228        with open(tmp_filename, "w") as f:
4229            x.tofile(f, sep=",")
4230        with open(tmp_filename) as f:
4231            s = f.read()
4232        # assert_equal(s, '1.51,2.0,3.51,4.0')
4233        y = np.array([float(p) for p in s.split(",")])
4234        assert_array_equal(x, y)
4235
4236    def test_tofile_format(self, tmp_filename, decimal_sep_localization):
4237        x = np.array([1.51, 2, 3.51, 4], dtype=float)
4238        with open(tmp_filename, "w") as f:
4239            x.tofile(f, sep=",", format="%.2f")
4240        with open(tmp_filename) as f:
4241            s = f.read()
4242        assert_equal(s, "1.51,2.00,3.51,4.00")
4243
4244    def test_tofile_cleanup(self, tmp_filename):
4245        x = np.zeros((10), dtype=object)
4246        with open(tmp_filename, "wb") as f:
4247            assert_raises(OSError, lambda: x.tofile(f, sep=""))
4248        # Dup-ed file handle should be closed or remove will fail on Windows OS
4249        os.remove(tmp_filename)
4250
4251        # Also make sure that we close the Python handle
4252        assert_raises(OSError, lambda: x.tofile(tmp_filename))
4253        os.remove(tmp_filename)
4254
4255    def test_fromfile_subarray_binary(self, tmp_filename):
4256        # Test subarray dtypes which are absorbed into the shape
4257        x = np.arange(24, dtype="i4").reshape(2, 3, 4)
4258        x.tofile(tmp_filename)
4259        res = np.fromfile(tmp_filename, dtype="(3,4)i4")
4260        assert_array_equal(x, res)
4261
4262        x_str = x.tobytes()
4263        with assert_warns(DeprecationWarning):
4264            # binary fromstring is deprecated
4265            res = np.fromstring(x_str, dtype="(3,4)i4")
4266            assert_array_equal(x, res)
4267
4268    def test_parsing_subarray_unsupported(self, tmp_filename):
4269        # We currently do not support parsing subarray dtypes
4270        data = "12,42,13," * 50
4271        with pytest.raises(ValueError):
4272            expected = np.fromstring(data, dtype="(3,)i", sep=",")
4273
4274        with open(tmp_filename, "w") as f:
4275            f.write(data)
4276
4277        with pytest.raises(ValueError):
4278            np.fromfile(tmp_filename, dtype="(3,)i", sep=",")
4279
4280    def test_read_shorter_than_count_subarray(self, tmp_filename):
4281        # Test that requesting more values does not cause any problems
4282        # in conjunction with subarray dimensions being absorbed into the
4283        # array dimension.
4284        expected = np.arange(511 * 10, dtype="i").reshape(-1, 10)
4285
4286        binary = expected.tobytes()
4287        with pytest.raises(ValueError):
4288            with pytest.warns(DeprecationWarning):
4289                np.fromstring(binary, dtype="(10,)i", count=10000)
4290
4291        expected.tofile(tmp_filename)
4292        res = np.fromfile(tmp_filename, dtype="(10,)i", count=10000)
4293        assert_array_equal(res, expected)
4294
4295
4296@xpassIfTorchDynamo  # (reason="TODO")
4297@instantiate_parametrized_tests
4298class TestFromBuffer(TestCase):
4299    @parametrize(
4300        "byteorder", [subtest("little", name="little"), subtest("big", name="big")]
4301    )
4302    @parametrize("dtype", [float, int, complex])
4303    def test_basic(self, byteorder, dtype):
4304        dt = np.dtype(dtype).newbyteorder(byteorder)
4305        x = (np.random.random((4, 7)) * 5).astype(dt)
4306        buf = x.tobytes()
4307        assert_array_equal(np.frombuffer(buf, dtype=dt), x.flat)
4308
4309    #    @xpassIfTorchDynamo
4310    @parametrize(
4311        "obj", [np.arange(10), subtest("12345678", decorators=[xfailIfTorchDynamo])]
4312    )
4313    def test_array_base(self, obj):
4314        # Objects (including NumPy arrays), which do not use the
4315        # `release_buffer` slot should be directly used as a base object.
4316        # See also gh-21612
4317        if isinstance(obj, str):
4318            # @parametrize breaks with bytes objects
4319            obj = bytes(obj, enconding="latin-1")
4320        new = np.frombuffer(obj)
4321        assert new.base is obj
4322
4323    def test_empty(self):
4324        assert_array_equal(np.frombuffer(b""), np.array([]))
4325
4326    @skip("fails on CI, we are unlikely to implement this")
4327    @skipif(
4328        IS_PYPY,
4329        reason="PyPy's memoryview currently does not track exports. See: "
4330        "https://foss.heptapod.net/pypy/pypy/-/issues/3724",
4331    )
4332    def test_mmap_close(self):
4333        # The old buffer protocol was not safe for some things that the new
4334        # one is.  But `frombuffer` always used the old one for a long time.
4335        # Checks that it is safe with the new one (using memoryviews)
4336        with tempfile.TemporaryFile(mode="wb") as tmp:
4337            tmp.write(b"asdf")
4338            tmp.flush()
4339            mm = mmap.mmap(tmp.fileno(), 0)
4340            arr = np.frombuffer(mm, dtype=np.uint8)
4341            with pytest.raises(BufferError):
4342                mm.close()  # cannot close while array uses the buffer
4343            del arr
4344            mm.close()
4345
4346
4347@skip  # (reason="TODO")   # FIXME: skip -> xfail (a0.shape = (4, 5) raises)
4348class TestFlat(TestCase):
4349    def setUp(self):
4350        a0 = np.arange(20.0)
4351        a = a0.reshape(4, 5)
4352        a0.shape = (4, 5)
4353        a.flags.writeable = False
4354        self.a = a
4355        self.b = a[::2, ::2]
4356        self.a0 = a0
4357        self.b0 = a0[::2, ::2]
4358
4359    def test_contiguous(self):
4360        testpassed = False
4361        try:
4362            self.a.flat[12] = 100.0
4363        except ValueError:
4364            testpassed = True
4365        assert_(testpassed)
4366        assert_(self.a.flat[12] == 12.0)
4367
4368    def test_discontiguous(self):
4369        testpassed = False
4370        try:
4371            self.b.flat[4] = 100.0
4372        except ValueError:
4373            testpassed = True
4374        assert_(testpassed)
4375        assert_(self.b.flat[4] == 12.0)
4376
4377    def test___array__(self):
4378        c = self.a.flat.__array__()
4379        d = self.b.flat.__array__()
4380        e = self.a0.flat.__array__()
4381        f = self.b0.flat.__array__()
4382
4383        assert_(c.flags.writeable is False)
4384        assert_(d.flags.writeable is False)
4385        assert_(e.flags.writeable is True)
4386        assert_(f.flags.writeable is False)
4387        assert_(c.flags.writebackifcopy is False)
4388        assert_(d.flags.writebackifcopy is False)
4389        assert_(e.flags.writebackifcopy is False)
4390        assert_(f.flags.writebackifcopy is False)
4391
4392    @skipif(not HAS_REFCOUNT, reason="Python lacks refcounts")
4393    def test_refcount(self):
4394        # includes regression test for reference count error gh-13165
4395        inds = [np.intp(0), np.array([True] * self.a.size), np.array([0]), None]
4396        indtype = np.dtype(np.intp)
4397        rc_indtype = sys.getrefcount(indtype)
4398        for ind in inds:
4399            rc_ind = sys.getrefcount(ind)
4400            for _ in range(100):
4401                try:
4402                    self.a.flat[ind]
4403                except IndexError:
4404                    pass
4405            assert_(abs(sys.getrefcount(ind) - rc_ind) < 50)
4406            assert_(abs(sys.getrefcount(indtype) - rc_indtype) < 50)
4407
4408    def test_index_getset(self):
4409        it = np.arange(10).reshape(2, 1, 5).flat
4410        with pytest.raises(AttributeError):
4411            it.index = 10
4412
4413        for _ in it:
4414            pass
4415        # Check the value of `.index` is updated correctly (see also gh-19153)
4416        # If the type was incorrect, this would show up on big-endian machines
4417        assert it.index == it.base.size
4418
4419
4420class TestResize(TestCase):
4421    @_no_tracing
4422    def test_basic(self):
4423        x = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
4424        if IS_PYPY:
4425            x.resize((5, 5), refcheck=False)
4426        else:
4427            x.resize((5, 5))
4428        assert_array_equal(
4429            x.ravel()[:9], np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]).ravel()
4430        )
4431        assert_array_equal(x[9:].ravel(), 0)
4432
4433    @skip(reason="how to find if someone is refencing an array")
4434    def test_check_reference(self):
4435        x = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
4436        y = x
4437        assert_raises(ValueError, x.resize, (5, 1))
4438        del y  # avoid pyflakes unused variable warning.
4439
4440    @_no_tracing
4441    def test_int_shape(self):
4442        x = np.eye(3)
4443        if IS_PYPY:
4444            x.resize(3, refcheck=False)
4445        else:
4446            x.resize(3)
4447        assert_array_equal(x, np.eye(3)[0, :])
4448
4449    def test_none_shape(self):
4450        x = np.eye(3)
4451        x.resize(None)
4452        assert_array_equal(x, np.eye(3))
4453        x.resize()
4454        assert_array_equal(x, np.eye(3))
4455
4456    def test_0d_shape(self):
4457        # to it multiple times to test it does not break alloc cache gh-9216
4458        for i in range(10):
4459            x = np.empty((1,))
4460            x.resize(())
4461            assert_equal(x.shape, ())
4462            assert_equal(x.size, 1)
4463            x = np.empty(())
4464            x.resize((1,))
4465            assert_equal(x.shape, (1,))
4466            assert_equal(x.size, 1)
4467
4468    def test_invalid_arguments(self):
4469        assert_raises(TypeError, np.eye(3).resize, "hi")
4470        assert_raises(ValueError, np.eye(3).resize, -1)
4471        assert_raises(TypeError, np.eye(3).resize, order=1)
4472        assert_raises((NotImplementedError, TypeError), np.eye(3).resize, refcheck="hi")
4473
4474    @_no_tracing
4475    def test_freeform_shape(self):
4476        x = np.eye(3)
4477        if IS_PYPY:
4478            x.resize(3, 2, 1, refcheck=False)
4479        else:
4480            x.resize(3, 2, 1)
4481        assert_(x.shape == (3, 2, 1))
4482
4483    @_no_tracing
4484    def test_zeros_appended(self):
4485        x = np.eye(3)
4486        if IS_PYPY:
4487            x.resize(2, 3, 3, refcheck=False)
4488        else:
4489            x.resize(2, 3, 3)
4490        assert_array_equal(x[0], np.eye(3))
4491        assert_array_equal(x[1], np.zeros((3, 3)))
4492
4493    def test_empty_view(self):
4494        # check that sizes containing a zero don't trigger a reallocate for
4495        # already empty arrays
4496        x = np.zeros((10, 0), int)
4497        x_view = x[...]
4498        x_view.resize((0, 10))
4499        x_view.resize((0, 100))
4500
4501    @skip(reason="ignore weakrefs for ndarray.resize")
4502    def test_check_weakref(self):
4503        x = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
4504        xref = weakref.ref(x)
4505        assert_raises(ValueError, x.resize, (5, 1))
4506        del xref  # avoid pyflakes unused variable warning.
4507
4508
4509def _mean(a, **args):
4510    return a.mean(**args)
4511
4512
4513def _var(a, **args):
4514    return a.var(**args)
4515
4516
4517def _std(a, **args):
4518    return a.std(**args)
4519
4520
4521@instantiate_parametrized_tests
4522class TestStats(TestCase):
4523    funcs = [_mean, _var, _std]
4524
4525    def setUp(self):
4526        np.random.seed(3)
4527        self.rmat = np.random.random((4, 5))
4528        self.cmat = self.rmat + 1j * self.rmat
4529
4530    def test_python_type(self):
4531        for x in (np.float16(1.0), 1, 1.0, 1 + 0j):
4532            assert_equal(np.mean([x]), 1.0)
4533            assert_equal(np.std([x]), 0.0)
4534            assert_equal(np.var([x]), 0.0)
4535
4536    def test_keepdims(self):
4537        mat = np.eye(3)
4538        for f in self.funcs:
4539            for axis in [0, 1]:
4540                res = f(mat, axis=axis, keepdims=True)
4541                assert_(res.ndim == mat.ndim)
4542                assert_(res.shape[axis] == 1)
4543            for axis in [None]:
4544                res = f(mat, axis=axis, keepdims=True)
4545                assert_(res.shape == (1, 1))
4546
4547    def test_out(self):
4548        mat = np.eye(3)
4549        for f in self.funcs:
4550            out = np.zeros(3)
4551            tgt = f(mat, axis=1)
4552            res = f(mat, axis=1, out=out)
4553            assert_almost_equal(res, out)
4554            assert_almost_equal(res, tgt)
4555        out = np.empty(2)
4556        assert_raises(ValueError, f, mat, axis=1, out=out)
4557        out = np.empty((2, 2))
4558        assert_raises(ValueError, f, mat, axis=1, out=out)
4559
4560    def test_dtype_from_input(self):
4561        icodes = np.typecodes["AllInteger"]
4562        fcodes = np.typecodes["AllFloat"]
4563
4564        # integer types
4565        for f in self.funcs:
4566            for c in icodes:
4567                mat = np.eye(3, dtype=c)
4568                tgt = np.float64
4569                res = f(mat, axis=1).dtype.type
4570                assert_(res is tgt)
4571                # scalar case
4572                res = f(mat, axis=None).dtype.type
4573                assert_(res is tgt)
4574
4575        # mean for float types
4576        for f in [_mean]:
4577            for c in fcodes:
4578                mat = np.eye(3, dtype=c)
4579                tgt = mat.dtype.type
4580                res = f(mat, axis=1).dtype.type
4581                assert_(res is tgt)
4582                # scalar case
4583                res = f(mat, axis=None).dtype.type
4584                assert_(res is tgt)
4585
4586        # var, std for float types
4587        for f in [_var, _std]:
4588            for c in fcodes:
4589                mat = np.eye(3, dtype=c)
4590                # deal with complex types
4591                tgt = mat.real.dtype.type
4592                res = f(mat, axis=1).dtype.type
4593                assert_(res is tgt)
4594                # scalar case
4595                res = f(mat, axis=None).dtype.type
4596                assert_(res is tgt)
4597
4598    def test_dtype_from_dtype(self):
4599        mat = np.eye(3)
4600
4601        # stats for integer types
4602        # FIXME:
4603        # this needs definition as there are lots places along the line
4604        # where type casting may take place.
4605
4606        # for f in self.funcs:
4607        #    for c in np.typecodes['AllInteger']:
4608        #        tgt = np.dtype(c).type
4609        #        res = f(mat, axis=1, dtype=c).dtype.type
4610        #        assert_(res is tgt)
4611        #        # scalar case
4612        #        res = f(mat, axis=None, dtype=c).dtype.type
4613        #        assert_(res is tgt)
4614
4615        # stats for float types
4616        for f in self.funcs:
4617            for c in np.typecodes["AllFloat"]:
4618                tgt = np.dtype(c).type
4619                res = f(mat, axis=1, dtype=c).dtype.type
4620                assert_(res is tgt)
4621                # scalar case
4622                res = f(mat, axis=None, dtype=c).dtype.type
4623                assert_(res is tgt)
4624
4625    def test_ddof(self):
4626        for f in [_var]:
4627            for ddof in range(3):
4628                dim = self.rmat.shape[1]
4629                tgt = f(self.rmat, axis=1) * dim
4630                res = f(self.rmat, axis=1, ddof=ddof) * (dim - ddof)
4631        for f in [_std]:
4632            for ddof in range(3):
4633                dim = self.rmat.shape[1]
4634                tgt = f(self.rmat, axis=1) * np.sqrt(dim)
4635                res = f(self.rmat, axis=1, ddof=ddof) * np.sqrt(dim - ddof)
4636                assert_almost_equal(res, tgt)
4637                assert_almost_equal(res, tgt)
4638
4639    def test_ddof_too_big(self):
4640        dim = self.rmat.shape[1]
4641        for f in [_var, _std]:
4642            for ddof in range(dim, dim + 2):
4643                #         with warnings.catch_warnings(record=True) as w:
4644                #             warnings.simplefilter('always')
4645                res = f(self.rmat, axis=1, ddof=ddof)
4646                assert_(not (res < 0).any())
4647        #            assert_(len(w) > 0)
4648        #            assert_(issubclass(w[0].category, RuntimeWarning))
4649
4650    def test_empty(self):
4651        A = np.zeros((0, 3))
4652        for f in self.funcs:
4653            for axis in [0, None]:
4654                #      with warnings.catch_warnings(record=True) as w:
4655                #          warnings.simplefilter('always')
4656                assert_(np.isnan(f(A, axis=axis)).all())
4657            #          assert_(len(w) > 0)
4658            #          assert_(issubclass(w[0].category, RuntimeWarning))
4659            for axis in [1]:
4660                #      with warnings.catch_warnings(record=True) as w:
4661                #          warnings.simplefilter('always')
4662                assert_equal(f(A, axis=axis), np.zeros([]))
4663
4664    def test_mean_values(self):
4665        for mat in [self.rmat, self.cmat]:
4666            for axis in [0, 1]:
4667                tgt = mat.sum(axis=axis)
4668                res = _mean(mat, axis=axis) * mat.shape[axis]
4669                assert_almost_equal(res, tgt)
4670            for axis in [None]:
4671                tgt = mat.sum(axis=axis)
4672                res = _mean(mat, axis=axis) * np.prod(mat.shape)
4673                assert_almost_equal(res, tgt)
4674
4675    def test_mean_float16(self):
4676        # This fail if the sum inside mean is done in float16 instead
4677        # of float32.
4678        assert_(_mean(np.ones(100000, dtype="float16")) == 1)
4679
4680    def test_mean_axis_error(self):
4681        # Ensure that AxisError is raised instead of IndexError when axis is
4682        # out of bounds, see gh-15817.
4683        with assert_raises(np.AxisError):
4684            np.arange(10).mean(axis=2)
4685
4686    @xpassIfTorchDynamo  # (reason="implement mean(..., where=...)")
4687    def test_mean_where(self):
4688        a = np.arange(16).reshape((4, 4))
4689        wh_full = np.array(
4690            [
4691                [False, True, False, True],
4692                [True, False, True, False],
4693                [True, True, False, False],
4694                [False, False, True, True],
4695            ]
4696        )
4697        wh_partial = np.array([[False], [True], [True], [False]])
4698        _cases = [
4699            (1, True, [1.5, 5.5, 9.5, 13.5]),
4700            (0, wh_full, [6.0, 5.0, 10.0, 9.0]),
4701            (1, wh_full, [2.0, 5.0, 8.5, 14.5]),
4702            (0, wh_partial, [6.0, 7.0, 8.0, 9.0]),
4703        ]
4704        for _ax, _wh, _res in _cases:
4705            assert_allclose(a.mean(axis=_ax, where=_wh), np.array(_res))
4706            assert_allclose(np.mean(a, axis=_ax, where=_wh), np.array(_res))
4707
4708        a3d = np.arange(16).reshape((2, 2, 4))
4709        _wh_partial = np.array([False, True, True, False])
4710        _res = [[1.5, 5.5], [9.5, 13.5]]
4711        assert_allclose(a3d.mean(axis=2, where=_wh_partial), np.array(_res))
4712        assert_allclose(np.mean(a3d, axis=2, where=_wh_partial), np.array(_res))
4713
4714        with pytest.warns(RuntimeWarning) as w:
4715            assert_allclose(
4716                a.mean(axis=1, where=wh_partial), np.array([np.nan, 5.5, 9.5, np.nan])
4717            )
4718        with pytest.warns(RuntimeWarning) as w:
4719            assert_equal(a.mean(where=False), np.nan)
4720        with pytest.warns(RuntimeWarning) as w:
4721            assert_equal(np.mean(a, where=False), np.nan)
4722
4723    def test_var_values(self):
4724        for mat in [self.rmat, self.cmat]:
4725            for axis in [0, 1, None]:
4726                msqr = _mean(mat * mat.conj(), axis=axis)
4727                mean = _mean(mat, axis=axis)
4728                tgt = msqr - mean * mean.conjugate()
4729                res = _var(mat, axis=axis)
4730                assert_almost_equal(res, tgt)
4731
4732    @parametrize(
4733        "complex_dtype, ndec",
4734        (
4735            ("complex64", 6),
4736            ("complex128", 7),
4737        ),
4738    )
4739    def test_var_complex_values(self, complex_dtype, ndec):
4740        # Test fast-paths for every builtin complex type
4741        for axis in [0, 1, None]:
4742            mat = self.cmat.copy().astype(complex_dtype)
4743            msqr = _mean(mat * mat.conj(), axis=axis)
4744            mean = _mean(mat, axis=axis)
4745            tgt = msqr - mean * mean.conjugate()
4746            res = _var(mat, axis=axis)
4747            assert_almost_equal(res, tgt, decimal=ndec)
4748
4749    def test_var_dimensions(self):
4750        # _var paths for complex number introduce additions on views that
4751        # increase dimensions. Ensure this generalizes to higher dims
4752        mat = np.stack([self.cmat] * 3)
4753        for axis in [0, 1, 2, -1, None]:
4754            msqr = _mean(mat * mat.conj(), axis=axis)
4755            mean = _mean(mat, axis=axis)
4756            tgt = msqr - mean * mean.conjugate()
4757            res = _var(mat, axis=axis)
4758            assert_almost_equal(res, tgt)
4759
4760    @skip(reason="endianness")
4761    def test_var_complex_byteorder(self):
4762        # Test that var fast-path does not cause failures for complex arrays
4763        # with non-native byteorder
4764        cmat = self.cmat.copy().astype("complex128")
4765        cmat_swapped = cmat.astype(cmat.dtype.newbyteorder())
4766        assert_almost_equal(cmat.var(), cmat_swapped.var())
4767
4768    def test_var_axis_error(self):
4769        # Ensure that AxisError is raised instead of IndexError when axis is
4770        # out of bounds, see gh-15817.
4771        with assert_raises(np.AxisError):
4772            np.arange(10).var(axis=2)
4773
4774    @xpassIfTorchDynamo  # (reason="implement var(..., where=...)")
4775    def test_var_where(self):
4776        a = np.arange(25).reshape((5, 5))
4777        wh_full = np.array(
4778            [
4779                [False, True, False, True, True],
4780                [True, False, True, True, False],
4781                [True, True, False, False, True],
4782                [False, True, True, False, True],
4783                [True, False, True, True, False],
4784            ]
4785        )
4786        wh_partial = np.array([[False], [True], [True], [False], [True]])
4787        _cases = [
4788            (0, True, [50.0, 50.0, 50.0, 50.0, 50.0]),
4789            (1, True, [2.0, 2.0, 2.0, 2.0, 2.0]),
4790        ]
4791        for _ax, _wh, _res in _cases:
4792            assert_allclose(a.var(axis=_ax, where=_wh), np.array(_res))
4793            assert_allclose(np.var(a, axis=_ax, where=_wh), np.array(_res))
4794
4795        a3d = np.arange(16).reshape((2, 2, 4))
4796        _wh_partial = np.array([False, True, True, False])
4797        _res = [[0.25, 0.25], [0.25, 0.25]]
4798        assert_allclose(a3d.var(axis=2, where=_wh_partial), np.array(_res))
4799        assert_allclose(np.var(a3d, axis=2, where=_wh_partial), np.array(_res))
4800
4801        assert_allclose(
4802            np.var(a, axis=1, where=wh_full), np.var(a[wh_full].reshape((5, 3)), axis=1)
4803        )
4804        assert_allclose(
4805            np.var(a, axis=0, where=wh_partial), np.var(a[wh_partial[:, 0]], axis=0)
4806        )
4807        with pytest.warns(RuntimeWarning) as w:
4808            assert_equal(a.var(where=False), np.nan)
4809        with pytest.warns(RuntimeWarning) as w:
4810            assert_equal(np.var(a, where=False), np.nan)
4811
4812    def test_std_values(self):
4813        for mat in [self.rmat, self.cmat]:
4814            for axis in [0, 1, None]:
4815                tgt = np.sqrt(_var(mat, axis=axis))
4816                res = _std(mat, axis=axis)
4817                assert_almost_equal(res, tgt)
4818
4819    @xpassIfTorchDynamo  # (reason="implement std(..., where=...)")
4820    def test_std_where(self):
4821        a = np.arange(25).reshape((5, 5))[::-1]
4822        whf = np.array(
4823            [
4824                [False, True, False, True, True],
4825                [True, False, True, False, True],
4826                [True, True, False, True, False],
4827                [True, False, True, True, False],
4828                [False, True, False, True, True],
4829            ]
4830        )
4831        whp = np.array([[False], [False], [True], [True], [False]])
4832        _cases = [
4833            (0, True, 7.07106781 * np.ones(5)),
4834            (1, True, 1.41421356 * np.ones(5)),
4835            (0, whf, np.array([4.0824829, 8.16496581, 5.0, 7.39509973, 8.49836586])),
4836            (0, whp, 2.5 * np.ones(5)),
4837        ]
4838        for _ax, _wh, _res in _cases:
4839            assert_allclose(a.std(axis=_ax, where=_wh), _res)
4840            assert_allclose(np.std(a, axis=_ax, where=_wh), _res)
4841
4842        a3d = np.arange(16).reshape((2, 2, 4))
4843        _wh_partial = np.array([False, True, True, False])
4844        _res = [[0.5, 0.5], [0.5, 0.5]]
4845        assert_allclose(a3d.std(axis=2, where=_wh_partial), np.array(_res))
4846        assert_allclose(np.std(a3d, axis=2, where=_wh_partial), np.array(_res))
4847
4848        assert_allclose(
4849            a.std(axis=1, where=whf), np.std(a[whf].reshape((5, 3)), axis=1)
4850        )
4851        assert_allclose(
4852            np.std(a, axis=1, where=whf), (a[whf].reshape((5, 3))).std(axis=1)
4853        )
4854        assert_allclose(a.std(axis=0, where=whp), np.std(a[whp[:, 0]], axis=0))
4855        assert_allclose(np.std(a, axis=0, where=whp), (a[whp[:, 0]]).std(axis=0))
4856        with pytest.warns(RuntimeWarning) as w:
4857            assert_equal(a.std(where=False), np.nan)
4858        with pytest.warns(RuntimeWarning) as w:
4859            assert_equal(np.std(a, where=False), np.nan)
4860
4861
4862class TestVdot(TestCase):
4863    def test_basic(self):
4864        dt_numeric = np.typecodes["AllFloat"] + np.typecodes["AllInteger"]
4865        dt_complex = np.typecodes["Complex"]
4866
4867        # test real
4868        a = np.eye(3)
4869        for dt in dt_numeric:
4870            b = a.astype(dt)
4871            res = np.vdot(b, b)
4872            assert_(np.isscalar(res))
4873            assert_equal(np.vdot(b, b), 3)
4874
4875        # test complex
4876        a = np.eye(3) * 1j
4877        for dt in dt_complex:
4878            b = a.astype(dt)
4879            res = np.vdot(b, b)
4880            assert_(np.isscalar(res))
4881            assert_equal(np.vdot(b, b), 3)
4882
4883        # test boolean
4884        b = np.eye(3, dtype=bool)
4885        res = np.vdot(b, b)
4886        assert_(np.isscalar(res))
4887        assert_equal(np.vdot(b, b), True)
4888
4889    @xpassIfTorchDynamo  # (reason="implement order='F'")
4890    def test_vdot_array_order(self):
4891        a = np.array([[1, 2], [3, 4]], order="C")
4892        b = np.array([[1, 2], [3, 4]], order="F")
4893        res = np.vdot(a, a)
4894
4895        # integer arrays are exact
4896        assert_equal(np.vdot(a, b), res)
4897        assert_equal(np.vdot(b, a), res)
4898        assert_equal(np.vdot(b, b), res)
4899
4900    def test_vdot_uncontiguous(self):
4901        for size in [2, 1000]:
4902            # Different sizes match different branches in vdot.
4903            a = np.zeros((size, 2, 2))
4904            b = np.zeros((size, 2, 2))
4905            a[:, 0, 0] = np.arange(size)
4906            b[:, 0, 0] = np.arange(size) + 1
4907            # Make a and b uncontiguous:
4908            a = a[..., 0]
4909            b = b[..., 0]
4910
4911            assert_equal(np.vdot(a, b), np.vdot(a.flatten(), b.flatten()))
4912            assert_equal(np.vdot(a, b.copy()), np.vdot(a.flatten(), b.flatten()))
4913            assert_equal(np.vdot(a.copy(), b), np.vdot(a.flatten(), b.flatten()))
4914
4915    @xpassIfTorchDynamo  # (reason="implement order='F'")
4916    def test_vdot_uncontiguous_2(self):
4917        # test order='F' separately
4918        for size in [2, 1000]:
4919            # Different sizes match different branches in vdot.
4920            a = np.zeros((size, 2, 2))
4921            b = np.zeros((size, 2, 2))
4922            a[:, 0, 0] = np.arange(size)
4923            b[:, 0, 0] = np.arange(size) + 1
4924            # Make a and b uncontiguous:
4925            a = a[..., 0]
4926            b = b[..., 0]
4927
4928            assert_equal(np.vdot(a.copy("F"), b), np.vdot(a.flatten(), b.flatten()))
4929            assert_equal(np.vdot(a, b.copy("F")), np.vdot(a.flatten(), b.flatten()))
4930
4931
4932@instantiate_parametrized_tests
4933class TestDot(TestCase):
4934    def setUp(self):
4935        np.random.seed(128)
4936
4937        # Numpy and pytorch random streams differ, so inline the
4938        # values from numpy 1.24.1
4939        # self.A = np.random.rand(4, 2)
4940        self.A = np.array(
4941            [
4942                [0.86663704, 0.26314485],
4943                [0.13140848, 0.04159344],
4944                [0.23892433, 0.6454746],
4945                [0.79059935, 0.60144244],
4946            ]
4947        )
4948
4949        # self.b1 = np.random.rand(2, 1)
4950        self.b1 = np.array([[0.33429937], [0.11942846]])
4951
4952        # self.b2 = np.random.rand(2)
4953        self.b2 = np.array([0.30913305, 0.10972379])
4954
4955        # self.b3 = np.random.rand(1, 2)
4956        self.b3 = np.array([[0.60211331, 0.25128496]])
4957
4958        # self.b4 = np.random.rand(4)
4959        self.b4 = np.array([0.29968129, 0.517116, 0.71520252, 0.9314494])
4960
4961        self.N = 7
4962
4963    def test_dotmatmat(self):
4964        A = self.A
4965        res = np.dot(A.transpose(), A)
4966        tgt = np.array([[1.45046013, 0.86323640], [0.86323640, 0.84934569]])
4967        assert_almost_equal(res, tgt, decimal=self.N)
4968
4969    def test_dotmatvec(self):
4970        A, b1 = self.A, self.b1
4971        res = np.dot(A, b1)
4972        tgt = np.array([[0.32114320], [0.04889721], [0.15696029], [0.33612621]])
4973        assert_almost_equal(res, tgt, decimal=self.N)
4974
4975    def test_dotmatvec2(self):
4976        A, b2 = self.A, self.b2
4977        res = np.dot(A, b2)
4978        tgt = np.array([0.29677940, 0.04518649, 0.14468333, 0.31039293])
4979        assert_almost_equal(res, tgt, decimal=self.N)
4980
4981    def test_dotvecmat(self):
4982        A, b4 = self.A, self.b4
4983        res = np.dot(b4, A)
4984        tgt = np.array([1.23495091, 1.12222648])
4985        assert_almost_equal(res, tgt, decimal=self.N)
4986
4987    def test_dotvecmat2(self):
4988        b3, A = self.b3, self.A
4989        res = np.dot(b3, A.transpose())
4990        tgt = np.array([[0.58793804, 0.08957460, 0.30605758, 0.62716383]])
4991        assert_almost_equal(res, tgt, decimal=self.N)
4992
4993    def test_dotvecmat3(self):
4994        A, b4 = self.A, self.b4
4995        res = np.dot(A.transpose(), b4)
4996        tgt = np.array([1.23495091, 1.12222648])
4997        assert_almost_equal(res, tgt, decimal=self.N)
4998
4999    def test_dotvecvecouter(self):
5000        b1, b3 = self.b1, self.b3
5001        res = np.dot(b1, b3)
5002        tgt = np.array([[0.20128610, 0.08400440], [0.07190947, 0.03001058]])
5003        assert_almost_equal(res, tgt, decimal=self.N)
5004
5005    def test_dotvecvecinner(self):
5006        b1, b3 = self.b1, self.b3
5007        res = np.dot(b3, b1)
5008        tgt = np.array([[0.23129668]])
5009        assert_almost_equal(res, tgt, decimal=self.N)
5010
5011    def test_dotcolumnvect1(self):
5012        b1 = np.ones((3, 1))
5013        b2 = [5.3]
5014        res = np.dot(b1, b2)
5015        tgt = np.array([5.3, 5.3, 5.3])
5016        assert_almost_equal(res, tgt, decimal=self.N)
5017
5018    def test_dotcolumnvect2(self):
5019        b1 = np.ones((3, 1)).transpose()
5020        b2 = [6.2]
5021        res = np.dot(b2, b1)
5022        tgt = np.array([6.2, 6.2, 6.2])
5023        assert_almost_equal(res, tgt, decimal=self.N)
5024
5025    def test_dotvecscalar(self):
5026        np.random.seed(100)
5027        # Numpy guarantees the random stream, and we don't. So inline the
5028        # values from numpy 1.24.1
5029        # b1 = np.random.rand(1, 1)
5030        b1 = np.array([[0.54340494]])
5031
5032        # b2 = np.random.rand(1, 4)
5033        b2 = np.array([[0.27836939, 0.42451759, 0.84477613, 0.00471886]])
5034
5035        res = np.dot(b1, b2)
5036        tgt = np.array([[0.15126730, 0.23068496, 0.45905553, 0.00256425]])
5037        assert_almost_equal(res, tgt, decimal=self.N)
5038
5039    def test_dotvecscalar2(self):
5040        np.random.seed(100)
5041        # b1 = np.random.rand(4, 1)
5042        b1 = np.array([[0.54340494], [0.27836939], [0.42451759], [0.84477613]])
5043
5044        # b2 = np.random.rand(1, 1)
5045        b2 = np.array([[0.00471886]])
5046
5047        res = np.dot(b1, b2)
5048        tgt = np.array([[0.00256425], [0.00131359], [0.00200324], [0.00398638]])
5049        assert_almost_equal(res, tgt, decimal=self.N)
5050
5051    def test_all(self):
5052        dims = [(), (1,), (1, 1)]
5053        dout = [(), (1,), (1, 1), (1,), (), (1,), (1, 1), (1,), (1, 1)]
5054        for dim, (dim1, dim2) in zip(dout, itertools.product(dims, dims)):
5055            b1 = np.zeros(dim1)
5056            b2 = np.zeros(dim2)
5057            res = np.dot(b1, b2)
5058            tgt = np.zeros(dim)
5059            assert_(res.shape == tgt.shape)
5060            assert_almost_equal(res, tgt, decimal=self.N)
5061
5062    @skip(reason="numpy internals")
5063    def test_dot_2args(self):
5064        from numpy.core.multiarray import dot
5065
5066        a = np.array([[1, 2], [3, 4]], dtype=float)
5067        b = np.array([[1, 0], [1, 1]], dtype=float)
5068        c = np.array([[3, 2], [7, 4]], dtype=float)
5069
5070        d = dot(a, b)
5071        assert_allclose(c, d)
5072
5073    @skip(reason="numpy internals")
5074    def test_dot_3args(self):
5075        from numpy.core.multiarray import dot
5076
5077        np.random.seed(22)
5078        f = np.random.random_sample((1024, 16))
5079        v = np.random.random_sample((16, 32))
5080
5081        r = np.empty((1024, 32))
5082        for i in range(12):
5083            dot(f, v, r)
5084        if HAS_REFCOUNT:
5085            assert_equal(sys.getrefcount(r), 2)
5086        r2 = dot(f, v, out=None)
5087        assert_array_equal(r2, r)
5088        assert_(r is dot(f, v, out=r))
5089
5090        v = v[:, 0].copy()  # v.shape == (16,)
5091        r = r[:, 0].copy()  # r.shape == (1024,)
5092        r2 = dot(f, v)
5093        assert_(r is dot(f, v, r))
5094        assert_array_equal(r2, r)
5095
5096    @skip(reason="numpy internals")
5097    def test_dot_3args_errors(self):
5098        from numpy.core.multiarray import dot
5099
5100        np.random.seed(22)
5101        f = np.random.random_sample((1024, 16))
5102        v = np.random.random_sample((16, 32))
5103
5104        r = np.empty((1024, 31))
5105        assert_raises(ValueError, dot, f, v, r)
5106
5107        r = np.empty((1024,))
5108        assert_raises(ValueError, dot, f, v, r)
5109
5110        r = np.empty((32,))
5111        assert_raises(ValueError, dot, f, v, r)
5112
5113        r = np.empty((32, 1024))
5114        assert_raises(ValueError, dot, f, v, r)
5115        assert_raises(ValueError, dot, f, v, r.T)
5116
5117        r = np.empty((1024, 64))
5118        assert_raises(ValueError, dot, f, v, r[:, ::2])
5119        assert_raises(ValueError, dot, f, v, r[:, :32])
5120
5121        r = np.empty((1024, 32), dtype=np.float32)
5122        assert_raises(ValueError, dot, f, v, r)
5123
5124        r = np.empty((1024, 32), dtype=int)
5125        assert_raises(ValueError, dot, f, v, r)
5126
5127    @xpassIfTorchDynamo  # (reason="TODO order='F'")
5128    def test_dot_array_order(self):
5129        a = np.array([[1, 2], [3, 4]], order="C")
5130        b = np.array([[1, 2], [3, 4]], order="F")
5131        res = np.dot(a, a)
5132
5133        # integer arrays are exact
5134        assert_equal(np.dot(a, b), res)
5135        assert_equal(np.dot(b, a), res)
5136        assert_equal(np.dot(b, b), res)
5137
5138    @skip(reason="TODO: nbytes, view, __array_interface__")
5139    def test_accelerate_framework_sgemv_fix(self):
5140        def aligned_array(shape, align, dtype, order="C"):
5141            d = dtype(0)
5142            N = np.prod(shape)
5143            tmp = np.zeros(N * d.nbytes + align, dtype=np.uint8)
5144            address = tmp.__array_interface__["data"][0]
5145            for offset in range(align):
5146                if (address + offset) % align == 0:
5147                    break
5148            tmp = tmp[offset : offset + N * d.nbytes].view(dtype=dtype)
5149            return tmp.reshape(shape, order=order)
5150
5151        def as_aligned(arr, align, dtype, order="C"):
5152            aligned = aligned_array(arr.shape, align, dtype, order)
5153            aligned[:] = arr[:]
5154            return aligned
5155
5156        def assert_dot_close(A, X, desired):
5157            assert_allclose(np.dot(A, X), desired, rtol=1e-5, atol=1e-7)
5158
5159        m = aligned_array(100, 15, np.float32)
5160        s = aligned_array((100, 100), 15, np.float32)
5161        np.dot(s, m)  # this will always segfault if the bug is present
5162
5163        testdata = itertools.product((15, 32), (10000,), (200, 89), ("C", "F"))
5164        for align, m, n, a_order in testdata:
5165            # Calculation in double precision
5166            A_d = np.random.rand(m, n)
5167            X_d = np.random.rand(n)
5168            desired = np.dot(A_d, X_d)
5169            # Calculation with aligned single precision
5170            A_f = as_aligned(A_d, align, np.float32, order=a_order)
5171            X_f = as_aligned(X_d, align, np.float32)
5172            assert_dot_close(A_f, X_f, desired)
5173            # Strided A rows
5174            A_d_2 = A_d[::2]
5175            desired = np.dot(A_d_2, X_d)
5176            A_f_2 = A_f[::2]
5177            assert_dot_close(A_f_2, X_f, desired)
5178            # Strided A columns, strided X vector
5179            A_d_22 = A_d_2[:, ::2]
5180            X_d_2 = X_d[::2]
5181            desired = np.dot(A_d_22, X_d_2)
5182            A_f_22 = A_f_2[:, ::2]
5183            X_f_2 = X_f[::2]
5184            assert_dot_close(A_f_22, X_f_2, desired)
5185            # Check the strides are as expected
5186            if a_order == "F":
5187                assert_equal(A_f_22.strides, (8, 8 * m))
5188            else:
5189                assert_equal(A_f_22.strides, (8 * n, 8))
5190            assert_equal(X_f_2.strides, (8,))
5191            # Strides in A rows + cols only
5192            X_f_2c = as_aligned(X_f_2, align, np.float32)
5193            assert_dot_close(A_f_22, X_f_2c, desired)
5194            # Strides just in A cols
5195            A_d_12 = A_d[:, ::2]
5196            desired = np.dot(A_d_12, X_d_2)
5197            A_f_12 = A_f[:, ::2]
5198            assert_dot_close(A_f_12, X_f_2c, desired)
5199            # Strides in A cols and X
5200            assert_dot_close(A_f_12, X_f_2, desired)
5201
5202    @slow
5203    @parametrize("dtype", [np.float64, np.complex128])
5204    @requires_memory(free_bytes=18e9)  # complex case needs 18GiB+
5205    def test_huge_vectordot(self, dtype):
5206        # Large vector multiplications are chunked with 32bit BLAS
5207        # Test that the chunking does the right thing, see also gh-22262
5208        data = np.ones(2**30 + 100, dtype=dtype)
5209        res = np.dot(data, data)
5210        assert res == 2**30 + 100
5211
5212
5213class MatmulCommon:
5214    """Common tests for '@' operator and numpy.matmul."""
5215
5216    # Should work with these types. Will want to add
5217    # "O" at some point
5218    types = "?bhilBefdFD"
5219
5220    def test_exceptions(self):
5221        dims = [
5222            ((1,), (2,)),  # mismatched vector vector
5223            (
5224                (
5225                    2,
5226                    1,
5227                ),
5228                (2,),
5229            ),  # mismatched matrix vector
5230            ((2,), (1, 2)),  # mismatched vector matrix
5231            ((1, 2), (3, 1)),  # mismatched matrix matrix
5232            ((1,), ()),  # vector scalar
5233            ((), (1)),  # scalar vector
5234            ((1, 1), ()),  # matrix scalar
5235            ((), (1, 1)),  # scalar matrix
5236            ((2, 2, 1), (3, 1, 2)),  # cannot broadcast
5237        ]
5238
5239        for dt, (dm1, dm2) in itertools.product(self.types, dims):
5240            a = np.ones(dm1, dtype=dt)
5241            b = np.ones(dm2, dtype=dt)
5242            assert_raises((RuntimeError, ValueError), self.matmul, a, b)
5243
5244    def test_shapes(self):
5245        dims = [
5246            ((1, 1), (2, 1, 1)),  # broadcast first argument
5247            ((2, 1, 1), (1, 1)),  # broadcast second argument
5248            ((2, 1, 1), (2, 1, 1)),  # matrix stack sizes match
5249        ]
5250
5251        for dt, (dm1, dm2) in itertools.product(self.types, dims):
5252            a = np.ones(dm1, dtype=dt)
5253            b = np.ones(dm2, dtype=dt)
5254            res = self.matmul(a, b)
5255            assert_(res.shape == (2, 1, 1))
5256
5257        # vector vector returns scalars.
5258        for dt in self.types:
5259            a = np.ones((2,), dtype=dt)
5260            b = np.ones((2,), dtype=dt)
5261            c = self.matmul(a, b)
5262            assert_(np.array(c).shape == ())
5263
5264    def test_result_types(self):
5265        mat = np.ones((1, 1))
5266        vec = np.ones((1,))
5267        for dt in self.types:
5268            m = mat.astype(dt)
5269            v = vec.astype(dt)
5270            for arg in [(m, v), (v, m), (m, m)]:
5271                res = self.matmul(*arg)
5272                assert_(res.dtype == dt)
5273
5274    @xpassIfTorchDynamo  # (reason="no scalars")
5275    def test_result_types_2(self):
5276        # in numpy, vector vector returns scalars
5277        # we return a 0D array instead
5278
5279        for dt in self.types:
5280            v = np.ones((1,)).astype(dt)
5281            if dt != "O":
5282                res = self.matmul(v, v)
5283                assert_(type(res) is np.dtype(dt).type)
5284
5285    def test_scalar_output(self):
5286        vec1 = np.array([2])
5287        vec2 = np.array([3, 4]).reshape(1, -1)
5288        tgt = np.array([6, 8])
5289        for dt in self.types[1:]:
5290            v1 = vec1.astype(dt)
5291            v2 = vec2.astype(dt)
5292            res = self.matmul(v1, v2)
5293            assert_equal(res, tgt)
5294            res = self.matmul(v2.T, v1)
5295            assert_equal(res, tgt)
5296
5297        # boolean type
5298        vec = np.array([True, True], dtype="?").reshape(1, -1)
5299        res = self.matmul(vec[:, 0], vec)
5300        assert_equal(res, True)
5301
5302    def test_vector_vector_values(self):
5303        vec1 = np.array([1, 2])
5304        vec2 = np.array([3, 4]).reshape(-1, 1)
5305        tgt1 = np.array([11])
5306        tgt2 = np.array([[3, 6], [4, 8]])
5307        for dt in self.types[1:]:
5308            v1 = vec1.astype(dt)
5309            v2 = vec2.astype(dt)
5310            res = self.matmul(v1, v2)
5311            assert_equal(res, tgt1)
5312            # no broadcast, we must make v1 into a 2d ndarray
5313            res = self.matmul(v2, v1.reshape(1, -1))
5314            assert_equal(res, tgt2)
5315
5316        # boolean type
5317        vec = np.array([True, True], dtype="?")
5318        res = self.matmul(vec, vec)
5319        assert_equal(res, True)
5320
5321    def test_vector_matrix_values(self):
5322        vec = np.array([1, 2])
5323        mat1 = np.array([[1, 2], [3, 4]])
5324        mat2 = np.stack([mat1] * 2, axis=0)
5325        tgt1 = np.array([7, 10])
5326        tgt2 = np.stack([tgt1] * 2, axis=0)
5327        for dt in self.types[1:]:
5328            v = vec.astype(dt)
5329            m1 = mat1.astype(dt)
5330            m2 = mat2.astype(dt)
5331            res = self.matmul(v, m1)
5332            assert_equal(res, tgt1)
5333            res = self.matmul(v, m2)
5334            assert_equal(res, tgt2)
5335
5336        # boolean type
5337        vec = np.array([True, False])
5338        mat1 = np.array([[True, False], [False, True]])
5339        mat2 = np.stack([mat1] * 2, axis=0)
5340        tgt1 = np.array([True, False])
5341        tgt2 = np.stack([tgt1] * 2, axis=0)
5342
5343        res = self.matmul(vec, mat1)
5344        assert_equal(res, tgt1)
5345        res = self.matmul(vec, mat2)
5346        assert_equal(res, tgt2)
5347
5348    def test_matrix_vector_values(self):
5349        vec = np.array([1, 2])
5350        mat1 = np.array([[1, 2], [3, 4]])
5351        mat2 = np.stack([mat1] * 2, axis=0)
5352        tgt1 = np.array([5, 11])
5353        tgt2 = np.stack([tgt1] * 2, axis=0)
5354        for dt in self.types[1:]:
5355            v = vec.astype(dt)
5356            m1 = mat1.astype(dt)
5357            m2 = mat2.astype(dt)
5358            res = self.matmul(m1, v)
5359            assert_equal(res, tgt1)
5360            res = self.matmul(m2, v)
5361            assert_equal(res, tgt2)
5362
5363        # boolean type
5364        vec = np.array([True, False])
5365        mat1 = np.array([[True, False], [False, True]])
5366        mat2 = np.stack([mat1] * 2, axis=0)
5367        tgt1 = np.array([True, False])
5368        tgt2 = np.stack([tgt1] * 2, axis=0)
5369
5370        res = self.matmul(vec, mat1)
5371        assert_equal(res, tgt1)
5372        res = self.matmul(vec, mat2)
5373        assert_equal(res, tgt2)
5374
5375    def test_matrix_matrix_values(self):
5376        mat1 = np.array([[1, 2], [3, 4]])
5377        mat2 = np.array([[1, 0], [1, 1]])
5378        mat12 = np.stack([mat1, mat2], axis=0)
5379        mat21 = np.stack([mat2, mat1], axis=0)
5380        tgt11 = np.array([[7, 10], [15, 22]])
5381        tgt12 = np.array([[3, 2], [7, 4]])
5382        tgt21 = np.array([[1, 2], [4, 6]])
5383        tgt12_21 = np.stack([tgt12, tgt21], axis=0)
5384        tgt11_12 = np.stack((tgt11, tgt12), axis=0)
5385        tgt11_21 = np.stack((tgt11, tgt21), axis=0)
5386        for dt in self.types[1:]:
5387            m1 = mat1.astype(dt)
5388            m2 = mat2.astype(dt)
5389            m12 = mat12.astype(dt)
5390            m21 = mat21.astype(dt)
5391
5392            # matrix @ matrix
5393            res = self.matmul(m1, m2)
5394            assert_equal(res, tgt12)
5395            res = self.matmul(m2, m1)
5396            assert_equal(res, tgt21)
5397
5398            # stacked @ matrix
5399            res = self.matmul(m12, m1)
5400            assert_equal(res, tgt11_21)
5401
5402            # matrix @ stacked
5403            res = self.matmul(m1, m12)
5404            assert_equal(res, tgt11_12)
5405
5406            # stacked @ stacked
5407            res = self.matmul(m12, m21)
5408            assert_equal(res, tgt12_21)
5409
5410        # boolean type
5411        m1 = np.array([[1, 1], [0, 0]], dtype=np.bool_)
5412        m2 = np.array([[1, 0], [1, 1]], dtype=np.bool_)
5413        m12 = np.stack([m1, m2], axis=0)
5414        m21 = np.stack([m2, m1], axis=0)
5415        tgt11 = m1
5416        tgt12 = m1
5417        tgt21 = np.array([[1, 1], [1, 1]], dtype=np.bool_)
5418        tgt12_21 = np.stack([tgt12, tgt21], axis=0)
5419        tgt11_12 = np.stack((tgt11, tgt12), axis=0)
5420        tgt11_21 = np.stack((tgt11, tgt21), axis=0)
5421
5422        # matrix @ matrix
5423        res = self.matmul(m1, m2)
5424        assert_equal(res, tgt12)
5425        res = self.matmul(m2, m1)
5426        assert_equal(res, tgt21)
5427
5428        # stacked @ matrix
5429        res = self.matmul(m12, m1)
5430        assert_equal(res, tgt11_21)
5431
5432        # matrix @ stacked
5433        res = self.matmul(m1, m12)
5434        assert_equal(res, tgt11_12)
5435
5436        # stacked @ stacked
5437        res = self.matmul(m12, m21)
5438        assert_equal(res, tgt12_21)
5439
5440
5441@instantiate_parametrized_tests
5442class TestMatmul(MatmulCommon, TestCase):
5443    def setUp(self):
5444        self.matmul = np.matmul
5445
5446    def test_out_arg(self):
5447        a = np.ones((5, 2), dtype=float)
5448        b = np.array([[1, 3], [5, 7]], dtype=float)
5449        tgt = np.dot(a, b)
5450
5451        # test as positional argument
5452        msg = "out positional argument"
5453        out = np.zeros((5, 2), dtype=float)
5454        self.matmul(a, b, out)
5455        assert_array_equal(out, tgt, err_msg=msg)
5456
5457        # test as keyword argument
5458        msg = "out keyword argument"
5459        out = np.zeros((5, 2), dtype=float)
5460        self.matmul(a, b, out=out)
5461        assert_array_equal(out, tgt, err_msg=msg)
5462
5463        # test out with not allowed type cast (safe casting)
5464        msg = "Cannot cast"
5465        out = np.zeros((5, 2), dtype=np.int32)
5466        assert_raises_regex(TypeError, msg, self.matmul, a, b, out=out)
5467
5468        # test out with type upcast to complex
5469        out = np.zeros((5, 2), dtype=np.complex128)
5470        c = self.matmul(a, b, out=out)
5471        assert_(c is out)
5472        #      with suppress_warnings() as sup:
5473        #          sup.filter(np.ComplexWarning, '')
5474        c = c.astype(tgt.dtype)
5475        assert_array_equal(c, tgt)
5476
5477    def test_empty_out(self):
5478        # Check that the output cannot be broadcast, so that it cannot be
5479        # size zero when the outer dimensions (iterator size) has size zero.
5480        arr = np.ones((0, 1, 1))
5481        out = np.ones((1, 1, 1))
5482        assert self.matmul(arr, arr).shape == (0, 1, 1)
5483
5484        with pytest.raises((RuntimeError, ValueError)):
5485            self.matmul(arr, arr, out=out)
5486
5487    def test_out_contiguous(self):
5488        a = np.ones((5, 2), dtype=float)
5489        b = np.array([[1, 3], [5, 7]], dtype=float)
5490        v = np.array([1, 3], dtype=float)
5491        tgt = np.dot(a, b)
5492        tgt_mv = np.dot(a, v)
5493
5494        # test out non-contiguous
5495        out = np.ones((5, 2, 2), dtype=float)
5496        c = self.matmul(a, b, out=out[..., 0])
5497        assert_array_equal(c, tgt)
5498        c = self.matmul(a, v, out=out[:, 0, 0])
5499        assert_array_equal(c, tgt_mv)
5500        c = self.matmul(v, a.T, out=out[:, 0, 0])
5501        assert_array_equal(c, tgt_mv)
5502
5503        # test out contiguous in only last dim
5504        out = np.ones((10, 2), dtype=float)
5505        c = self.matmul(a, b, out=out[::2, :])
5506        assert_array_equal(c, tgt)
5507
5508        # test transposes of out, args
5509        out = np.ones((5, 2), dtype=float)
5510        c = self.matmul(b.T, a.T, out=out.T)
5511        assert_array_equal(out, tgt)
5512
5513    @xfailIfTorchDynamo
5514    def test_out_contiguous_2(self):
5515        a = np.ones((5, 2), dtype=float)
5516        b = np.array([[1, 3], [5, 7]], dtype=float)
5517
5518        # test out non-contiguous
5519        out = np.ones((5, 2, 2), dtype=float)
5520        c = self.matmul(a, b, out=out[..., 0])
5521        assert c.tensor._base is out.tensor
5522
5523    m1 = np.arange(15.0).reshape(5, 3)
5524    m2 = np.arange(21.0).reshape(3, 7)
5525    m3 = np.arange(30.0).reshape(5, 6)[:, ::2]  # non-contiguous
5526    vc = np.arange(10.0)
5527    vr = np.arange(6.0)
5528    m0 = np.zeros((3, 0))
5529
5530    @parametrize(
5531        "args",
5532        (
5533            # matrix-matrix
5534            subtest((m1, m2), name="mm1"),
5535            subtest((m2.T, m1.T), name="mm2"),
5536            subtest((m2.T.copy(), m1.T), name="mm3"),
5537            subtest((m2.T, m1.T.copy()), name="mm4"),
5538            # matrix-matrix-transpose, contiguous and non
5539            subtest((m1, m1.T), name="mmT1"),
5540            subtest((m1.T, m1), name="mmT2"),
5541            subtest((m1, m3.T), name="mmT3"),
5542            subtest((m3, m1.T), name="mmT4"),
5543            subtest((m3, m3.T), name="mmT5"),
5544            subtest((m3.T, m3), name="mmT6"),
5545            # matrix-matrix non-contiguous
5546            subtest((m3, m2), name="mmN1"),
5547            subtest((m2.T, m3.T), name="mmN2"),
5548            subtest((m2.T.copy(), m3.T), name="mmN3"),
5549            # vector-matrix, matrix-vector, contiguous
5550            subtest((m1, vr[:3]), name="vm1"),
5551            subtest((vc[:5], m1), name="vm2"),
5552            subtest((m1.T, vc[:5]), name="vm3"),
5553            subtest((vr[:3], m1.T), name="vm4"),
5554            # vector-matrix, matrix-vector, vector non-contiguous
5555            subtest((m1, vr[::2]), name="mvN1"),
5556            subtest((vc[::2], m1), name="mvN2"),
5557            subtest((m1.T, vc[::2]), name="mvN3"),
5558            subtest((vr[::2], m1.T), name="mvN4"),
5559            # vector-matrix, matrix-vector, matrix non-contiguous
5560            subtest((m3, vr[:3]), name="mvN5"),
5561            subtest((vc[:5], m3), name="mvN6"),
5562            subtest((m3.T, vc[:5]), name="mvN7"),
5563            subtest((vr[:3], m3.T), name="mvN8"),
5564            # vector-matrix, matrix-vector, both non-contiguous
5565            subtest((m3, vr[::2]), name="mvN9"),
5566            subtest((vc[::2], m3), name="mvn10"),
5567            subtest((m3.T, vc[::2]), name="mv11"),
5568            subtest((vr[::2], m3.T), name="mv12"),
5569            # size == 0
5570            subtest((m0, m0.T), name="s0_1"),
5571            subtest((m0.T, m0), name="s0_2"),
5572            subtest((m1, m0), name="s0_3"),
5573            subtest((m0.T, m1.T), name="s0_4"),
5574        ),
5575    )
5576    def test_dot_equivalent(self, args):
5577        r1 = np.matmul(*args)
5578        r2 = np.dot(*args)
5579        assert_equal(r1, r2)
5580
5581        r3 = np.matmul(args[0].copy(), args[1].copy())
5582        assert_equal(r1, r3)
5583
5584    @skip(reason="object arrays")
5585    def test_matmul_exception_multiply(self):
5586        # test that matmul fails if `__mul__` is missing
5587        class add_not_multiply:
5588            def __add__(self, other):
5589                return self
5590
5591        a = np.full((3, 3), add_not_multiply())
5592        with assert_raises(TypeError):
5593            b = np.matmul(a, a)
5594
5595    @skip(reason="object arrays")
5596    def test_matmul_exception_add(self):
5597        # test that matmul fails if `__add__` is missing
5598        class multiply_not_add:
5599            def __mul__(self, other):
5600                return self
5601
5602        a = np.full((3, 3), multiply_not_add())
5603        with assert_raises(TypeError):
5604            b = np.matmul(a, a)
5605
5606    def test_matmul_bool(self):
5607        # gh-14439
5608        a = np.array([[1, 0], [1, 1]], dtype=bool)
5609        assert np.max(a.view(np.uint8)) == 1
5610        b = np.matmul(a, a)
5611        # matmul with boolean output should always be 0, 1
5612        assert np.max(b.view(np.uint8)) == 1
5613
5614        # rg = np.random.default_rng(np.random.PCG64(43))
5615        # d = rg.integers(2, size=4*5, dtype=np.int8)
5616        # d = d.reshape(4, 5) > 0
5617        np.random.seed(1234)
5618        d = np.random.randint(2, size=(4, 5)) > 0
5619
5620        out1 = np.matmul(d, d.reshape(5, 4))
5621        out2 = np.dot(d, d.reshape(5, 4))
5622        assert_equal(out1, out2)
5623
5624        c = np.matmul(np.zeros((2, 0), dtype=bool), np.zeros(0, dtype=bool))
5625        assert not np.any(c)
5626
5627
5628class TestMatmulOperator(MatmulCommon, TestCase):
5629    import operator
5630
5631    matmul = operator.matmul
5632
5633    @skip(reason="no __array_priority__")
5634    def test_array_priority_override(self):
5635        class A:
5636            __array_priority__ = 1000
5637
5638            def __matmul__(self, other):
5639                return "A"
5640
5641            def __rmatmul__(self, other):
5642                return "A"
5643
5644        a = A()
5645        b = np.ones(2)
5646        assert_equal(self.matmul(a, b), "A")
5647        assert_equal(self.matmul(b, a), "A")
5648
5649    def test_matmul_raises(self):
5650        assert_raises(
5651            (RuntimeError, TypeError, ValueError), self.matmul, np.int8(5), np.int8(5)
5652        )
5653
5654    @xpassIfTorchDynamo  # (reason="torch supports inplace matmul, and so do we")
5655    @skipif(numpy.__version__ >= "1.26", reason="This is fixed in numpy 1.26")
5656    def test_matmul_inplace(self):
5657        # It would be nice to support in-place matmul eventually, but for now
5658        # we don't have a working implementation, so better just to error out
5659        # and nudge people to writing "a = a @ b".
5660        a = np.eye(3)
5661        b = np.eye(3)
5662        assert_raises(TypeError, a.__imatmul__, b)
5663
5664    @xfail  # XXX: what's up with exec under Dynamo
5665    def test_matmul_inplace_2(self):
5666        a = np.eye(3)
5667        b = np.eye(3)
5668
5669        assert_raises(TypeError, operator.imatmul, a, b)
5670        assert_raises(TypeError, exec, "a @= b", globals(), locals())
5671
5672    @xpassIfTorchDynamo  # (reason="matmul_axes")
5673    def test_matmul_axes(self):
5674        a = np.arange(3 * 4 * 5).reshape(3, 4, 5)
5675        c = np.matmul(a, a, axes=[(-2, -1), (-1, -2), (1, 2)])
5676        assert c.shape == (3, 4, 4)
5677        d = np.matmul(a, a, axes=[(-2, -1), (-1, -2), (0, 1)])
5678        assert d.shape == (4, 4, 3)
5679        e = np.swapaxes(d, 0, 2)
5680        assert_array_equal(e, c)
5681        f = np.matmul(a, np.arange(3), axes=[(1, 0), (0), (0)])
5682        assert f.shape == (4, 5)
5683
5684
5685class TestInner(TestCase):
5686    def test_inner_scalar_and_vector(self):
5687        for dt in np.typecodes["AllInteger"] + np.typecodes["AllFloat"] + "?":
5688            sca = np.array(3, dtype=dt)[()]
5689            vec = np.array([1, 2], dtype=dt)
5690            desired = np.array([3, 6], dtype=dt)
5691            assert_equal(np.inner(vec, sca), desired)
5692            assert_equal(np.inner(sca, vec), desired)
5693
5694    def test_vecself(self):
5695        # Ticket 844.
5696        # Inner product of a vector with itself segfaults or give
5697        # meaningless result
5698        a = np.zeros(shape=(1, 80), dtype=np.float64)
5699        p = np.inner(a, a)
5700        assert_almost_equal(p, 0, decimal=14)
5701
5702    def test_inner_product_with_various_contiguities(self):
5703        # github issue 6532
5704        for dt in np.typecodes["AllInteger"] + np.typecodes["AllFloat"] + "?":
5705            # check an inner product involving a matrix transpose
5706            A = np.array([[1, 2], [3, 4]], dtype=dt)
5707            B = np.array([[1, 3], [2, 4]], dtype=dt)
5708            C = np.array([1, 1], dtype=dt)
5709            desired = np.array([4, 6], dtype=dt)
5710            assert_equal(np.inner(A.T, C), desired)
5711            assert_equal(np.inner(C, A.T), desired)
5712            assert_equal(np.inner(B, C), desired)
5713            assert_equal(np.inner(C, B), desired)
5714            # check a matrix product
5715            desired = np.array([[7, 10], [15, 22]], dtype=dt)
5716            assert_equal(np.inner(A, B), desired)
5717            # check the syrk vs. gemm paths
5718            desired = np.array([[5, 11], [11, 25]], dtype=dt)
5719            assert_equal(np.inner(A, A), desired)
5720            assert_equal(np.inner(A, A.copy()), desired)
5721
5722    @skip(reason="[::-1] not supported")
5723    def test_inner_product_reversed_view(self):
5724        for dt in np.typecodes["AllInteger"] + np.typecodes["AllFloat"] + "?":
5725            # check an inner product involving an aliased and reversed view
5726            a = np.arange(5).astype(dt)
5727            b = a[::-1]
5728            desired = np.array(10, dtype=dt).item()
5729            assert_equal(np.inner(b, a), desired)
5730
5731    def test_3d_tensor(self):
5732        for dt in np.typecodes["AllInteger"] + np.typecodes["AllFloat"] + "?":
5733            a = np.arange(24).reshape(2, 3, 4).astype(dt)
5734            b = np.arange(24, 48).reshape(2, 3, 4).astype(dt)
5735            desired = np.array(
5736                [
5737                    [
5738                        [[158, 182, 206], [230, 254, 278]],
5739                        [[566, 654, 742], [830, 918, 1006]],
5740                        [[974, 1126, 1278], [1430, 1582, 1734]],
5741                    ],
5742                    [
5743                        [[1382, 1598, 1814], [2030, 2246, 2462]],
5744                        [[1790, 2070, 2350], [2630, 2910, 3190]],
5745                        [[2198, 2542, 2886], [3230, 3574, 3918]],
5746                    ],
5747                ]
5748            ).astype(dt)
5749            assert_equal(np.inner(a, b), desired)
5750            assert_equal(np.inner(b, a).transpose(2, 3, 0, 1), desired)
5751
5752
5753@instantiate_parametrized_tests
5754class TestChoose(TestCase):
5755    def setUp(self):
5756        self.x = 2 * np.ones((3,), dtype=int)
5757        self.y = 3 * np.ones((3,), dtype=int)
5758        self.x2 = 2 * np.ones((2, 3), dtype=int)
5759        self.y2 = 3 * np.ones((2, 3), dtype=int)
5760        self.ind = [0, 0, 1]
5761
5762    def test_basic(self):
5763        A = np.choose(self.ind, (self.x, self.y))
5764        assert_equal(A, [2, 2, 3])
5765
5766    def test_broadcast1(self):
5767        A = np.choose(self.ind, (self.x2, self.y2))
5768        assert_equal(A, [[2, 2, 3], [2, 2, 3]])
5769
5770    def test_broadcast2(self):
5771        A = np.choose(self.ind, (self.x, self.y2))
5772        assert_equal(A, [[2, 2, 3], [2, 2, 3]])
5773
5774    # XXX: revisit xfails when NEP 50 lands in numpy
5775    @skip(reason="XXX: revisit xfails when NEP 50 lands in numpy")
5776    @parametrize(
5777        "ops",
5778        [
5779            (1000, np.array([1], dtype=np.uint8)),
5780            (-1, np.array([1], dtype=np.uint8)),
5781            (1.0, np.float32(3)),
5782            (1.0, np.array([3], dtype=np.float32)),
5783        ],
5784    )
5785    def test_output_dtype(self, ops):
5786        expected_dt = np.result_type(*ops)
5787        assert np.choose([0], ops).dtype == expected_dt
5788
5789    def test_docstring_1(self):
5790        # examples from the docstring,
5791        # https://numpy.org/doc/1.23/reference/generated/numpy.choose.html
5792        choices = [[0, 1, 2, 3], [10, 11, 12, 13], [20, 21, 22, 23], [30, 31, 32, 33]]
5793        A = np.choose([2, 3, 1, 0], choices)
5794
5795        assert_equal(A, [20, 31, 12, 3])
5796
5797    def test_docstring_2(self):
5798        a = [[1, 0, 1], [0, 1, 0], [1, 0, 1]]
5799        choices = [-10, 10]
5800        A = np.choose(a, choices)
5801        assert_equal(A, [[10, -10, 10], [-10, 10, -10], [10, -10, 10]])
5802
5803    def test_docstring_3(self):
5804        a = np.array([0, 1]).reshape((2, 1, 1))
5805        c1 = np.array([1, 2, 3]).reshape((1, 3, 1))
5806        c2 = np.array([-1, -2, -3, -4, -5]).reshape((1, 1, 5))
5807        A = np.choose(a, (c1, c2))  # result is 2x3x5, res[0,:,:]=c1, res[1,:,:]=c2
5808        expected = np.array(
5809            [
5810                [[1, 1, 1, 1, 1], [2, 2, 2, 2, 2], [3, 3, 3, 3, 3]],
5811                [[-1, -2, -3, -4, -5], [-1, -2, -3, -4, -5], [-1, -2, -3, -4, -5]],
5812            ]
5813        )
5814        assert_equal(A, expected)
5815
5816
5817class TestRepeat(TestCase):
5818    def setUp(self):
5819        self.m = np.array([1, 2, 3, 4, 5, 6])
5820        self.m_rect = self.m.reshape((2, 3))
5821
5822    def test_basic(self):
5823        A = np.repeat(self.m, [1, 3, 2, 1, 1, 2])
5824        assert_equal(A, [1, 2, 2, 2, 3, 3, 4, 5, 6, 6])
5825
5826    def test_broadcast1(self):
5827        A = np.repeat(self.m, 2)
5828        assert_equal(A, [1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6])
5829
5830    def test_axis_spec(self):
5831        A = np.repeat(self.m_rect, [2, 1], axis=0)
5832        assert_equal(A, [[1, 2, 3], [1, 2, 3], [4, 5, 6]])
5833
5834        A = np.repeat(self.m_rect, [1, 3, 2], axis=1)
5835        assert_equal(A, [[1, 2, 2, 2, 3, 3], [4, 5, 5, 5, 6, 6]])
5836
5837    def test_broadcast2(self):
5838        A = np.repeat(self.m_rect, 2, axis=0)
5839        assert_equal(A, [[1, 2, 3], [1, 2, 3], [4, 5, 6], [4, 5, 6]])
5840
5841        A = np.repeat(self.m_rect, 2, axis=1)
5842        assert_equal(A, [[1, 1, 2, 2, 3, 3], [4, 4, 5, 5, 6, 6]])
5843
5844
5845# TODO: test for multidimensional
5846NEIGH_MODE = {"zero": 0, "one": 1, "constant": 2, "circular": 3, "mirror": 4}
5847
5848
5849@xpassIfTorchDynamo  # (reason="TODO")
5850class TestWarnings(TestCase):
5851    def test_complex_warning(self):
5852        x = np.array([1, 2])
5853        y = np.array([1 - 2j, 1 + 2j])
5854
5855        with warnings.catch_warnings():
5856            warnings.simplefilter("error", np.ComplexWarning)
5857            assert_raises(np.ComplexWarning, x.__setitem__, slice(None), y)
5858            assert_equal(x, [1, 2])
5859
5860
5861class TestMinScalarType(TestCase):
5862    def test_usigned_shortshort(self):
5863        dt = np.min_scalar_type(2**8 - 1)
5864        wanted = np.dtype("uint8")
5865        assert_equal(wanted, dt)
5866
5867    # three tests below are added based on what numpy does
5868    def test_complex(self):
5869        dt = np.min_scalar_type(0 + 0j)
5870        assert dt == np.dtype("complex64")
5871
5872    def test_float(self):
5873        dt = np.min_scalar_type(0.1)
5874        assert dt == np.dtype("float16")
5875
5876    def test_nonscalar(self):
5877        dt = np.min_scalar_type([0, 1, 2])
5878        assert dt == np.dtype("int64")
5879
5880
5881from numpy.core._internal import _dtype_from_pep3118
5882
5883
5884@skip(reason="dont worry about buffer protocol")
5885class TestPEP3118Dtype(TestCase):
5886    def _check(self, spec, wanted):
5887        dt = np.dtype(wanted)
5888        actual = _dtype_from_pep3118(spec)
5889        assert_equal(actual, dt, err_msg=f"spec {spec!r} != dtype {wanted!r}")
5890
5891    def test_native_padding(self):
5892        align = np.dtype("i").alignment
5893        for j in range(8):
5894            if j == 0:
5895                s = "bi"
5896            else:
5897                s = "b%dxi" % j
5898            self._check(
5899                "@" + s, {"f0": ("i1", 0), "f1": ("i", align * (1 + j // align))}
5900            )
5901            self._check("=" + s, {"f0": ("i1", 0), "f1": ("i", 1 + j)})
5902
5903    def test_native_padding_2(self):
5904        # Native padding should work also for structs and sub-arrays
5905        self._check("x3T{xi}", {"f0": (({"f0": ("i", 4)}, (3,)), 4)})
5906        self._check("^x3T{xi}", {"f0": (({"f0": ("i", 1)}, (3,)), 1)})
5907
5908    def test_trailing_padding(self):
5909        # Trailing padding should be included, *and*, the item size
5910        # should match the alignment if in aligned mode
5911        align = np.dtype("i").alignment
5912        size = np.dtype("i").itemsize
5913
5914        def aligned(n):
5915            return align * (1 + (n - 1) // align)
5916
5917        base = dict(formats=["i"], names=["f0"])
5918
5919        self._check("ix", dict(itemsize=aligned(size + 1), **base))
5920        self._check("ixx", dict(itemsize=aligned(size + 2), **base))
5921        self._check("ixxx", dict(itemsize=aligned(size + 3), **base))
5922        self._check("ixxxx", dict(itemsize=aligned(size + 4), **base))
5923        self._check("i7x", dict(itemsize=aligned(size + 7), **base))
5924
5925        self._check("^ix", dict(itemsize=size + 1, **base))
5926        self._check("^ixx", dict(itemsize=size + 2, **base))
5927        self._check("^ixxx", dict(itemsize=size + 3, **base))
5928        self._check("^ixxxx", dict(itemsize=size + 4, **base))
5929        self._check("^i7x", dict(itemsize=size + 7, **base))
5930
5931    def test_native_padding_3(self):
5932        dt = np.dtype(
5933            [("a", "b"), ("b", "i"), ("sub", np.dtype("b,i")), ("c", "i")], align=True
5934        )
5935        self._check("T{b:a:xxxi:b:T{b:f0:=i:f1:}:sub:xxxi:c:}", dt)
5936
5937        dt = np.dtype(
5938            [
5939                ("a", "b"),
5940                ("b", "i"),
5941                ("c", "b"),
5942                ("d", "b"),
5943                ("e", "b"),
5944                ("sub", np.dtype("b,i", align=True)),
5945            ]
5946        )
5947        self._check("T{b:a:=i:b:b:c:b:d:b:e:T{b:f0:xxxi:f1:}:sub:}", dt)
5948
5949    def test_padding_with_array_inside_struct(self):
5950        dt = np.dtype(
5951            [("a", "b"), ("b", "i"), ("c", "b", (3,)), ("d", "i")], align=True
5952        )
5953        self._check("T{b:a:xxxi:b:3b:c:xi:d:}", dt)
5954
5955    def test_byteorder_inside_struct(self):
5956        # The byte order after @T{=i} should be '=', not '@'.
5957        # Check this by noting the absence of native alignment.
5958        self._check("@T{^i}xi", {"f0": ({"f0": ("i", 0)}, 0), "f1": ("i", 5)})
5959
5960    def test_intra_padding(self):
5961        # Natively aligned sub-arrays may require some internal padding
5962        align = np.dtype("i").alignment
5963        size = np.dtype("i").itemsize
5964
5965        def aligned(n):
5966            return align * (1 + (n - 1) // align)
5967
5968        self._check(
5969            "(3)T{ix}",
5970            (
5971                dict(
5972                    names=["f0"], formats=["i"], offsets=[0], itemsize=aligned(size + 1)
5973                ),
5974                (3,),
5975            ),
5976        )
5977
5978    def test_char_vs_string(self):
5979        dt = np.dtype("c")
5980        self._check("c", dt)
5981
5982        dt = np.dtype([("f0", "S1", (4,)), ("f1", "S4")])
5983        self._check("4c4s", dt)
5984
5985    def test_field_order(self):
5986        # gh-9053 - previously, we relied on dictionary key order
5987        self._check("(0)I:a:f:b:", [("a", "I", (0,)), ("b", "f")])
5988        self._check("(0)I:b:f:a:", [("b", "I", (0,)), ("a", "f")])
5989
5990    def test_unnamed_fields(self):
5991        self._check("ii", [("f0", "i"), ("f1", "i")])
5992        self._check("ii:f0:", [("f1", "i"), ("f0", "i")])
5993
5994        self._check("i", "i")
5995        self._check("i:f0:", [("f0", "i")])
5996
5997
5998# NOTE: xpassIfTorchDynamo below
5999# 1. TODO: torch._numpy does not handle/model _CopyMode
6000# 2. order= keyword not supported (probably won't be)
6001# 3. Under TEST_WITH_TORCHDYNAMO many of these make it through due
6002#    to a graph break leaving the _CopyMode to only be handled by numpy.
6003@skipif(numpy.__version__ < "1.23", reason="CopyMode is new in NumPy 1.22")
6004@xpassIfTorchDynamo
6005@instantiate_parametrized_tests
6006class TestArrayCreationCopyArgument(TestCase):
6007    class RaiseOnBool:
6008        def __bool__(self):
6009            raise ValueError
6010
6011    # true_vals = [True, np._CopyMode.ALWAYS, np.True_]
6012    # false_vals = [False, np._CopyMode.IF_NEEDED, np.False_]
6013    true_vals = [True, 1, np.True_]
6014    false_vals = [False, 0, np.False_]
6015
6016    def test_scalars(self):
6017        # Test both numpy and python scalars
6018        for dtype in np.typecodes["All"]:
6019            arr = np.zeros((), dtype=dtype)
6020            scalar = arr[()]
6021            pyscalar = arr.item(0)
6022
6023            # Test never-copy raises error:
6024            assert_raises(ValueError, np.array, scalar, copy=np._CopyMode.NEVER)
6025            assert_raises(ValueError, np.array, pyscalar, copy=np._CopyMode.NEVER)
6026            assert_raises(ValueError, np.array, pyscalar, copy=self.RaiseOnBool())
6027            # Casting with a dtype (to unsigned integers) can be special:
6028            with pytest.raises(ValueError):
6029                np.array(pyscalar, dtype=np.int64, copy=np._CopyMode.NEVER)
6030
6031    @xfail  # TODO: handle `_CopyMode` properly in torch._numpy
6032    def test_compatible_cast(self):
6033        # Some types are compatible even though they are different, no
6034        # copy is necessary for them. This is mostly true for some integers
6035        def int_types(byteswap=False):
6036            int_types = np.typecodes["Integer"] + np.typecodes["UnsignedInteger"]
6037            for int_type in int_types:
6038                yield np.dtype(int_type)
6039                if byteswap:
6040                    yield np.dtype(int_type).newbyteorder()
6041
6042        for int1 in int_types():
6043            for int2 in int_types(True):
6044                arr = np.arange(10, dtype=int1)
6045
6046                for copy in self.true_vals:
6047                    res = np.array(arr, copy=copy, dtype=int2)
6048                    assert res is not arr and res.flags.owndata
6049                    assert_array_equal(res, arr)
6050
6051                if int1 == int2:
6052                    # Casting is not necessary, base check is sufficient here
6053                    for copy in self.false_vals:
6054                        res = np.array(arr, copy=copy, dtype=int2)
6055                        assert res is arr or res.base is arr
6056
6057                    res = np.array(arr, copy=np._CopyMode.NEVER, dtype=int2)
6058                    assert res is arr or res.base is arr
6059
6060                else:
6061                    # Casting is necessary, assert copy works:
6062                    for copy in self.false_vals:
6063                        res = np.array(arr, copy=copy, dtype=int2)
6064                        assert res is not arr and res.flags.owndata
6065                        assert_array_equal(res, arr)
6066
6067                    assert_raises(
6068                        ValueError, np.array, arr, copy=np._CopyMode.NEVER, dtype=int2
6069                    )
6070                    assert_raises(ValueError, np.array, arr, copy=None, dtype=int2)
6071
6072    def test_buffer_interface(self):
6073        # Buffer interface gives direct memory access (no copy)
6074        arr = np.arange(10)
6075        view = memoryview(arr)
6076
6077        # Checking bases is a bit tricky since numpy creates another
6078        # memoryview, so use may_share_memory.
6079        for copy in self.true_vals:
6080            res = np.array(view, copy=copy)
6081            assert not np.may_share_memory(arr, res)
6082        for copy in self.false_vals:
6083            res = np.array(view, copy=copy)
6084            assert np.may_share_memory(arr, res)
6085        res = np.array(view, copy=np._CopyMode.NEVER)
6086        assert np.may_share_memory(arr, res)
6087
6088    def test_array_interfaces(self):
6089        # Array interface gives direct memory access (much like a memoryview)
6090        base_arr = np.arange(10)
6091
6092        class ArrayLike:
6093            __array_interface__ = base_arr.__array_interface__
6094
6095        arr = ArrayLike()
6096
6097        for copy, val in [
6098            (True, None),
6099            (np._CopyMode.ALWAYS, None),
6100            (False, arr),
6101            (np._CopyMode.IF_NEEDED, arr),
6102            (np._CopyMode.NEVER, arr),
6103        ]:
6104            res = np.array(arr, copy=copy)
6105            assert res.base is val
6106
6107    def test___array__(self):
6108        base_arr = np.arange(10)
6109
6110        class ArrayLike:
6111            def __array__(self):
6112                # __array__ should return a copy, numpy cannot know this
6113                # however.
6114                return base_arr
6115
6116        arr = ArrayLike()
6117
6118        for copy in self.true_vals:
6119            res = np.array(arr, copy=copy)
6120            assert_array_equal(res, base_arr)
6121            # An additional copy is currently forced by numpy in this case,
6122            # you could argue, numpy does not trust the ArrayLike. This
6123            # may be open for change:
6124            assert res is not base_arr
6125
6126        for copy in self.false_vals:
6127            res = np.array(arr, copy=False)
6128            assert_array_equal(res, base_arr)
6129            assert res is base_arr  # numpy trusts the ArrayLike
6130
6131        with pytest.raises(ValueError):
6132            np.array(arr, copy=np._CopyMode.NEVER)
6133
6134    @parametrize("arr", [np.ones(()), np.arange(81).reshape((9, 9))])
6135    @parametrize("order1", ["C", "F", None])
6136    @parametrize("order2", ["C", "F", "A", "K"])
6137    def test_order_mismatch(self, arr, order1, order2):
6138        # The order is the main (python side) reason that can cause
6139        # a never-copy to fail.
6140        # Prepare C-order, F-order and non-contiguous arrays:
6141        arr = arr.copy(order1)
6142        if order1 == "C":
6143            assert arr.flags.c_contiguous
6144        elif order1 == "F":
6145            assert arr.flags.f_contiguous
6146        elif arr.ndim != 0:
6147            # Make array non-contiguous
6148            arr = arr[::2, ::2]
6149            assert not arr.flags.forc
6150
6151        # Whether a copy is necessary depends on the order of arr:
6152        if order2 == "C":
6153            no_copy_necessary = arr.flags.c_contiguous
6154        elif order2 == "F":
6155            no_copy_necessary = arr.flags.f_contiguous
6156        else:
6157            # Keeporder and Anyorder are OK with non-contiguous output.
6158            # This is not consistent with the `astype` behaviour which
6159            # enforces contiguity for "A". It is probably historic from when
6160            # "K" did not exist.
6161            no_copy_necessary = True
6162
6163        # Test it for both the array and a memoryview
6164        for view in [arr, memoryview(arr)]:
6165            for copy in self.true_vals:
6166                res = np.array(view, copy=copy, order=order2)
6167                assert res is not arr and res.flags.owndata
6168                assert_array_equal(arr, res)
6169
6170            if no_copy_necessary:
6171                for copy in self.false_vals:
6172                    res = np.array(view, copy=copy, order=order2)
6173                    # res.base.obj refers to the memoryview
6174                    if not IS_PYPY:
6175                        assert res is arr or res.base.obj is arr
6176
6177                res = np.array(view, copy=np._CopyMode.NEVER, order=order2)
6178                if not IS_PYPY:
6179                    assert res is arr or res.base.obj is arr
6180            else:
6181                for copy in self.false_vals:
6182                    res = np.array(arr, copy=copy, order=order2)
6183                    assert_array_equal(arr, res)
6184                assert_raises(
6185                    ValueError, np.array, view, copy=np._CopyMode.NEVER, order=order2
6186                )
6187                assert_raises(ValueError, np.array, view, copy=None, order=order2)
6188
6189    def test_striding_not_ok(self):
6190        arr = np.array([[1, 2, 4], [3, 4, 5]])
6191        assert_raises(ValueError, np.array, arr.T, copy=np._CopyMode.NEVER, order="C")
6192        assert_raises(
6193            ValueError,
6194            np.array,
6195            arr.T,
6196            copy=np._CopyMode.NEVER,
6197            order="C",
6198            dtype=np.int64,
6199        )
6200        assert_raises(ValueError, np.array, arr, copy=np._CopyMode.NEVER, order="F")
6201        assert_raises(
6202            ValueError,
6203            np.array,
6204            arr,
6205            copy=np._CopyMode.NEVER,
6206            order="F",
6207            dtype=np.int64,
6208        )
6209
6210
6211class TestArrayAttributeDeletion(TestCase):
6212    def test_multiarray_writable_attributes_deletion(self):
6213        # ticket #2046, should not seqfault, raise AttributeError
6214        a = np.ones(2)
6215        attr = ["shape", "strides", "data", "dtype", "real", "imag", "flat"]
6216        with suppress_warnings() as sup:
6217            sup.filter(DeprecationWarning, "Assigning the 'data' attribute")
6218            for s in attr:
6219                assert_raises(AttributeError, delattr, a, s)
6220
6221    def test_multiarray_not_writable_attributes_deletion(self):
6222        a = np.ones(2)
6223        attr = [
6224            "ndim",
6225            "flags",
6226            "itemsize",
6227            "size",
6228            "nbytes",
6229            "base",
6230            "ctypes",
6231            "T",
6232            "__array_interface__",
6233            "__array_struct__",
6234            "__array_priority__",
6235            "__array_finalize__",
6236        ]
6237        for s in attr:
6238            assert_raises(AttributeError, delattr, a, s)
6239
6240    def test_multiarray_flags_writable_attribute_deletion(self):
6241        a = np.ones(2).flags
6242        attr = ["writebackifcopy", "updateifcopy", "aligned", "writeable"]
6243        for s in attr:
6244            assert_raises(AttributeError, delattr, a, s)
6245
6246    def test_multiarray_flags_not_writable_attribute_deletion(self):
6247        a = np.ones(2).flags
6248        attr = [
6249            "contiguous",
6250            "c_contiguous",
6251            "f_contiguous",
6252            "fortran",
6253            "owndata",
6254            "fnc",
6255            "forc",
6256            "behaved",
6257            "carray",
6258            "farray",
6259            "num",
6260        ]
6261        for s in attr:
6262            assert_raises(AttributeError, delattr, a, s)
6263
6264
6265@skip  # not supported, too brittle, too annoying
6266@instantiate_parametrized_tests
6267class TestArrayInterface(TestCase):
6268    class Foo:
6269        def __init__(self, value):
6270            self.value = value
6271            self.iface = {"typestr": "f8"}
6272
6273        def __float__(self):
6274            return float(self.value)
6275
6276        @property
6277        def __array_interface__(self):
6278            return self.iface
6279
6280    f = Foo(0.5)
6281
6282    @parametrize(
6283        "val, iface, expected",
6284        [
6285            (f, {}, 0.5),
6286            ([f], {}, [0.5]),
6287            ([f, f], {}, [0.5, 0.5]),
6288            (f, {"shape": ()}, 0.5),
6289            (f, {"shape": None}, TypeError),
6290            (f, {"shape": (1, 1)}, [[0.5]]),
6291            (f, {"shape": (2,)}, ValueError),
6292            (f, {"strides": ()}, 0.5),
6293            (f, {"strides": (2,)}, ValueError),
6294            (f, {"strides": 16}, TypeError),
6295        ],
6296    )
6297    def test_scalar_interface(self, val, iface, expected):
6298        # Test scalar coercion within the array interface
6299        self.f.iface = {"typestr": "f8"}
6300        self.f.iface.update(iface)
6301        if HAS_REFCOUNT:
6302            pre_cnt = sys.getrefcount(np.dtype("f8"))
6303        if isinstance(expected, type):
6304            assert_raises(expected, np.array, val)
6305        else:
6306            result = np.array(val)
6307            assert_equal(np.array(val), expected)
6308            assert result.dtype == "f8"
6309            del result
6310        if HAS_REFCOUNT:
6311            post_cnt = sys.getrefcount(np.dtype("f8"))
6312            assert_equal(pre_cnt, post_cnt)
6313
6314
6315class TestDelMisc(TestCase):
6316    @xpassIfTorchDynamo  # (reason="TODO")
6317    def test_flat_element_deletion(self):
6318        it = np.ones(3).flat
6319        try:
6320            del it[1]
6321            del it[1:2]
6322        except TypeError:
6323            pass
6324        except Exception:
6325            raise AssertionError from None
6326
6327
6328class TestConversion(TestCase):
6329    def test_array_scalar_relational_operation(self):
6330        # All integer
6331        for dt1 in np.typecodes["AllInteger"]:
6332            assert_(1 > np.array(0, dtype=dt1), f"type {dt1} failed")
6333            assert_(not 1 < np.array(0, dtype=dt1), f"type {dt1} failed")
6334
6335            for dt2 in np.typecodes["AllInteger"]:
6336                assert_(
6337                    np.array(1, dtype=dt1) > np.array(0, dtype=dt2),
6338                    f"type {dt1} and {dt2} failed",
6339                )
6340                assert_(
6341                    not np.array(1, dtype=dt1) < np.array(0, dtype=dt2),
6342                    f"type {dt1} and {dt2} failed",
6343                )
6344
6345        # Unsigned integers
6346        for dt1 in "B":
6347            assert_(-1 < np.array(1, dtype=dt1), f"type {dt1} failed")
6348            assert_(not -1 > np.array(1, dtype=dt1), f"type {dt1} failed")
6349            assert_(-1 != np.array(1, dtype=dt1), f"type {dt1} failed")
6350
6351            # Unsigned vs signed
6352            for dt2 in "bhil":
6353                assert_(
6354                    np.array(1, dtype=dt1) > np.array(-1, dtype=dt2),
6355                    f"type {dt1} and {dt2} failed",
6356                )
6357                assert_(
6358                    not np.array(1, dtype=dt1) < np.array(-1, dtype=dt2),
6359                    f"type {dt1} and {dt2} failed",
6360                )
6361                assert_(
6362                    np.array(1, dtype=dt1) != np.array(-1, dtype=dt2),
6363                    f"type {dt1} and {dt2} failed",
6364                )
6365
6366        # Signed integers and floats
6367        for dt1 in "bhl" + np.typecodes["Float"]:
6368            assert_(1 > np.array(-1, dtype=dt1), f"type {dt1} failed")
6369            assert_(not 1 < np.array(-1, dtype=dt1), f"type {dt1} failed")
6370            assert_(-1 == np.array(-1, dtype=dt1), f"type {dt1} failed")
6371
6372            for dt2 in "bhl" + np.typecodes["Float"]:
6373                assert_(
6374                    np.array(1, dtype=dt1) > np.array(-1, dtype=dt2),
6375                    f"type {dt1} and {dt2} failed",
6376                )
6377                assert_(
6378                    not np.array(1, dtype=dt1) < np.array(-1, dtype=dt2),
6379                    f"type {dt1} and {dt2} failed",
6380                )
6381                assert_(
6382                    np.array(-1, dtype=dt1) == np.array(-1, dtype=dt2),
6383                    f"type {dt1} and {dt2} failed",
6384                )
6385
6386    @skip(reason="object arrays")
6387    def test_to_bool_scalar(self):
6388        assert_equal(bool(np.array([False])), False)
6389        assert_equal(bool(np.array([True])), True)
6390        assert_equal(bool(np.array([[42]])), True)
6391        assert_raises(ValueError, bool, np.array([1, 2]))
6392
6393        class NotConvertible:
6394            def __bool__(self):
6395                raise NotImplementedError
6396
6397        assert_raises(NotImplementedError, bool, np.array(NotConvertible()))
6398        assert_raises(NotImplementedError, bool, np.array([NotConvertible()]))
6399        if IS_PYSTON:
6400            raise SkipTest("Pyston disables recursion checking")
6401
6402        self_containing = np.array([None])
6403        self_containing[0] = self_containing
6404
6405        Error = RecursionError
6406
6407        assert_raises(Error, bool, self_containing)  # previously stack overflow
6408        self_containing[0] = None  # resolve circular reference
6409
6410    def test_to_int_scalar(self):
6411        # gh-9972 means that these aren't always the same
6412        int_funcs = (int, lambda x: x.__int__())
6413        for int_func in int_funcs:
6414            assert_equal(int_func(np.array(0)), 0)
6415            assert_equal(int_func(np.array([1])), 1)
6416            assert_equal(int_func(np.array([[42]])), 42)
6417            assert_raises((ValueError, TypeError), int_func, np.array([1, 2]))
6418
6419    @skip(reason="object arrays")
6420    def test_to_int_scalar_2(self):
6421        int_funcs = (int, lambda x: x.__int__())
6422        for int_func in int_funcs:
6423            # gh-9972
6424            assert_equal(4, int_func(np.array("4")))
6425            assert_equal(5, int_func(np.bytes_(b"5")))
6426            assert_equal(6, int_func(np.str_("6")))
6427
6428            # The delegation of int() to __trunc__ was deprecated in
6429            # Python 3.11.
6430            if sys.version_info < (3, 11):
6431
6432                class HasTrunc:
6433                    def __trunc__(self):
6434                        return 3
6435
6436                assert_equal(3, int_func(np.array(HasTrunc())))
6437                assert_equal(3, int_func(np.array([HasTrunc()])))
6438            else:
6439                pass
6440
6441            class NotConvertible:
6442                def __int__(self):
6443                    raise NotImplementedError
6444
6445            assert_raises(NotImplementedError, int_func, np.array(NotConvertible()))
6446            assert_raises(NotImplementedError, int_func, np.array([NotConvertible()]))
6447
6448
6449class TestWhere(TestCase):
6450    def test_basic(self):
6451        dts = [bool, np.int16, np.int32, np.int64, np.double, np.complex128]
6452        for dt in dts:
6453            c = np.ones(53, dtype=bool)
6454            assert_equal(np.where(c, dt(0), dt(1)), dt(0))
6455            assert_equal(np.where(~c, dt(0), dt(1)), dt(1))
6456            assert_equal(np.where(True, dt(0), dt(1)), dt(0))
6457            assert_equal(np.where(False, dt(0), dt(1)), dt(1))
6458            d = np.ones_like(c).astype(dt)
6459            e = np.zeros_like(d)
6460            r = d.astype(dt)
6461            c[7] = False
6462            r[7] = e[7]
6463            assert_equal(np.where(c, e, e), e)
6464            assert_equal(np.where(c, d, e), r)
6465            assert_equal(np.where(c, d, e[0]), r)
6466            assert_equal(np.where(c, d[0], e), r)
6467            assert_equal(np.where(c[::2], d[::2], e[::2]), r[::2])
6468            assert_equal(np.where(c[1::2], d[1::2], e[1::2]), r[1::2])
6469            assert_equal(np.where(c[::3], d[::3], e[::3]), r[::3])
6470            assert_equal(np.where(c[1::3], d[1::3], e[1::3]), r[1::3])
6471        #  assert_equal(np.where(c[::-2], d[::-2], e[::-2]), r[::-2])
6472        #  assert_equal(np.where(c[::-3], d[::-3], e[::-3]), r[::-3])
6473        #  assert_equal(np.where(c[1::-3], d[1::-3], e[1::-3]), r[1::-3])
6474
6475    def test_exotic(self):
6476        # zero sized
6477        m = np.array([], dtype=bool).reshape(0, 3)
6478        b = np.array([], dtype=np.float64).reshape(0, 3)
6479        assert_array_equal(np.where(m, 0, b), np.array([]).reshape(0, 3))
6480
6481    @skip(reason="object arrays")
6482    def test_exotic_2(self):
6483        # object cast
6484        d = np.array(
6485            [
6486                -1.34,
6487                -0.16,
6488                -0.54,
6489                -0.31,
6490                -0.08,
6491                -0.95,
6492                0.000,
6493                0.313,
6494                0.547,
6495                -0.18,
6496                0.876,
6497                0.236,
6498                1.969,
6499                0.310,
6500                0.699,
6501                1.013,
6502                1.267,
6503                0.229,
6504                -1.39,
6505                0.487,
6506            ]
6507        )
6508        nan = float("NaN")
6509        e = np.array(
6510            [
6511                "5z",
6512                "0l",
6513                nan,
6514                "Wz",
6515                nan,
6516                nan,
6517                "Xq",
6518                "cs",
6519                nan,
6520                nan,
6521                "QN",
6522                nan,
6523                nan,
6524                "Fd",
6525                nan,
6526                nan,
6527                "kp",
6528                nan,
6529                "36",
6530                "i1",
6531            ],
6532            dtype=object,
6533        )
6534        m = np.array(
6535            [0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0], dtype=bool
6536        )
6537
6538        r = e[:]
6539        r[np.where(m)] = d[np.where(m)]
6540        assert_array_equal(np.where(m, d, e), r)
6541
6542        r = e[:]
6543        r[np.where(~m)] = d[np.where(~m)]
6544        assert_array_equal(np.where(m, e, d), r)
6545
6546        assert_array_equal(np.where(m, e, e), e)
6547
6548        # minimal dtype result with NaN scalar (e.g required by pandas)
6549        d = np.array([1.0, 2.0], dtype=np.float32)
6550        e = float("NaN")
6551        assert_equal(np.where(True, d, e).dtype, np.float32)
6552        e = float("Infinity")
6553        assert_equal(np.where(True, d, e).dtype, np.float32)
6554        e = float("-Infinity")
6555        assert_equal(np.where(True, d, e).dtype, np.float32)
6556        # also check upcast
6557        e = 1e150
6558        assert_equal(np.where(True, d, e).dtype, np.float64)
6559
6560    def test_ndim(self):
6561        c = [True, False]
6562        a = np.zeros((2, 25))
6563        b = np.ones((2, 25))
6564        r = np.where(np.array(c)[:, np.newaxis], a, b)
6565        assert_array_equal(r[0], a[0])
6566        assert_array_equal(r[1], b[0])
6567
6568        a = a.T
6569        b = b.T
6570        r = np.where(c, a, b)
6571        assert_array_equal(r[:, 0], a[:, 0])
6572        assert_array_equal(r[:, 1], b[:, 0])
6573
6574    def test_dtype_mix(self):
6575        c = np.array(
6576            [
6577                False,
6578                True,
6579                False,
6580                False,
6581                False,
6582                False,
6583                True,
6584                False,
6585                False,
6586                False,
6587                True,
6588                False,
6589            ]
6590        )
6591        a = np.uint8(1)
6592        b = np.array(
6593            [5.0, 0.0, 3.0, 2.0, -1.0, -4.0, 0.0, -10.0, 10.0, 1.0, 0.0, 3.0],
6594            dtype=np.float64,
6595        )
6596        r = np.array(
6597            [5.0, 1.0, 3.0, 2.0, -1.0, -4.0, 1.0, -10.0, 10.0, 1.0, 1.0, 3.0],
6598            dtype=np.float64,
6599        )
6600        assert_equal(np.where(c, a, b), r)
6601
6602        a = a.astype(np.float32)
6603        b = b.astype(np.int64)
6604        assert_equal(np.where(c, a, b), r)
6605
6606        # non bool mask
6607        c = c.astype(int)
6608        c[c != 0] = 34242324
6609        assert_equal(np.where(c, a, b), r)
6610        # invert
6611        tmpmask = c != 0
6612        c[c == 0] = 41247212
6613        c[tmpmask] = 0
6614        assert_equal(np.where(c, b, a), r)
6615
6616    @skip(reason="endianness")
6617    def test_foreign(self):
6618        c = np.array(
6619            [
6620                False,
6621                True,
6622                False,
6623                False,
6624                False,
6625                False,
6626                True,
6627                False,
6628                False,
6629                False,
6630                True,
6631                False,
6632            ]
6633        )
6634        r = np.array(
6635            [5.0, 1.0, 3.0, 2.0, -1.0, -4.0, 1.0, -10.0, 10.0, 1.0, 1.0, 3.0],
6636            dtype=np.float64,
6637        )
6638        a = np.ones(1, dtype=">i4")
6639        b = np.array(
6640            [5.0, 0.0, 3.0, 2.0, -1.0, -4.0, 0.0, -10.0, 10.0, 1.0, 0.0, 3.0],
6641            dtype=np.float64,
6642        )
6643        assert_equal(np.where(c, a, b), r)
6644
6645        b = b.astype(">f8")
6646        assert_equal(np.where(c, a, b), r)
6647
6648        a = a.astype("<i4")
6649        assert_equal(np.where(c, a, b), r)
6650
6651        c = c.astype(">i4")
6652        assert_equal(np.where(c, a, b), r)
6653
6654    def test_error(self):
6655        c = [True, True]
6656        a = np.ones((4, 5))
6657        b = np.ones((5, 5))
6658        assert_raises((RuntimeError, ValueError), np.where, c, a, a)
6659        assert_raises((RuntimeError, ValueError), np.where, c[0], a, b)
6660
6661    def test_empty_result(self):
6662        # pass empty where result through an assignment which reads the data of
6663        # empty arrays, error detectable with valgrind, see gh-8922
6664        x = np.zeros((1, 1))
6665        ibad = np.vstack(np.where(x == 99.0))
6666        assert_array_equal(ibad, np.atleast_2d(np.array([[], []], dtype=np.intp)))
6667
6668    def test_largedim(self):
6669        # invalid read regression gh-9304
6670        shape = [10, 2, 3, 4, 5, 6]
6671        np.random.seed(2)
6672        array = np.random.rand(*shape)
6673
6674        for i in range(10):
6675            benchmark = array.nonzero()
6676            result = array.nonzero()
6677            assert_array_equal(benchmark, result)
6678
6679    def test_kwargs(self):
6680        a = np.zeros(1)
6681        with assert_raises(TypeError):
6682            np.where(a, x=a, y=a)
6683
6684
6685class TestHashing(TestCase):
6686    def test_arrays_not_hashable(self):
6687        x = np.ones(3)
6688        assert_raises(TypeError, hash, x)
6689
6690    def test_collections_hashable(self):
6691        x = np.array([])
6692        assert_(not isinstance(x, collections.abc.Hashable))
6693
6694
6695class TestFormat(TestCase):
6696    @xpassIfTorchDynamo  # (reason="TODO")
6697    def test_0d(self):
6698        a = np.array(np.pi)
6699        assert_equal(f"{a:0.3g}", "3.14")
6700        assert_equal(f"{a[()]:0.3g}", "3.14")
6701
6702    def test_1d_no_format(self):
6703        a = np.array([np.pi])
6704        assert_equal(f"{a}", str(a))
6705
6706    def test_1d_format(self):
6707        # until gh-5543, ensure that the behaviour matches what it used to be
6708        a = np.array([np.pi])
6709        assert_raises(TypeError, "{:30}".format, a)
6710
6711
6712from numpy.testing import IS_PYPY
6713
6714
6715class TestWritebackIfCopy(TestCase):
6716    # all these tests use the WRITEBACKIFCOPY mechanism
6717    def test_argmax_with_out(self):
6718        mat = np.eye(5)
6719        out = np.empty(5, dtype="i2")
6720        res = np.argmax(mat, 0, out=out)
6721        assert_equal(res, range(5))
6722
6723    def test_argmin_with_out(self):
6724        mat = -np.eye(5)
6725        out = np.empty(5, dtype="i2")
6726        res = np.argmin(mat, 0, out=out)
6727        assert_equal(res, range(5))
6728
6729    @xpassIfTorchDynamo  # (reason="XXX: place()")
6730    def test_insert_noncontiguous(self):
6731        a = np.arange(6).reshape(2, 3).T  # force non-c-contiguous
6732        # uses arr_insert
6733        np.place(a, a > 2, [44, 55])
6734        assert_equal(a, np.array([[0, 44], [1, 55], [2, 44]]))
6735        # hit one of the failing paths
6736        assert_raises(ValueError, np.place, a, a > 20, [])
6737
6738    def test_put_noncontiguous(self):
6739        a = np.arange(6).reshape(2, 3).T  # force non-c-contiguous
6740        assert not a.flags["C_CONTIGUOUS"]  # sanity check
6741        np.put(a, [0, 2], [44, 55])
6742        assert_equal(a, np.array([[44, 3], [55, 4], [2, 5]]))
6743
6744    @xpassIfTorchDynamo  # (reason="XXX: putmask()")
6745    def test_putmask_noncontiguous(self):
6746        a = np.arange(6).reshape(2, 3).T  # force non-c-contiguous
6747        # uses arr_putmask
6748        np.putmask(a, a > 2, a**2)
6749        assert_equal(a, np.array([[0, 9], [1, 16], [2, 25]]))
6750
6751    def test_take_mode_raise(self):
6752        a = np.arange(6, dtype="int")
6753        out = np.empty(2, dtype="int")
6754        np.take(a, [0, 2], out=out, mode="raise")
6755        assert_equal(out, np.array([0, 2]))
6756
6757    def test_choose_mod_raise(self):
6758        a = np.array([[1, 0, 1], [0, 1, 0], [1, 0, 1]])
6759        out = np.empty((3, 3), dtype="int")
6760        choices = [-10, 10]
6761        np.choose(a, choices, out=out, mode="raise")
6762        assert_equal(out, np.array([[10, -10, 10], [-10, 10, -10], [10, -10, 10]]))
6763
6764    @xpassIfTorchDynamo  # (reason="XXX: ndarray.flat")
6765    def test_flatiter__array__(self):
6766        a = np.arange(9).reshape(3, 3)
6767        b = a.T.flat
6768        c = b.__array__()
6769        # triggers the WRITEBACKIFCOPY resolution, assuming refcount semantics
6770        del c
6771
6772    def test_dot_out(self):
6773        # if HAVE_CBLAS, will use WRITEBACKIFCOPY
6774        a = np.arange(9, dtype=float).reshape(3, 3)
6775        b = np.dot(a, a, out=a)
6776        assert_equal(b, np.array([[15, 18, 21], [42, 54, 66], [69, 90, 111]]))
6777
6778
6779@instantiate_parametrized_tests
6780class TestArange(TestCase):
6781    def test_infinite(self):
6782        assert_raises(
6783            (RuntimeError, ValueError), np.arange, 0, np.inf  # "unsupported range",
6784        )
6785
6786    def test_nan_step(self):
6787        assert_raises(
6788            (RuntimeError, ValueError),  # "cannot compute length",
6789            np.arange,
6790            0,
6791            1,
6792            np.nan,
6793        )
6794
6795    def test_zero_step(self):
6796        assert_raises(ZeroDivisionError, np.arange, 0, 10, 0)
6797        assert_raises(ZeroDivisionError, np.arange, 0.0, 10.0, 0.0)
6798
6799        # empty range
6800        assert_raises(ZeroDivisionError, np.arange, 0, 0, 0)
6801        assert_raises(ZeroDivisionError, np.arange, 0.0, 0.0, 0.0)
6802
6803    def test_require_range(self):
6804        assert_raises(TypeError, np.arange)
6805        assert_raises(TypeError, np.arange, step=3)
6806        assert_raises(TypeError, np.arange, dtype="int64")
6807
6808    @xpassIfTorchDynamo  # (reason="weird arange signature (optionals before required args)")
6809    def test_require_range_2(self):
6810        assert_raises(TypeError, np.arange, start=4)
6811
6812    def test_start_stop_kwarg(self):
6813        keyword_stop = np.arange(stop=3)
6814        keyword_zerotostop = np.arange(start=0, stop=3)
6815        keyword_start_stop = np.arange(start=3, stop=9)
6816
6817        assert len(keyword_stop) == 3
6818        assert len(keyword_zerotostop) == 3
6819        assert len(keyword_start_stop) == 6
6820        assert_array_equal(keyword_stop, keyword_zerotostop)
6821
6822    @skip(reason="arange for booleans: numpy maybe deprecates?")
6823    def test_arange_booleans(self):
6824        # Arange makes some sense for booleans and works up to length 2.
6825        # But it is weird since `arange(2, 4, dtype=bool)` works.
6826        # Arguably, much or all of this could be deprecated/removed.
6827        res = np.arange(False, dtype=bool)
6828        assert_array_equal(res, np.array([], dtype="bool"))
6829
6830        res = np.arange(True, dtype="bool")
6831        assert_array_equal(res, [False])
6832
6833        res = np.arange(2, dtype="bool")
6834        assert_array_equal(res, [False, True])
6835
6836        # This case is especially weird, but drops out without special case:
6837        res = np.arange(6, 8, dtype="bool")
6838        assert_array_equal(res, [True, True])
6839
6840        with pytest.raises(TypeError):
6841            np.arange(3, dtype="bool")
6842
6843    @parametrize("which", [0, 1, 2])
6844    def test_error_paths_and_promotion(self, which):
6845        args = [0, 1, 2]  # start, stop, and step
6846        args[which] = np.float64(2.0)  # should ensure float64 output
6847        assert np.arange(*args).dtype == np.float64
6848
6849        # repeat with non-empty ranges
6850        args = [0, 8, 2]
6851        args[which] = np.float64(2.0)
6852        assert np.arange(*args).dtype == np.float64
6853
6854    @parametrize("dt", [np.float32, np.uint8, complex])
6855    def test_explicit_dtype(self, dt):
6856        assert np.arange(5.0, dtype=dt).dtype == dt
6857
6858
6859class TestRichcompareScalar(TestCase):
6860    @skip  # XXX: brittle, fails or passes under dynamo depending on the NumPy version
6861    def test_richcompare_scalar_boolean_singleton_return(self):
6862        # These are currently guaranteed to be the boolean singletons, but maybe
6863        # returning NumPy booleans would also be OK:
6864        assert (np.array(0) == "a") is False
6865        assert (np.array(0) != "a") is True
6866        assert (np.int16(0) == "a") is False
6867        assert (np.int16(0) != "a") is True
6868
6869
6870@skip  # (reason="implement views/dtypes")
6871class TestViewDtype(TestCase):
6872    """
6873    Verify that making a view of a non-contiguous array works as expected.
6874    """
6875
6876    def test_smaller_dtype_multiple(self):
6877        # x is non-contiguous
6878        x = np.arange(10, dtype="<i4")[::2]
6879        with pytest.raises(ValueError, match="the last axis must be contiguous"):
6880            x.view("<i2")
6881        expected = [[0, 0], [2, 0], [4, 0], [6, 0], [8, 0]]
6882        assert_array_equal(x[:, np.newaxis].view("<i2"), expected)
6883
6884    def test_smaller_dtype_not_multiple(self):
6885        # x is non-contiguous
6886        x = np.arange(5, dtype="<i4")[::2]
6887
6888        with pytest.raises(ValueError, match="the last axis must be contiguous"):
6889            x.view("S3")
6890        with pytest.raises(ValueError, match="When changing to a smaller dtype"):
6891            x[:, np.newaxis].view("S3")
6892
6893        # Make sure the problem is because of the dtype size
6894        expected = [[b""], [b"\x02"], [b"\x04"]]
6895        assert_array_equal(x[:, np.newaxis].view("S4"), expected)
6896
6897    def test_larger_dtype_multiple(self):
6898        # x is non-contiguous in the first dimension, contiguous in the last
6899        x = np.arange(20, dtype="<i2").reshape(10, 2)[::2, :]
6900        expected = np.array(
6901            [[65536], [327684], [589832], [851980], [1114128]], dtype="<i4"
6902        )
6903        assert_array_equal(x.view("<i4"), expected)
6904
6905    def test_larger_dtype_not_multiple(self):
6906        # x is non-contiguous in the first dimension, contiguous in the last
6907        x = np.arange(20, dtype="<i2").reshape(10, 2)[::2, :]
6908        with pytest.raises(ValueError, match="When changing to a larger dtype"):
6909            x.view("S3")
6910        # Make sure the problem is because of the dtype size
6911        expected = [
6912            [b"\x00\x00\x01"],
6913            [b"\x04\x00\x05"],
6914            [b"\x08\x00\t"],
6915            [b"\x0c\x00\r"],
6916            [b"\x10\x00\x11"],
6917        ]
6918        assert_array_equal(x.view("S4"), expected)
6919
6920    def test_f_contiguous(self):
6921        # x is F-contiguous
6922        x = np.arange(4 * 3, dtype="<i4").reshape(4, 3).T
6923        with pytest.raises(ValueError, match="the last axis must be contiguous"):
6924            x.view("<i2")
6925
6926    def test_non_c_contiguous(self):
6927        # x is contiguous in axis=-1, but not C-contiguous in other axes
6928        x = np.arange(2 * 3 * 4, dtype="i1").reshape(2, 3, 4).transpose(1, 0, 2)
6929        expected = [
6930            [[256, 770], [3340, 3854]],
6931            [[1284, 1798], [4368, 4882]],
6932            [[2312, 2826], [5396, 5910]],
6933        ]
6934        assert_array_equal(x.view("<i2"), expected)
6935
6936
6937@instantiate_parametrized_tests
6938class TestSortFloatMisc(TestCase):
6939    # Test various array sizes that hit different code paths in quicksort-avx512
6940    @parametrize(
6941        "N", [8, 16, 24, 32, 48, 64, 96, 128, 151, 191, 256, 383, 512, 1023, 2047]
6942    )
6943    def test_sort_float(self, N):
6944        # Regular data with nan sprinkled
6945        np.random.seed(42)
6946        arr = -0.5 + np.random.sample(N).astype("f")
6947        arr[np.random.choice(arr.shape[0], 3)] = np.nan
6948        assert_equal(np.sort(arr, kind="quick"), np.sort(arr, kind="heap"))
6949
6950        # (2) with +INF
6951        infarr = np.inf * np.ones(N, dtype="f")
6952        infarr[np.random.choice(infarr.shape[0], 5)] = -1.0
6953        assert_equal(np.sort(infarr, kind="quick"), np.sort(infarr, kind="heap"))
6954
6955        # (3) with -INF
6956        neginfarr = -np.inf * np.ones(N, dtype="f")
6957        neginfarr[np.random.choice(neginfarr.shape[0], 5)] = 1.0
6958        assert_equal(np.sort(neginfarr, kind="quick"), np.sort(neginfarr, kind="heap"))
6959
6960        # (4) with +/-INF
6961        infarr = np.inf * np.ones(N, dtype="f")
6962        infarr[np.random.choice(infarr.shape[0], (int)(N / 2))] = -np.inf
6963        assert_equal(np.sort(infarr, kind="quick"), np.sort(infarr, kind="heap"))
6964
6965    def test_sort_int(self):
6966        # Random data with NPY_MAX_INT32 and NPY_MIN_INT32 sprinkled
6967        # rng = np.random.default_rng(42)
6968        np.random.seed(1234)
6969        N = 2047
6970        minv = np.iinfo(np.int32).min
6971        maxv = np.iinfo(np.int32).max
6972        arr = np.random.randint(low=minv, high=maxv, size=N).astype("int32")
6973        arr[np.random.choice(arr.shape[0], 10)] = minv
6974        arr[np.random.choice(arr.shape[0], 10)] = maxv
6975        assert_equal(np.sort(arr, kind="quick"), np.sort(arr, kind="heap"))
6976
6977
6978if __name__ == "__main__":
6979    run_tests()
6980