xref: /aosp_15_r20/external/pytorch/test/torch_np/numpy_tests/core/test_scalarmath.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2
3import contextlib
4import functools
5import itertools
6import operator
7import sys
8import warnings
9from unittest import expectedFailure as xfail, skipIf as skipif, SkipTest
10
11import numpy
12import pytest
13from pytest import raises as assert_raises
14
15from torch.testing._internal.common_utils import (
16    instantiate_parametrized_tests,
17    parametrize,
18    run_tests,
19    skipIfTorchDynamo,
20    slowTest as slow,
21    subtest,
22    TEST_WITH_TORCHDYNAMO,
23    TestCase,
24    xpassIfTorchDynamo,
25)
26
27
28if TEST_WITH_TORCHDYNAMO:
29    import numpy as np
30    from numpy.testing import (
31        _gen_alignment_data,
32        assert_,
33        assert_almost_equal,
34        assert_equal,
35    )
36else:
37    import torch._numpy as np
38    from torch._numpy.testing import (
39        _gen_alignment_data,
40        assert_,
41        assert_almost_equal,
42        assert_equal,
43    )
44
45
46skip = functools.partial(skipif, True)
47
48IS_PYPY = False
49
50types = [
51    np.bool_,
52    np.byte,
53    np.ubyte,
54    np.short,
55    np.intc,
56    np.int_,
57    np.longlong,
58    np.single,
59    np.double,
60    np.csingle,
61    np.cdouble,
62]
63
64floating_types = np.floating.__subclasses__()
65complex_floating_types = np.complexfloating.__subclasses__()
66
67objecty_things = [object(), None]
68
69reasonable_operators_for_scalars = [
70    operator.lt,
71    operator.le,
72    operator.eq,
73    operator.ne,
74    operator.ge,
75    operator.gt,
76    operator.add,
77    operator.floordiv,
78    operator.mod,
79    operator.mul,
80    operator.pow,
81    operator.sub,
82    operator.truediv,
83]
84
85
86# This compares scalarmath against ufuncs.
87
88
89class TestTypes(TestCase):
90    def test_types(self):
91        for atype in types:
92            a = atype(1)
93            assert_(a == 1, f"error with {atype!r}: got {a!r}")
94
95    def test_type_add(self):
96        # list of types
97        for k, atype in enumerate(types):
98            a_scalar = atype(3)
99            a_array = np.array([3], dtype=atype)
100            for l, btype in enumerate(types):
101                b_scalar = btype(1)
102                b_array = np.array([1], dtype=btype)
103                c_scalar = a_scalar + b_scalar
104                c_array = a_array + b_array
105                # It was comparing the type numbers, but the new ufunc
106                # function-finding mechanism finds the lowest function
107                # to which both inputs can be cast - which produces 'l'
108                # when you do 'q' + 'b'.  The old function finding mechanism
109                # skipped ahead based on the first argument, but that
110                # does not produce properly symmetric results...
111                assert_equal(
112                    c_scalar.dtype,
113                    c_array.dtype,
114                    "error with types (%d/'%s' + %d/'%s')"
115                    % (k, np.dtype(atype).name, l, np.dtype(btype).name),
116                )
117
118    def test_type_create(self):
119        for k, atype in enumerate(types):
120            a = np.array([1, 2, 3], atype)
121            b = atype([1, 2, 3])
122            assert_equal(a, b)
123
124    @skipIfTorchDynamo()  # freezes under torch.Dynamo (loop unrolling, huh)
125    def test_leak(self):
126        # test leak of scalar objects
127        # a leak would show up in valgrind as still-reachable of ~2.6MB
128        for i in range(200000):
129            np.add(1, 1)
130
131
132class TestBaseMath(TestCase):
133    def test_blocked(self):
134        # test alignments offsets for simd instructions
135        # alignments for vz + 2 * (vs - 1) + 1
136        for dt, sz in [(np.float32, 11), (np.float64, 7), (np.int32, 11)]:
137            for out, inp1, inp2, msg in _gen_alignment_data(
138                dtype=dt, type="binary", max_size=sz
139            ):
140                exp1 = np.ones_like(inp1)
141                inp1[...] = np.ones_like(inp1)
142                inp2[...] = np.zeros_like(inp2)
143                assert_almost_equal(np.add(inp1, inp2), exp1, err_msg=msg)
144                assert_almost_equal(np.add(inp1, 2), exp1 + 2, err_msg=msg)
145                assert_almost_equal(np.add(1, inp2), exp1, err_msg=msg)
146
147                np.add(inp1, inp2, out=out)
148                assert_almost_equal(out, exp1, err_msg=msg)
149
150                inp2[...] += np.arange(inp2.size, dtype=dt) + 1
151                assert_almost_equal(
152                    np.square(inp2), np.multiply(inp2, inp2), err_msg=msg
153                )
154                # skip true divide for ints
155                if dt != np.int32:
156                    assert_almost_equal(
157                        np.reciprocal(inp2), np.divide(1, inp2), err_msg=msg
158                    )
159
160                inp1[...] = np.ones_like(inp1)
161                np.add(inp1, 2, out=out)
162                assert_almost_equal(out, exp1 + 2, err_msg=msg)
163                inp2[...] = np.ones_like(inp2)
164                np.add(2, inp2, out=out)
165                assert_almost_equal(out, exp1 + 2, err_msg=msg)
166
167    @xpassIfTorchDynamo  # (reason="pytorch does not have .view")
168    def test_lower_align(self):
169        # check data that is not aligned to element size
170        # i.e doubles are aligned to 4 bytes on i386
171        d = np.zeros(23 * 8, dtype=np.int8)[4:-4].view(np.float64)
172        o = np.zeros(23 * 8, dtype=np.int8)[4:-4].view(np.float64)
173        assert_almost_equal(d + d, d * 2)
174        np.add(d, d, out=o)
175        np.add(np.ones_like(d), d, out=o)
176        np.add(d, np.ones_like(d), out=o)
177        np.add(np.ones_like(d), d)
178        np.add(d, np.ones_like(d))
179
180
181class TestPower(TestCase):
182    def test_small_types(self):
183        for t in [np.int8, np.int16, np.float16]:
184            a = t(3)
185            b = a**4
186            assert_(b == 81, f"error with {t!r}: got {b!r}")
187
188    def test_large_types(self):
189        for t in [np.int32, np.int64, np.float32, np.float64]:
190            a = t(51)
191            b = a**4
192            msg = f"error with {t!r}: got {b!r}"
193            if np.issubdtype(t, np.integer):
194                assert_(b == 6765201, msg)
195            else:
196                assert_almost_equal(b, 6765201, err_msg=msg)
197
198    @skip(reason="NP_VER: fails on CI on older NumPy")
199    @xpassIfTorchDynamo  # (reason="Value-based casting: (2)**(-2) -> 0 in pytorch.")
200    def test_integers_to_negative_integer_power(self):
201        # Note that the combination of uint64 with a signed integer
202        # has common type np.float64. The other combinations should all
203        # raise a ValueError for integer ** negative integer.
204        exp = [np.array(-1, dt)[()] for dt in "bhil"]
205
206        # 1 ** -1 possible special case
207        base = [np.array(1, dt)[()] for dt in "bhilB"]
208        for i1, i2 in itertools.product(base, exp):
209            if i1.dtype != np.uint64:
210                assert_raises(ValueError, operator.pow, i1, i2)
211            else:
212                res = operator.pow(i1, i2)
213                assert_(res.dtype.type is np.float64)
214                assert_almost_equal(res, 1.0)
215
216        # -1 ** -1 possible special case
217        base = [np.array(-1, dt)[()] for dt in "bhil"]
218        for i1, i2 in itertools.product(base, exp):
219            if i1.dtype != np.uint64:
220                assert_raises(ValueError, operator.pow, i1, i2)
221            else:
222                res = operator.pow(i1, i2)
223                assert_(res.dtype.type is np.float64)
224                assert_almost_equal(res, -1.0)
225
226        # 2 ** -1 perhaps generic
227        base = [np.array(2, dt)[()] for dt in "bhilB"]
228        for i1, i2 in itertools.product(base, exp):
229            if i1.dtype != np.uint64:
230                assert_raises(ValueError, operator.pow, i1, i2)
231            else:
232                res = operator.pow(i1, i2)
233                assert_(res.dtype.type is np.float64)
234                assert_almost_equal(res, 0.5)
235
236    def test_mixed_types(self):
237        typelist = [
238            np.int8,
239            np.int16,
240            np.float16,
241            np.float32,
242            np.float64,
243            np.int8,
244            np.int16,
245            np.int32,
246            np.int64,
247        ]
248        for t1 in typelist:
249            for t2 in typelist:
250                a = t1(3)
251                b = t2(2)
252                result = a**b
253                msg = f"error with {t1!r} and {t2!r}:" f"got {result!r}, expected {9!r}"
254                if np.issubdtype(np.dtype(result), np.integer):
255                    assert_(result == 9, msg)
256                else:
257                    assert_almost_equal(result, 9, err_msg=msg)
258
259    def test_modular_power(self):
260        # modular power is not implemented, so ensure it errors
261        a = 5
262        b = 4
263        c = 10
264        expected = pow(a, b, c)  # noqa: F841
265        for t in (np.int32, np.float32, np.complex64):
266            # note that 3-operand power only dispatches on the first argument
267            assert_raises(TypeError, operator.pow, t(a), b, c)
268            assert_raises(TypeError, operator.pow, np.array(t(a)), b, c)
269
270
271def floordiv_and_mod(x, y):
272    return (x // y, x % y)
273
274
275def _signs(dt):
276    if dt in np.typecodes["UnsignedInteger"]:
277        return (+1,)
278    else:
279        return (+1, -1)
280
281
282@instantiate_parametrized_tests
283class TestModulus(TestCase):
284    def test_modulus_basic(self):
285        # dt = np.typecodes["AllInteger"] + np.typecodes["Float"]
286        dt = "Bbhil" + "efd"
287        for op in [floordiv_and_mod, divmod]:
288            for dt1, dt2 in itertools.product(dt, dt):
289                for sg1, sg2 in itertools.product(_signs(dt1), _signs(dt2)):
290                    fmt = "op: %s, dt1: %s, dt2: %s, sg1: %s, sg2: %s"
291                    msg = fmt % (op.__name__, dt1, dt2, sg1, sg2)
292                    a = np.array(sg1 * 71, dtype=dt1)[()]
293                    b = np.array(sg2 * 19, dtype=dt2)[()]
294                    div, rem = op(a, b)
295                    assert_equal(div * b + rem, a, err_msg=msg)
296                    if sg2 == -1:
297                        assert_(b < rem <= 0, msg)
298                    else:
299                        assert_(b > rem >= 0, msg)
300
301    @slow
302    def test_float_modulus_exact(self):
303        # test that float results are exact for small integers. This also
304        # holds for the same integers scaled by powers of two.
305        nlst = list(range(-127, 0))
306        plst = list(range(1, 128))
307        dividend = nlst + [0] + plst
308        divisor = nlst + plst
309        arg = list(itertools.product(dividend, divisor))
310        tgt = [divmod(*t) for t in arg]
311
312        a, b = np.array(arg, dtype=int).T
313        # convert exact integer results from Python to float so that
314        # signed zero can be used, it is checked.
315        tgtdiv, tgtrem = np.array(tgt, dtype=float).T
316        tgtdiv = np.where((tgtdiv == 0.0) & ((b < 0) ^ (a < 0)), -0.0, tgtdiv)
317        tgtrem = np.where((tgtrem == 0.0) & (b < 0), -0.0, tgtrem)
318
319        for op in [floordiv_and_mod, divmod]:
320            for dt in np.typecodes["Float"]:
321                msg = f"op: {op.__name__}, dtype: {dt}"
322                fa = a.astype(dt)
323                fb = b.astype(dt)
324                # use list comprehension so a_ and b_ are scalars
325                div, rem = zip(*[op(a_, b_) for a_, b_ in zip(fa, fb)])
326                assert_equal(div, tgtdiv, err_msg=msg)
327                assert_equal(rem, tgtrem, err_msg=msg)
328
329    def test_float_modulus_roundoff(self):
330        # gh-6127
331        # dt = np.typecodes["Float"]
332        dt = "efd"
333        for op in [floordiv_and_mod, divmod]:
334            for dt1, dt2 in itertools.product(dt, dt):
335                for sg1, sg2 in itertools.product((+1, -1), (+1, -1)):
336                    fmt = "op: %s, dt1: %s, dt2: %s, sg1: %s, sg2: %s"
337                    msg = fmt % (op.__name__, dt1, dt2, sg1, sg2)
338                    a = np.array(sg1 * 78 * 6e-8, dtype=dt1)[()]
339                    b = np.array(sg2 * 6e-8, dtype=dt2)[()]
340                    div, rem = op(a, b)
341                    # Equal assertion should hold when fmod is used
342                    assert_equal(div * b + rem, a, err_msg=msg)
343                    if sg2 == -1:
344                        assert_(b < rem <= 0, msg)
345                    else:
346                        assert_(b > rem >= 0, msg)
347
348    @parametrize("dt", "efd")
349    def test_float_modulus_corner_cases(self, dt):
350        if dt == "e":
351            # FIXME: make xfail
352            raise SkipTest("RuntimeError: 'nextafter_cpu' not implemented for 'Half'")
353
354        b = np.array(1.0, dtype=dt)
355        a = np.nextafter(np.array(0.0, dtype=dt), -b)
356        rem = operator.mod(a, b)
357        assert_(rem <= b, f"dt: {dt}")
358        rem = operator.mod(-a, -b)
359        assert_(rem >= -b, f"dt: {dt}")
360
361        # Check nans, inf
362        #     with suppress_warnings() as sup:
363        #         sup.filter(RuntimeWarning, "invalid value encountered in remainder")
364        #         sup.filter(RuntimeWarning, "divide by zero encountered in remainder")
365        #         sup.filter(RuntimeWarning, "divide by zero encountered in floor_divide")
366        #         sup.filter(RuntimeWarning, "divide by zero encountered in divmod")
367        #         sup.filter(RuntimeWarning, "invalid value encountered in divmod")
368        for dt in "efd":
369            fone = np.array(1.0, dtype=dt)
370            fzer = np.array(0.0, dtype=dt)
371            finf = np.array(np.inf, dtype=dt)
372            fnan = np.array(np.nan, dtype=dt)
373            rem = operator.mod(fone, fzer)
374            assert_(np.isnan(rem), f"dt: {dt}")
375            # MSVC 2008 returns NaN here, so disable the check.
376            # rem = operator.mod(fone, finf)
377            # assert_(rem == fone, 'dt: %s' % dt)
378            rem = operator.mod(fone, fnan)
379            assert_(np.isnan(rem), f"dt: {dt}")
380            rem = operator.mod(finf, fone)
381            assert_(np.isnan(rem), f"dt: {dt}")
382            for op in [floordiv_and_mod, divmod]:
383                div, mod = op(fone, fzer)
384                assert_(np.isinf(div)) and assert_(np.isnan(mod))
385
386
387class TestComplexDivision(TestCase):
388    @skip(reason="With pytorch, 1/(0+0j) is nan + nan*j, not inf + nan*j")
389    def test_zero_division(self):
390        for t in [np.complex64, np.complex128]:
391            a = t(0.0)
392            b = t(1.0)
393            assert_(np.isinf(b / a))
394            b = t(complex(np.inf, np.inf))
395            assert_(np.isinf(b / a))
396            b = t(complex(np.inf, np.nan))
397            assert_(np.isinf(b / a))
398            b = t(complex(np.nan, np.inf))
399            assert_(np.isinf(b / a))
400            b = t(complex(np.nan, np.nan))
401            assert_(np.isnan(b / a))
402            b = t(0.0)
403            assert_(np.isnan(b / a))
404
405    def test_signed_zeros(self):
406        for t in [np.complex64, np.complex128]:
407            # tupled (numerator, denominator, expected)
408            # for testing as expected == numerator/denominator
409            data = (
410                ((0.0, -1.0), (0.0, 1.0), (-1.0, -0.0)),
411                ((0.0, -1.0), (0.0, -1.0), (1.0, -0.0)),
412                ((0.0, -1.0), (-0.0, -1.0), (1.0, 0.0)),
413                ((0.0, -1.0), (-0.0, 1.0), (-1.0, 0.0)),
414                ((0.0, 1.0), (0.0, -1.0), (-1.0, 0.0)),
415                ((0.0, -1.0), (0.0, -1.0), (1.0, -0.0)),
416                ((-0.0, -1.0), (0.0, -1.0), (1.0, -0.0)),
417                ((-0.0, 1.0), (0.0, -1.0), (-1.0, -0.0)),
418            )
419            for cases in data:
420                n = cases[0]
421                d = cases[1]
422                ex = cases[2]
423                result = t(complex(n[0], n[1])) / t(complex(d[0], d[1]))
424                # check real and imag parts separately to avoid comparison
425                # in array context, which does not account for signed zeros
426                assert_equal(result.real, ex[0])
427                assert_equal(result.imag, ex[1])
428
429    def test_branches(self):
430        for t in [np.complex64, np.complex128]:
431            # tupled (numerator, denominator, expected)
432            # for testing as expected == numerator/denominator
433            data = []
434
435            # trigger branch: real(fabs(denom)) > imag(fabs(denom))
436            # followed by else condition as neither are == 0
437            data.append(((2.0, 1.0), (2.0, 1.0), (1.0, 0.0)))
438
439            # trigger branch: real(fabs(denom)) > imag(fabs(denom))
440            # followed by if condition as both are == 0
441            # is performed in test_zero_division(), so this is skipped
442
443            # trigger else if branch: real(fabs(denom)) < imag(fabs(denom))
444            data.append(((1.0, 2.0), (1.0, 2.0), (1.0, 0.0)))
445
446            for cases in data:
447                n = cases[0]
448                d = cases[1]
449                ex = cases[2]
450                result = t(complex(n[0], n[1])) / t(complex(d[0], d[1]))
451                # check real and imag parts separately to avoid comparison
452                # in array context, which does not account for signed zeros
453                assert_equal(result.real, ex[0])
454                assert_equal(result.imag, ex[1])
455
456
457class TestConversion(TestCase):
458    def test_int_from_long(self):
459        # NB: this test assumes that the default fp type is float64
460        l = [1e6, 1e12, 1e18, -1e6, -1e12, -1e18]
461        li = [10**6, 10**12, 10**18, -(10**6), -(10**12), -(10**18)]
462        for T in [None, np.float64, np.int64]:
463            a = np.array(l, dtype=T)
464            assert_equal([int(_m) for _m in a], li)
465
466    @skipif(numpy.__version__ < "1.24", reason="NP_VER: fails on NumPy 1.23.x")
467    @xpassIfTorchDynamo  # (reason="pytorch does not emit this warning.")
468    def test_iinfo_long_values_1(self):
469        for code in "bBh":
470            with pytest.warns(DeprecationWarning):
471                res = np.array(np.iinfo(code).max + 1, dtype=code)
472            tgt = np.iinfo(code).min
473            assert_(res == tgt)
474
475    def test_iinfo_long_values_2(self):
476        for code in np.typecodes["AllInteger"]:
477            res = np.array(np.iinfo(code).max, dtype=code)
478            tgt = np.iinfo(code).max
479            assert_(res == tgt)
480
481        for code in np.typecodes["AllInteger"]:
482            res = np.dtype(code).type(np.iinfo(code).max)
483            tgt = np.iinfo(code).max
484            assert_(res == tgt)
485
486    def test_int_raise_behaviour(self):
487        def overflow_error_func(dtype):
488            dtype(np.iinfo(dtype).max + 1)
489
490        for code in [np.int_, np.longlong]:
491            assert_raises((OverflowError, RuntimeError), overflow_error_func, code)
492
493    def test_numpy_scalar_relational_operators(self):
494        # All integer
495        for dt1 in np.typecodes["AllInteger"]:
496            assert_(1 > np.array(0, dtype=dt1)[()], f"type {dt1} failed")
497            assert_(not 1 < np.array(0, dtype=dt1)[()], f"type {dt1} failed")
498
499            for dt2 in np.typecodes["AllInteger"]:
500                assert_(
501                    np.array(1, dtype=dt1)[()] > np.array(0, dtype=dt2)[()],
502                    f"type {dt1} and {dt2} failed",
503                )
504                assert_(
505                    not np.array(1, dtype=dt1)[()] < np.array(0, dtype=dt2)[()],
506                    f"type {dt1} and {dt2} failed",
507                )
508
509        # Signed integers and floats
510        for dt1 in "bhl" + np.typecodes["Float"]:
511            assert_(1 > np.array(-1, dtype=dt1)[()], f"type {dt1} failed")
512            assert_(not 1 < np.array(-1, dtype=dt1)[()], f"type {dt1} failed")
513            assert_(-1 == np.array(-1, dtype=dt1)[()], f"type {dt1} failed")
514
515            for dt2 in "bhl" + np.typecodes["Float"]:
516                assert_(
517                    np.array(1, dtype=dt1)[()] > np.array(-1, dtype=dt2)[()],
518                    f"type {dt1} and {dt2} failed",
519                )
520                assert_(
521                    not np.array(1, dtype=dt1)[()] < np.array(-1, dtype=dt2)[()],
522                    f"type {dt1} and {dt2} failed",
523                )
524                assert_(
525                    np.array(-1, dtype=dt1)[()] == np.array(-1, dtype=dt2)[()],
526                    f"type {dt1} and {dt2} failed",
527                )
528
529    def test_numpy_scalar_relational_operators_2(self):
530        # Unsigned integers
531        for dt1 in "B":
532            assert_(-1 < np.array(1, dtype=dt1)[()], f"type {dt1} failed")
533            assert_(not -1 > np.array(1, dtype=dt1)[()], f"type {dt1} failed")
534            assert_(-1 != np.array(1, dtype=dt1)[()], f"type {dt1} failed")
535
536            # unsigned vs signed
537            for dt2 in "bhil":
538                assert_(
539                    np.array(1, dtype=dt1)[()] > np.array(-1, dtype=dt2)[()],
540                    f"type {dt1} and {dt2} failed",
541                )
542                assert_(
543                    not np.array(1, dtype=dt1)[()] < np.array(-1, dtype=dt2)[()],
544                    f"type {dt1} and {dt2} failed",
545                )
546                assert_(
547                    np.array(1, dtype=dt1)[()] != np.array(-1, dtype=dt2)[()],
548                    f"type {dt1} and {dt2} failed",
549                )
550
551    def test_scalar_comparison_to_none(self):
552        # Scalars should just return False and not give a warnings.
553        # The comparisons are flagged by pep8, ignore that.
554        with warnings.catch_warnings(record=True) as w:
555            warnings.filterwarnings("always", "", FutureWarning)
556            assert_(np.float32(1) is not None)
557            assert_(np.float32(1) is not None)
558        assert_(len(w) == 0)
559
560
561# class TestRepr:
562#    def test_repr(self):
563#        for t in types:
564#            val = t(1197346475.0137341)
565#            val_repr = repr(val)
566#            val2 = eval(val_repr)
567#            assert_equal( val, val2 )
568
569
570@xpassIfTorchDynamo  # (reason="can delegate repr to pytorch")
571class TestRepr(TestCase):
572    def _test_type_repr(self, t):
573        finfo = np.finfo(t)
574        last_fraction_bit_idx = finfo.nexp + finfo.nmant
575        last_exponent_bit_idx = finfo.nexp
576        storage_bytes = np.dtype(t).itemsize * 8
577        # could add some more types to the list below
578        for which in ["small denorm", "small norm"]:
579            # Values from https://en.wikipedia.org/wiki/IEEE_754
580            constr = np.array([0x00] * storage_bytes, dtype=np.uint8)
581            if which == "small denorm":
582                byte = last_fraction_bit_idx // 8
583                bytebit = 7 - (last_fraction_bit_idx % 8)
584                constr[byte] = 1 << bytebit
585            elif which == "small norm":
586                byte = last_exponent_bit_idx // 8
587                bytebit = 7 - (last_exponent_bit_idx % 8)
588                constr[byte] = 1 << bytebit
589            else:
590                raise ValueError("hmm")
591            val = constr.view(t)[0]
592            val_repr = repr(val)
593            val2 = t(eval(val_repr))
594            if not (val2 == 0 and val < 1e-100):
595                assert_equal(val, val2)
596
597    def test_float_repr(self):
598        # long double test cannot work, because eval goes through a python
599        # float
600        for t in [np.float32, np.float64]:
601            self._test_type_repr(t)
602
603
604@skip(reason="Array scalars do not decay to python scalars.")
605class TestMultiply(TestCase):
606    def test_seq_repeat(self):
607        # Test that basic sequences get repeated when multiplied with
608        # numpy integers. And errors are raised when multiplied with others.
609        # Some of this behaviour may be controversial and could be open for
610        # change.
611        accepted_types = set(np.typecodes["AllInteger"])
612        deprecated_types = {"?"}
613        forbidden_types = set(np.typecodes["All"]) - accepted_types - deprecated_types
614        forbidden_types -= {"V"}  # can't default-construct void scalars
615
616        for seq_type in (list, tuple):
617            seq = seq_type([1, 2, 3])
618            for numpy_type in accepted_types:
619                i = np.dtype(numpy_type).type(2)
620                assert_equal(seq * i, seq * int(i))
621                assert_equal(i * seq, int(i) * seq)
622
623            for numpy_type in deprecated_types:
624                i = np.dtype(numpy_type).type()
625                assert_equal(
626                    assert_warns(DeprecationWarning, operator.mul, seq, i), seq * int(i)
627                )
628                assert_equal(
629                    assert_warns(DeprecationWarning, operator.mul, i, seq), int(i) * seq
630                )
631
632            for numpy_type in forbidden_types:
633                i = np.dtype(numpy_type).type()
634                assert_raises(TypeError, operator.mul, seq, i)
635                assert_raises(TypeError, operator.mul, i, seq)
636
637    def test_no_seq_repeat_basic_array_like(self):
638        # Test that an array-like which does not know how to be multiplied
639        # does not attempt sequence repeat (raise TypeError).
640        # See also gh-7428.
641        class ArrayLike:
642            def __init__(self, arr):
643                self.arr = arr
644
645            def __array__(self):
646                return self.arr
647
648        # Test for simple ArrayLike above and memoryviews (original report)
649        for arr_like in (ArrayLike(np.ones(3)), memoryview(np.ones(3))):
650            assert_array_equal(arr_like * np.float32(3.0), np.full(3, 3.0))
651            assert_array_equal(np.float32(3.0) * arr_like, np.full(3, 3.0))
652            assert_array_equal(arr_like * np.int_(3), np.full(3, 3))
653            assert_array_equal(np.int_(3) * arr_like, np.full(3, 3))
654
655
656class TestNegative(TestCase):
657    def test_exceptions(self):
658        a = np.ones((), dtype=np.bool_)[()]
659        # XXX: TypeError from numpy, RuntimeError from torch
660        assert_raises((TypeError, RuntimeError), operator.neg, a)
661
662    def test_result(self):
663        types = np.typecodes["AllInteger"] + np.typecodes["AllFloat"]
664        # with suppress_warnings() as sup:
665        #     sup.filter(RuntimeWarning)
666        for dt in types:
667            a = np.ones((), dtype=dt)[()]
668            if dt in np.typecodes["UnsignedInteger"]:
669                st = np.dtype(dt).type
670                max = st(np.iinfo(dt).max)
671                assert_equal(operator.neg(a), max)
672            else:
673                assert_equal(operator.neg(a) + a, 0)
674
675
676class TestSubtract(TestCase):
677    def test_exceptions(self):
678        a = np.ones((), dtype=np.bool_)[()]
679        with assert_raises((TypeError, RuntimeError)):  # XXX: TypeError from numpy
680            operator.sub(a, a)  # RuntimeError from torch
681
682    def test_result(self):
683        types = np.typecodes["AllInteger"] + np.typecodes["AllFloat"]
684        #        with suppress_warnings() as sup:
685        #            sup.filter(RuntimeWarning)
686        for dt in types:
687            a = np.ones((), dtype=dt)[()]
688            assert_equal(operator.sub(a, a), 0)
689
690
691@instantiate_parametrized_tests
692class TestAbs(TestCase):
693    def _test_abs_func(self, absfunc, test_dtype):
694        x = test_dtype(-1.5)
695        assert_equal(absfunc(x), 1.5)
696        x = test_dtype(0.0)
697        res = absfunc(x)
698        # assert_equal() checks zero signedness
699        assert_equal(res, 0.0)
700        x = test_dtype(-0.0)
701        res = absfunc(x)
702        assert_equal(res, 0.0)
703
704        x = test_dtype(np.finfo(test_dtype).max)
705        assert_equal(absfunc(x), x.real)
706
707        #      with suppress_warnings() as sup:
708        #          sup.filter(UserWarning)
709        x = test_dtype(np.finfo(test_dtype).tiny)
710        assert_equal(absfunc(x), x.real)
711
712        x = test_dtype(np.finfo(test_dtype).min)
713        assert_equal(absfunc(x), -x.real)
714
715    @parametrize("dtype", floating_types + complex_floating_types)
716    def test_builtin_abs(self, dtype):
717        self._test_abs_func(abs, dtype)
718
719    @parametrize("dtype", floating_types + complex_floating_types)
720    def test_numpy_abs(self, dtype):
721        self._test_abs_func(np.abs, dtype)
722
723
724@instantiate_parametrized_tests
725class TestBitShifts(TestCase):
726    @parametrize("type_code", np.typecodes["AllInteger"])
727    @parametrize("op", [operator.rshift, operator.lshift])
728    def test_shift_all_bits(self, type_code, op):
729        """Shifts where the shift amount is the width of the type or wider"""
730        # gh-2449
731        dt = np.dtype(type_code)
732        nbits = dt.itemsize * 8
733        if dt in (np.dtype(np.uint64), np.dtype(np.uint32), np.dtype(np.uint16)):
734            raise SkipTest("NYI: bitshift uint64")
735
736        for val in [5, -5]:
737            for shift in [nbits, nbits + 4]:
738                val_scl = np.array(val).astype(dt)[()]
739                shift_scl = dt.type(shift)
740
741                res_scl = op(val_scl, shift_scl)
742                if val_scl < 0 and op is operator.rshift:
743                    # sign bit is preserved
744                    assert_equal(res_scl, -1)
745                else:
746                    if type_code in ("i", "l") and shift == np.iinfo(type_code).bits:
747                        # FIXME: make xfail
748                        raise SkipTest(
749                            "https://github.com/pytorch/pytorch/issues/70904"
750                        )
751                    assert_equal(res_scl, 0)
752
753                # Result on scalars should be the same as on arrays
754                val_arr = np.array([val_scl] * 32, dtype=dt)
755                shift_arr = np.array([shift] * 32, dtype=dt)
756                res_arr = op(val_arr, shift_arr)
757                assert_equal(res_arr, res_scl)
758
759
760@skip(reason="Will rely on pytest for hashing")
761@instantiate_parametrized_tests
762class TestHash(TestCase):
763    @parametrize("type_code", np.typecodes["AllInteger"])
764    def test_integer_hashes(self, type_code):
765        scalar = np.dtype(type_code).type
766        for i in range(128):
767            assert hash(i) == hash(scalar(i))
768
769    @parametrize("type_code", np.typecodes["AllFloat"])
770    def test_float_and_complex_hashes(self, type_code):
771        scalar = np.dtype(type_code).type
772        for val in [np.pi, np.inf, 3, 6.0]:
773            numpy_val = scalar(val)
774            # Cast back to Python, in case the NumPy scalar has less precision
775            if numpy_val.dtype.kind == "c":
776                val = complex(numpy_val)
777            else:
778                val = float(numpy_val)
779            assert val == numpy_val
780            assert hash(val) == hash(numpy_val)
781
782        if hash(float(np.nan)) != hash(float(np.nan)):
783            # If Python distinguishes different NaNs we do so too (gh-18833)
784            assert hash(scalar(np.nan)) != hash(scalar(np.nan))
785
786    @parametrize("type_code", np.typecodes["Complex"])
787    def test_complex_hashes(self, type_code):
788        # Test some complex valued hashes specifically:
789        scalar = np.dtype(type_code).type
790        for val in [np.pi + 1j, np.inf - 3j, 3j, 6.0 + 1j]:
791            numpy_val = scalar(val)
792            assert hash(complex(numpy_val)) == hash(numpy_val)
793
794
795@contextlib.contextmanager
796def recursionlimit(n):
797    o = sys.getrecursionlimit()
798    try:
799        sys.setrecursionlimit(n)
800        yield
801    finally:
802        sys.setrecursionlimit(o)
803
804
805@instantiate_parametrized_tests
806class TestScalarOpsMisc(TestCase):
807    @xfail  # (reason="pytorch does not warn on overflow")
808    @parametrize("dtype", "Bbhil")
809    @parametrize(
810        "operation",
811        [
812            lambda min, max: max + max,
813            lambda min, max: min - max,
814            lambda min, max: max * max,
815        ],
816    )
817    def test_scalar_integer_operation_overflow(self, dtype, operation):
818        st = np.dtype(dtype).type
819        min = st(np.iinfo(dtype).min)
820        max = st(np.iinfo(dtype).max)
821
822        with pytest.warns(RuntimeWarning, match="overflow encountered"):
823            operation(min, max)
824
825    @skip(reason="integer overflow UB: crashes pytorch under ASAN")
826    @parametrize("dtype", "bhil")
827    @parametrize(
828        "operation",
829        [
830            lambda min, neg_1: -min,
831            lambda min, neg_1: abs(min),
832            lambda min, neg_1: min * neg_1,
833            subtest(
834                lambda min, neg_1: min // neg_1,
835                decorators=[skip(reason="broken on some platforms")],
836            ),
837        ],
838    )
839    def test_scalar_signed_integer_overflow(self, dtype, operation):
840        # The minimum signed integer can "overflow" for some additional operations
841        st = np.dtype(dtype).type
842        min = st(np.iinfo(dtype).min)
843        neg_1 = st(-1)
844
845        with pytest.warns(RuntimeWarning, match="overflow encountered"):
846            operation(min, neg_1)
847
848    @skipif(numpy.__version__ < "1.24", reason="NP_VER: fails on NumPy 1.23.x")
849    @xpassIfTorchDynamo  # (reason="pytorch does not warn on overflow")
850    @parametrize("dtype", "B")
851    def test_scalar_unsigned_integer_overflow(self, dtype):
852        val = np.dtype(dtype).type(8)
853        with pytest.warns(RuntimeWarning, match="overflow encountered"):
854            -val
855
856        zero = np.dtype(dtype).type(0)
857        -zero  # does not warn
858
859    @xfail  # (reason="pytorch raises RuntimeError on division by zero")
860    @parametrize("dtype", np.typecodes["AllInteger"])
861    @parametrize(
862        "operation",
863        [
864            lambda val, zero: val // zero,
865            lambda val, zero: val % zero,
866        ],
867    )
868    def test_scalar_integer_operation_divbyzero(self, dtype, operation):
869        st = np.dtype(dtype).type
870        val = st(100)
871        zero = st(0)
872
873        with pytest.warns(RuntimeWarning, match="divide by zero"):
874            operation(val, zero)
875
876
877ops_with_names = [
878    ("__lt__", "__gt__", operator.lt, True),
879    ("__le__", "__ge__", operator.le, True),
880    ("__eq__", "__eq__", operator.eq, True),
881    # Note __op__ and __rop__ may be identical here:
882    ("__ne__", "__ne__", operator.ne, True),
883    ("__gt__", "__lt__", operator.gt, True),
884    ("__ge__", "__le__", operator.ge, True),
885    ("__floordiv__", "__rfloordiv__", operator.floordiv, False),
886    ("__truediv__", "__rtruediv__", operator.truediv, False),
887    ("__add__", "__radd__", operator.add, False),
888    ("__mod__", "__rmod__", operator.mod, False),
889    ("__mul__", "__rmul__", operator.mul, False),
890    ("__pow__", "__rpow__", operator.pow, False),
891    ("__sub__", "__rsub__", operator.sub, False),
892]
893
894
895@instantiate_parametrized_tests
896class TestScalarSubclassingMisc(TestCase):
897    @skip(reason="We do not support subclassing scalars.")
898    @parametrize("__op__, __rop__, op, cmp", ops_with_names)
899    @parametrize("sctype", [np.float32, np.float64])
900    def test_subclass_deferral(self, sctype, __op__, __rop__, op, cmp):
901        """
902        This test covers scalar subclass deferral.  Note that this is exceedingly
903        complicated, especially since it tends to fall back to the array paths and
904        these additionally add the "array priority" mechanism.
905
906        The behaviour was modified subtly in 1.22 (to make it closer to how Python
907        scalars work).  Due to its complexity and the fact that subclassing NumPy
908        scalars is probably a bad idea to begin with.  There is probably room
909        for adjustments here.
910        """
911
912        class myf_simple1(sctype):
913            pass
914
915        class myf_simple2(sctype):
916            pass
917
918        def op_func(self, other):
919            return __op__
920
921        def rop_func(self, other):
922            return __rop__
923
924        myf_op = type("myf_op", (sctype,), {__op__: op_func, __rop__: rop_func})
925
926        # inheritance has to override, or this is correctly lost:
927        res = op(myf_simple1(1), myf_simple2(2))
928        assert type(res) == sctype or type(res) == np.bool_
929        assert op(myf_simple1(1), myf_simple2(2)) == op(1, 2)  # inherited
930
931        # Two independent subclasses do not really define an order.  This could
932        # be attempted, but we do not since Python's `int` does neither:
933        assert op(myf_op(1), myf_simple1(2)) == __op__
934        assert op(myf_simple1(1), myf_op(2)) == op(1, 2)  # inherited
935
936    @skip(reason="We do not support subclassing scalars.")
937    @parametrize("__op__, __rop__, op, cmp", ops_with_names)
938    @parametrize("subtype", [float, int, complex, np.float16])
939    # @np._no_nep50_warning()
940    def test_pyscalar_subclasses(self, subtype, __op__, __rop__, op, cmp):
941        def op_func(self, other):
942            return __op__
943
944        def rop_func(self, other):
945            return __rop__
946
947        # Check that deferring is indicated using `__array_ufunc__`:
948        myt = type(
949            "myt",
950            (subtype,),
951            {__op__: op_func, __rop__: rop_func, "__array_ufunc__": None},
952        )
953
954        # Just like normally, we should never presume we can modify the float.
955        assert op(myt(1), np.float64(2)) == __op__
956        assert op(np.float64(1), myt(2)) == __rop__
957
958        if op in {operator.mod, operator.floordiv} and subtype == complex:
959            return  # module is not support for complex.  Do not test.
960
961        if __rop__ == __op__:
962            return
963
964        # When no deferring is indicated, subclasses are handled normally.
965        myt = type("myt", (subtype,), {__rop__: rop_func})
966
967        # Check for float32, as a float subclass float64 may behave differently
968        res = op(myt(1), np.float16(2))
969        expected = op(subtype(1), np.float16(2))
970        assert res == expected
971        assert type(res) == type(expected)
972        res = op(np.float32(2), myt(1))
973        expected = op(np.float32(2), subtype(1))
974        assert res == expected
975        assert type(res) == type(expected)
976
977
978if __name__ == "__main__":
979    run_tests()
980