xref: /aosp_15_r20/external/pytorch/test/torch_np/numpy_tests/core/test_einsum.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2
3import functools
4import itertools
5from unittest import expectedFailure as xfail, skipIf as skipif, SkipTest
6
7from pytest import raises as assert_raises
8
9import torch._numpy as np
10from torch._numpy.testing import (
11    assert_,
12    assert_allclose,
13    assert_almost_equal,
14    assert_array_equal,
15    assert_equal,
16    suppress_warnings,
17)
18from torch.testing._internal.common_utils import (
19    instantiate_parametrized_tests,
20    parametrize,
21    run_tests,
22    TestCase,
23)
24
25
26skip = functools.partial(skipif, True)
27
28
29# Setup for optimize einsum
30chars = "abcdefghij"
31sizes = np.array([2, 3, 4, 5, 4, 3, 2, 6, 5, 4, 3])
32global_size_dict = dict(zip(chars, sizes))
33
34
35@instantiate_parametrized_tests
36class TestEinsum(TestCase):
37    def test_einsum_errors(self):
38        for do_opt in [True, False]:
39            # Need enough arguments
40            assert_raises(
41                (TypeError, IndexError, ValueError), np.einsum, optimize=do_opt
42            )
43            assert_raises((IndexError, ValueError), np.einsum, "", optimize=do_opt)
44
45            # subscripts must be a string
46            assert_raises((AttributeError, TypeError), np.einsum, 0, 0, optimize=do_opt)
47
48            # out parameter must be an array
49            assert_raises(TypeError, np.einsum, "", 0, out="test", optimize=do_opt)
50
51            # order parameter must be a valid order
52            assert_raises(
53                (NotImplementedError, ValueError),
54                np.einsum,
55                "",
56                0,
57                order="W",
58                optimize=do_opt,
59            )
60
61            # casting parameter must be a valid casting
62            assert_raises(ValueError, np.einsum, "", 0, casting="blah", optimize=do_opt)
63
64            # dtype parameter must be a valid dtype
65            assert_raises(
66                TypeError, np.einsum, "", 0, dtype="bad_data_type", optimize=do_opt
67            )
68
69            # other keyword arguments are rejected
70            assert_raises(TypeError, np.einsum, "", 0, bad_arg=0, optimize=do_opt)
71
72            # issue 4528 revealed a segfault with this call
73            assert_raises(
74                (RuntimeError, TypeError), np.einsum, *(None,) * 63, optimize=do_opt
75            )
76
77            # number of operands must match count in subscripts string
78            assert_raises(
79                (RuntimeError, ValueError), np.einsum, "", 0, 0, optimize=do_opt
80            )
81            assert_raises(
82                (RuntimeError, ValueError), np.einsum, ",", 0, [0], [0], optimize=do_opt
83            )
84            assert_raises(
85                (RuntimeError, ValueError), np.einsum, ",", [0], optimize=do_opt
86            )
87
88            # can't have more subscripts than dimensions in the operand
89            assert_raises(
90                (RuntimeError, ValueError), np.einsum, "i", 0, optimize=do_opt
91            )
92            assert_raises(
93                (RuntimeError, ValueError), np.einsum, "ij", [0, 0], optimize=do_opt
94            )
95            assert_raises(
96                (RuntimeError, ValueError), np.einsum, "...i", 0, optimize=do_opt
97            )
98            assert_raises(
99                (RuntimeError, ValueError), np.einsum, "i...j", [0, 0], optimize=do_opt
100            )
101            assert_raises(
102                (RuntimeError, ValueError), np.einsum, "i...", 0, optimize=do_opt
103            )
104            assert_raises(
105                (RuntimeError, ValueError), np.einsum, "ij...", [0, 0], optimize=do_opt
106            )
107
108            # invalid ellipsis
109            assert_raises(
110                (RuntimeError, ValueError), np.einsum, "i..", [0, 0], optimize=do_opt
111            )
112            assert_raises(
113                (RuntimeError, ValueError), np.einsum, ".i...", [0, 0], optimize=do_opt
114            )
115            assert_raises(
116                (RuntimeError, ValueError), np.einsum, "j->..j", [0, 0], optimize=do_opt
117            )
118            assert_raises(
119                (RuntimeError, ValueError),
120                np.einsum,
121                "j->.j...",
122                [0, 0],
123                optimize=do_opt,
124            )
125
126            # invalid subscript character
127            assert_raises(
128                (RuntimeError, ValueError), np.einsum, "i%...", [0, 0], optimize=do_opt
129            )
130            assert_raises(
131                (RuntimeError, ValueError), np.einsum, "...j$", [0, 0], optimize=do_opt
132            )
133            assert_raises(
134                (RuntimeError, ValueError), np.einsum, "i->&", [0, 0], optimize=do_opt
135            )
136
137            # output subscripts must appear in input
138            assert_raises(
139                (RuntimeError, ValueError), np.einsum, "i->ij", [0, 0], optimize=do_opt
140            )
141
142            # output subscripts may only be specified once
143            assert_raises(
144                (RuntimeError, ValueError),
145                np.einsum,
146                "ij->jij",
147                [[0, 0], [0, 0]],
148                optimize=do_opt,
149            )
150
151            # dimensions much match when being collapsed
152            assert_raises(
153                (RuntimeError, ValueError),
154                np.einsum,
155                "ii",
156                np.arange(6).reshape(2, 3),
157                optimize=do_opt,
158            )
159            assert_raises(
160                (RuntimeError, ValueError),
161                np.einsum,
162                "ii->i",
163                np.arange(6).reshape(2, 3),
164                optimize=do_opt,
165            )
166
167            # broadcasting to new dimensions must be enabled explicitly
168            assert_raises(
169                (RuntimeError, ValueError),
170                np.einsum,
171                "i",
172                np.arange(6).reshape(2, 3),
173                optimize=do_opt,
174            )
175            assert_raises(
176                (RuntimeError, ValueError),
177                np.einsum,
178                "i->i",
179                [[0, 1], [0, 1]],
180                out=np.arange(4).reshape(2, 2),
181                optimize=do_opt,
182            )
183            with assert_raises((RuntimeError, ValueError)):  # , match="'b'"):
184                # gh-11221 - 'c' erroneously appeared in the error message
185                a = np.ones((3, 3, 4, 5, 6))
186                b = np.ones((3, 4, 5))
187                np.einsum("aabcb,abc", a, b)
188
189            # Check order kwarg, asanyarray allows 1d to pass through
190            assert_raises(
191                (NotImplementedError, ValueError),
192                np.einsum,
193                "i->i",
194                np.arange(6).reshape(-1, 1),
195                optimize=do_opt,
196                order="d",
197            )
198
199    @xfail  # (reason="a view into smth else")
200    def test_einsum_views(self):
201        # pass-through
202        for do_opt in [True, False]:
203            a = np.arange(6)
204            a = a.reshape(2, 3)
205
206            b = np.einsum("...", a, optimize=do_opt)
207            assert_(b.tensor._base is a.tensor)
208
209            b = np.einsum(a, [Ellipsis], optimize=do_opt)
210            assert_(b.base is a)
211
212            b = np.einsum("ij", a, optimize=do_opt)
213            assert_(b.base is a)
214            assert_equal(b, a)
215
216            b = np.einsum(a, [0, 1], optimize=do_opt)
217            assert_(b.base is a)
218            assert_equal(b, a)
219
220            # output is writeable whenever input is writeable
221            b = np.einsum("...", a, optimize=do_opt)
222            assert_(b.flags["WRITEABLE"])
223            a.flags["WRITEABLE"] = False
224            b = np.einsum("...", a, optimize=do_opt)
225            assert_(not b.flags["WRITEABLE"])
226
227            # transpose
228            a = np.arange(6)
229            a.shape = (2, 3)
230
231            b = np.einsum("ji", a, optimize=do_opt)
232            assert_(b.base is a)
233            assert_equal(b, a.T)
234
235            b = np.einsum(a, [1, 0], optimize=do_opt)
236            assert_(b.base is a)
237            assert_equal(b, a.T)
238
239            # diagonal
240            a = np.arange(9)
241            a.shape = (3, 3)
242
243            b = np.einsum("ii->i", a, optimize=do_opt)
244            assert_(b.base is a)
245            assert_equal(b, [a[i, i] for i in range(3)])
246
247            b = np.einsum(a, [0, 0], [0], optimize=do_opt)
248            assert_(b.base is a)
249            assert_equal(b, [a[i, i] for i in range(3)])
250
251            # diagonal with various ways of broadcasting an additional dimension
252            a = np.arange(27)
253            a.shape = (3, 3, 3)
254
255            b = np.einsum("...ii->...i", a, optimize=do_opt)
256            assert_(b.base is a)
257            assert_equal(b, [[x[i, i] for i in range(3)] for x in a])
258
259            b = np.einsum(a, [Ellipsis, 0, 0], [Ellipsis, 0], optimize=do_opt)
260            assert_(b.base is a)
261            assert_equal(b, [[x[i, i] for i in range(3)] for x in a])
262
263            b = np.einsum("ii...->...i", a, optimize=do_opt)
264            assert_(b.base is a)
265            assert_equal(b, [[x[i, i] for i in range(3)] for x in a.transpose(2, 0, 1)])
266
267            b = np.einsum(a, [0, 0, Ellipsis], [Ellipsis, 0], optimize=do_opt)
268            assert_(b.base is a)
269            assert_equal(b, [[x[i, i] for i in range(3)] for x in a.transpose(2, 0, 1)])
270
271            b = np.einsum("...ii->i...", a, optimize=do_opt)
272            assert_(b.base is a)
273            assert_equal(b, [a[:, i, i] for i in range(3)])
274
275            b = np.einsum(a, [Ellipsis, 0, 0], [0, Ellipsis], optimize=do_opt)
276            assert_(b.base is a)
277            assert_equal(b, [a[:, i, i] for i in range(3)])
278
279            b = np.einsum("jii->ij", a, optimize=do_opt)
280            assert_(b.base is a)
281            assert_equal(b, [a[:, i, i] for i in range(3)])
282
283            b = np.einsum(a, [1, 0, 0], [0, 1], optimize=do_opt)
284            assert_(b.base is a)
285            assert_equal(b, [a[:, i, i] for i in range(3)])
286
287            b = np.einsum("ii...->i...", a, optimize=do_opt)
288            assert_(b.base is a)
289            assert_equal(b, [a.transpose(2, 0, 1)[:, i, i] for i in range(3)])
290
291            b = np.einsum(a, [0, 0, Ellipsis], [0, Ellipsis], optimize=do_opt)
292            assert_(b.base is a)
293            assert_equal(b, [a.transpose(2, 0, 1)[:, i, i] for i in range(3)])
294
295            b = np.einsum("i...i->i...", a, optimize=do_opt)
296            assert_(b.base is a)
297            assert_equal(b, [a.transpose(1, 0, 2)[:, i, i] for i in range(3)])
298
299            b = np.einsum(a, [0, Ellipsis, 0], [0, Ellipsis], optimize=do_opt)
300            assert_(b.base is a)
301            assert_equal(b, [a.transpose(1, 0, 2)[:, i, i] for i in range(3)])
302
303            b = np.einsum("i...i->...i", a, optimize=do_opt)
304            assert_(b.base is a)
305            assert_equal(b, [[x[i, i] for i in range(3)] for x in a.transpose(1, 0, 2)])
306
307            b = np.einsum(a, [0, Ellipsis, 0], [Ellipsis, 0], optimize=do_opt)
308            assert_(b.base is a)
309            assert_equal(b, [[x[i, i] for i in range(3)] for x in a.transpose(1, 0, 2)])
310
311            # triple diagonal
312            a = np.arange(27)
313            a.shape = (3, 3, 3)
314
315            b = np.einsum("iii->i", a, optimize=do_opt)
316            assert_(b.base is a)
317            assert_equal(b, [a[i, i, i] for i in range(3)])
318
319            b = np.einsum(a, [0, 0, 0], [0], optimize=do_opt)
320            assert_(b.base is a)
321            assert_equal(b, [a[i, i, i] for i in range(3)])
322
323            # swap axes
324            a = np.arange(24)
325            a.shape = (2, 3, 4)
326
327            b = np.einsum("ijk->jik", a, optimize=do_opt)
328            assert_(b.base is a)
329            assert_equal(b, a.swapaxes(0, 1))
330
331            b = np.einsum(a, [0, 1, 2], [1, 0, 2], optimize=do_opt)
332            assert_(b.base is a)
333            assert_equal(b, a.swapaxes(0, 1))
334
335    #  @np._no_nep50_warning()
336    def check_einsum_sums(self, dtype, do_opt=False):
337        dtype = np.dtype(dtype)
338        # Check various sums.  Does many sizes to exercise unrolled loops.
339
340        # sum(a, axis=-1)
341        for n in range(1, 17):
342            a = np.arange(n, dtype=dtype)
343            assert_equal(
344                np.einsum("i->", a, optimize=do_opt), np.sum(a, axis=-1).astype(dtype)
345            )
346            assert_equal(
347                np.einsum(a, [0], [], optimize=do_opt), np.sum(a, axis=-1).astype(dtype)
348            )
349
350        for n in range(1, 17):
351            a = np.arange(2 * 3 * n, dtype=dtype).reshape(2, 3, n)
352            assert_equal(
353                np.einsum("...i->...", a, optimize=do_opt),
354                np.sum(a, axis=-1).astype(dtype),
355            )
356            assert_equal(
357                np.einsum(a, [Ellipsis, 0], [Ellipsis], optimize=do_opt),
358                np.sum(a, axis=-1).astype(dtype),
359            )
360
361        # sum(a, axis=0)
362        for n in range(1, 17):
363            a = np.arange(2 * n, dtype=dtype).reshape(2, n)
364            assert_equal(
365                np.einsum("i...->...", a, optimize=do_opt),
366                np.sum(a, axis=0).astype(dtype),
367            )
368            assert_equal(
369                np.einsum(a, [0, Ellipsis], [Ellipsis], optimize=do_opt),
370                np.sum(a, axis=0).astype(dtype),
371            )
372
373        for n in range(1, 17):
374            a = np.arange(2 * 3 * n, dtype=dtype).reshape(2, 3, n)
375            assert_equal(
376                np.einsum("i...->...", a, optimize=do_opt),
377                np.sum(a, axis=0).astype(dtype),
378            )
379            assert_equal(
380                np.einsum(a, [0, Ellipsis], [Ellipsis], optimize=do_opt),
381                np.sum(a, axis=0).astype(dtype),
382            )
383
384        # trace(a)
385        for n in range(1, 17):
386            a = np.arange(n * n, dtype=dtype).reshape(n, n)
387            assert_equal(np.einsum("ii", a, optimize=do_opt), np.trace(a).astype(dtype))
388            assert_equal(
389                np.einsum(a, [0, 0], optimize=do_opt),  # torch?
390                np.trace(a).astype(dtype),
391            )
392
393            # gh-15961: should accept numpy int64 type in subscript list
394        #     np_array = np.asarray([0, 0])
395        #     assert_equal(np.einsum(a, np_array, optimize=do_opt),
396        #                  np.trace(a).astype(dtype))
397        #     assert_equal(np.einsum(a, list(np_array), optimize=do_opt),
398        #                  np.trace(a).astype(dtype))
399
400        # multiply(a, b)
401        assert_equal(np.einsum("..., ...", 3, 4), 12)  # scalar case
402        for n in range(1, 17):
403            a = np.arange(3 * n, dtype=dtype).reshape(3, n)
404            b = np.arange(2 * 3 * n, dtype=dtype).reshape(2, 3, n)
405            assert_equal(
406                np.einsum("..., ...", a, b, optimize=do_opt), np.multiply(a, b)
407            )
408            assert_equal(
409                np.einsum(a, [Ellipsis], b, [Ellipsis], optimize=do_opt),
410                np.multiply(a, b),
411            )
412
413        # inner(a,b)
414        for n in range(1, 17):
415            a = np.arange(2 * 3 * n, dtype=dtype).reshape(2, 3, n)
416            b = np.arange(n, dtype=dtype)
417            assert_equal(np.einsum("...i, ...i", a, b, optimize=do_opt), np.inner(a, b))
418            assert_equal(
419                np.einsum(a, [Ellipsis, 0], b, [Ellipsis, 0], optimize=do_opt),
420                np.inner(a, b),
421            )
422
423        for n in range(1, 11):
424            a = np.arange(n * 3 * 2, dtype=dtype).reshape(n, 3, 2)
425            b = np.arange(n, dtype=dtype)
426            assert_equal(
427                np.einsum("i..., i...", a, b, optimize=do_opt), np.inner(a.T, b.T).T
428            )
429            assert_equal(
430                np.einsum(a, [0, Ellipsis], b, [0, Ellipsis], optimize=do_opt),
431                np.inner(a.T, b.T).T,
432            )
433
434        # outer(a,b)
435        for n in range(1, 17):
436            a = np.arange(3, dtype=dtype) + 1
437            b = np.arange(n, dtype=dtype) + 1
438            assert_equal(np.einsum("i,j", a, b, optimize=do_opt), np.outer(a, b))
439            assert_equal(np.einsum(a, [0], b, [1], optimize=do_opt), np.outer(a, b))
440
441        # Suppress the complex warnings for the 'as f8' tests
442        with suppress_warnings() as sup:
443            #         sup.filter(np.ComplexWarning)
444
445            # matvec(a,b) / a.dot(b) where a is matrix, b is vector
446            for n in range(1, 17):
447                a = np.arange(4 * n, dtype=dtype).reshape(4, n)
448                b = np.arange(n, dtype=dtype)
449                assert_equal(np.einsum("ij, j", a, b, optimize=do_opt), np.dot(a, b))
450                assert_equal(
451                    np.einsum(a, [0, 1], b, [1], optimize=do_opt), np.dot(a, b)
452                )
453
454                c = np.arange(4, dtype=dtype)
455                np.einsum(
456                    "ij,j", a, b, out=c, dtype="f8", casting="unsafe", optimize=do_opt
457                )
458                assert_equal(c, np.dot(a.astype("f8"), b.astype("f8")).astype(dtype))
459                c[...] = 0
460                np.einsum(
461                    a,
462                    [0, 1],
463                    b,
464                    [1],
465                    out=c,
466                    dtype="f8",
467                    casting="unsafe",
468                    optimize=do_opt,
469                )
470                assert_equal(c, np.dot(a.astype("f8"), b.astype("f8")).astype(dtype))
471
472            for n in range(1, 17):
473                a = np.arange(4 * n, dtype=dtype).reshape(4, n)
474                b = np.arange(n, dtype=dtype)
475                assert_equal(
476                    np.einsum("ji,j", a.T, b.T, optimize=do_opt), np.dot(b.T, a.T)
477                )
478                assert_equal(
479                    np.einsum(a.T, [1, 0], b.T, [1], optimize=do_opt), np.dot(b.T, a.T)
480                )
481
482                c = np.arange(4, dtype=dtype)
483                np.einsum(
484                    "ji,j",
485                    a.T,
486                    b.T,
487                    out=c,
488                    dtype="f8",
489                    casting="unsafe",
490                    optimize=do_opt,
491                )
492                assert_equal(
493                    c, np.dot(b.T.astype("f8"), a.T.astype("f8")).astype(dtype)
494                )
495                c[...] = 0
496                np.einsum(
497                    a.T,
498                    [1, 0],
499                    b.T,
500                    [1],
501                    out=c,
502                    dtype="f8",
503                    casting="unsafe",
504                    optimize=do_opt,
505                )
506                assert_equal(
507                    c, np.dot(b.T.astype("f8"), a.T.astype("f8")).astype(dtype)
508                )
509
510            # matmat(a,b) / a.dot(b) where a is matrix, b is matrix
511            for n in range(1, 17):
512                if n < 8 or dtype != "f2":
513                    a = np.arange(4 * n, dtype=dtype).reshape(4, n)
514                    b = np.arange(n * 6, dtype=dtype).reshape(n, 6)
515                    assert_equal(
516                        np.einsum("ij,jk", a, b, optimize=do_opt), np.dot(a, b)
517                    )
518                    assert_equal(
519                        np.einsum(a, [0, 1], b, [1, 2], optimize=do_opt), np.dot(a, b)
520                    )
521
522            for n in range(1, 17):
523                a = np.arange(4 * n, dtype=dtype).reshape(4, n)
524                b = np.arange(n * 6, dtype=dtype).reshape(n, 6)
525                c = np.arange(24, dtype=dtype).reshape(4, 6)
526                np.einsum(
527                    "ij,jk", a, b, out=c, dtype="f8", casting="unsafe", optimize=do_opt
528                )
529                assert_equal(c, np.dot(a.astype("f8"), b.astype("f8")).astype(dtype))
530                c[...] = 0
531                np.einsum(
532                    a,
533                    [0, 1],
534                    b,
535                    [1, 2],
536                    out=c,
537                    dtype="f8",
538                    casting="unsafe",
539                    optimize=do_opt,
540                )
541                assert_equal(c, np.dot(a.astype("f8"), b.astype("f8")).astype(dtype))
542
543            # matrix triple product (note this is not currently an efficient
544            # way to multiply 3 matrices)
545            a = np.arange(12, dtype=dtype).reshape(3, 4)
546            b = np.arange(20, dtype=dtype).reshape(4, 5)
547            c = np.arange(30, dtype=dtype).reshape(5, 6)
548            if dtype != "f2":
549                assert_equal(
550                    np.einsum("ij,jk,kl", a, b, c, optimize=do_opt), a.dot(b).dot(c)
551                )
552                assert_equal(
553                    np.einsum(a, [0, 1], b, [1, 2], c, [2, 3], optimize=do_opt),
554                    a.dot(b).dot(c),
555                )
556
557            d = np.arange(18, dtype=dtype).reshape(3, 6)
558            np.einsum(
559                "ij,jk,kl",
560                a,
561                b,
562                c,
563                out=d,
564                dtype="f8",
565                casting="unsafe",
566                optimize=do_opt,
567            )
568            tgt = a.astype("f8").dot(b.astype("f8"))
569            tgt = tgt.dot(c.astype("f8")).astype(dtype)
570            assert_equal(d, tgt)
571
572            d[...] = 0
573            np.einsum(
574                a,
575                [0, 1],
576                b,
577                [1, 2],
578                c,
579                [2, 3],
580                out=d,
581                dtype="f8",
582                casting="unsafe",
583                optimize=do_opt,
584            )
585            tgt = a.astype("f8").dot(b.astype("f8"))
586            tgt = tgt.dot(c.astype("f8")).astype(dtype)
587            assert_equal(d, tgt)
588
589            # tensordot(a, b)
590            if np.dtype(dtype) != np.dtype("f2"):
591                a = np.arange(60, dtype=dtype).reshape(3, 4, 5)
592                b = np.arange(24, dtype=dtype).reshape(4, 3, 2)
593                assert_equal(
594                    np.einsum("ijk, jil -> kl", a, b),
595                    np.tensordot(a, b, axes=([1, 0], [0, 1])),
596                )
597                assert_equal(
598                    np.einsum(a, [0, 1, 2], b, [1, 0, 3], [2, 3]),
599                    np.tensordot(a, b, axes=([1, 0], [0, 1])),
600                )
601
602                c = np.arange(10, dtype=dtype).reshape(5, 2)
603                np.einsum(
604                    "ijk,jil->kl",
605                    a,
606                    b,
607                    out=c,
608                    dtype="f8",
609                    casting="unsafe",
610                    optimize=do_opt,
611                )
612                assert_equal(
613                    c,
614                    np.tensordot(
615                        a.astype("f8"), b.astype("f8"), axes=([1, 0], [0, 1])
616                    ).astype(dtype),
617                )
618                c[...] = 0
619                np.einsum(
620                    a,
621                    [0, 1, 2],
622                    b,
623                    [1, 0, 3],
624                    [2, 3],
625                    out=c,
626                    dtype="f8",
627                    casting="unsafe",
628                    optimize=do_opt,
629                )
630                assert_equal(
631                    c,
632                    np.tensordot(
633                        a.astype("f8"), b.astype("f8"), axes=([1, 0], [0, 1])
634                    ).astype(dtype),
635                )
636
637        # logical_and(logical_and(a!=0, b!=0), c!=0)
638        neg_val = -2 if dtype.kind != "u" else np.iinfo(dtype).max - 1
639        a = np.array([1, 3, neg_val, 0, 12, 13, 0, 1], dtype=dtype)
640        b = np.array([0, 3.5, 0.0, neg_val, 0, 1, 3, 12], dtype=dtype)
641        c = np.array([True, True, False, True, True, False, True, True])
642
643        assert_equal(
644            np.einsum(
645                "i,i,i->i", a, b, c, dtype="?", casting="unsafe", optimize=do_opt
646            ),
647            np.logical_and(np.logical_and(a != 0, b != 0), c != 0),
648        )
649        assert_equal(
650            np.einsum(a, [0], b, [0], c, [0], [0], dtype="?", casting="unsafe"),
651            np.logical_and(np.logical_and(a != 0, b != 0), c != 0),
652        )
653
654        a = np.arange(9, dtype=dtype)
655        assert_equal(np.einsum(",i->", 3, a), 3 * np.sum(a))
656        assert_equal(np.einsum(3, [], a, [0], []), 3 * np.sum(a))
657        assert_equal(np.einsum("i,->", a, 3), 3 * np.sum(a))
658        assert_equal(np.einsum(a, [0], 3, [], []), 3 * np.sum(a))
659
660        # Various stride0, contiguous, and SSE aligned variants
661        for n in range(1, 25):
662            a = np.arange(n, dtype=dtype)
663            if np.dtype(dtype).itemsize > 1:
664                assert_equal(
665                    np.einsum("...,...", a, a, optimize=do_opt), np.multiply(a, a)
666                )
667                assert_equal(np.einsum("i,i", a, a, optimize=do_opt), np.dot(a, a))
668                assert_equal(np.einsum("i,->i", a, 2, optimize=do_opt), 2 * a)
669                assert_equal(np.einsum(",i->i", 2, a, optimize=do_opt), 2 * a)
670                assert_equal(np.einsum("i,->", a, 2, optimize=do_opt), 2 * np.sum(a))
671                assert_equal(np.einsum(",i->", 2, a, optimize=do_opt), 2 * np.sum(a))
672
673                assert_equal(
674                    np.einsum("...,...", a[1:], a[:-1], optimize=do_opt),
675                    np.multiply(a[1:], a[:-1]),
676                )
677                assert_equal(
678                    np.einsum("i,i", a[1:], a[:-1], optimize=do_opt),
679                    np.dot(a[1:], a[:-1]),
680                )
681                assert_equal(np.einsum("i,->i", a[1:], 2, optimize=do_opt), 2 * a[1:])
682                assert_equal(np.einsum(",i->i", 2, a[1:], optimize=do_opt), 2 * a[1:])
683                assert_equal(
684                    np.einsum("i,->", a[1:], 2, optimize=do_opt), 2 * np.sum(a[1:])
685                )
686                assert_equal(
687                    np.einsum(",i->", 2, a[1:], optimize=do_opt), 2 * np.sum(a[1:])
688                )
689
690        # An object array, summed as the data type
691        #    a = np.arange(9, dtype=object)
692        #
693        #    b = np.einsum("i->", a, dtype=dtype, casting='unsafe')
694        #    assert_equal(b, np.sum(a))
695        #    assert_equal(b.dtype, np.dtype(dtype))
696        #
697        #    b = np.einsum(a, [0], [], dtype=dtype, casting='unsafe')
698        #    assert_equal(b, np.sum(a))
699        #    assert_equal(b.dtype, np.dtype(dtype))
700
701        # A case which was failing (ticket #1885)
702        p = np.arange(2) + 1
703        q = np.arange(4).reshape(2, 2) + 3
704        r = np.arange(4).reshape(2, 2) + 7
705        assert_equal(np.einsum("z,mz,zm->", p, q, r), 253)
706
707        # singleton dimensions broadcast (gh-10343)
708        p = np.ones((10, 2))
709        q = np.ones((1, 2))
710        assert_array_equal(
711            np.einsum("ij,ij->j", p, q, optimize=True),
712            np.einsum("ij,ij->j", p, q, optimize=False),
713        )
714        assert_array_equal(np.einsum("ij,ij->j", p, q, optimize=True), [10.0] * 2)
715
716        # a blas-compatible contraction broadcasting case which was failing
717        # for optimize=True (ticket #10930)
718        x = np.array([2.0, 3.0])
719        y = np.array([4.0])
720        assert_array_equal(np.einsum("i, i", x, y, optimize=False), 20.0)
721        assert_array_equal(np.einsum("i, i", x, y, optimize=True), 20.0)
722
723        # all-ones array was bypassing bug (ticket #10930)
724        p = np.ones((1, 5)) / 2
725        q = np.ones((5, 5)) / 2
726        for optimize in (True, False):
727            assert_array_equal(
728                np.einsum("...ij,...jk->...ik", p, p, optimize=optimize),
729                np.einsum("...ij,...jk->...ik", p, q, optimize=optimize),
730            )
731            assert_array_equal(
732                np.einsum("...ij,...jk->...ik", p, q, optimize=optimize),
733                np.full((1, 5), 1.25),
734            )
735
736        # Cases which were failing (gh-10899)
737        x = np.eye(2, dtype=dtype)
738        y = np.ones(2, dtype=dtype)
739        assert_array_equal(
740            np.einsum("ji,i->", x, y, optimize=optimize), [2.0]
741        )  # contig_contig_outstride0_two
742        assert_array_equal(
743            np.einsum("i,ij->", y, x, optimize=optimize), [2.0]
744        )  # stride0_contig_outstride0_two
745        assert_array_equal(
746            np.einsum("ij,i->", x, y, optimize=optimize), [2.0]
747        )  # contig_stride0_outstride0_two
748
749    @xfail  # (reason="int overflow differs in numpy and pytorch")
750    def test_einsum_sums_int8(self):
751        self.check_einsum_sums("i1")
752
753    @xfail  # (reason="int overflow differs in numpy and pytorch")
754    def test_einsum_sums_uint8(self):
755        self.check_einsum_sums("u1")
756
757    @xfail  # (reason="int overflow differs in numpy and pytorch")
758    def test_einsum_sums_int16(self):
759        self.check_einsum_sums("i2")
760
761    def test_einsum_sums_int32(self):
762        self.check_einsum_sums("i4")
763        self.check_einsum_sums("i4", True)
764
765    def test_einsum_sums_int64(self):
766        self.check_einsum_sums("i8")
767
768    @xfail  # (reason="np.float16(4641) == 4640.0")
769    def test_einsum_sums_float16(self):
770        self.check_einsum_sums("f2")
771
772    def test_einsum_sums_float32(self):
773        self.check_einsum_sums("f4")
774
775    def test_einsum_sums_float64(self):
776        self.check_einsum_sums("f8")
777        self.check_einsum_sums("f8", True)
778
779    def test_einsum_sums_cfloat64(self):
780        self.check_einsum_sums("c8")
781        self.check_einsum_sums("c8", True)
782
783    def test_einsum_sums_cfloat128(self):
784        self.check_einsum_sums("c16")
785
786    def test_einsum_misc(self):
787        # This call used to crash because of a bug in
788        # PyArray_AssignZero
789        a = np.ones((1, 2))
790        b = np.ones((2, 2, 1))
791        assert_equal(np.einsum("ij...,j...->i...", a, b), [[[2], [2]]])
792        assert_equal(np.einsum("ij...,j...->i...", a, b, optimize=True), [[[2], [2]]])
793
794        # Regression test for issue #10369 (test unicode inputs with Python 2)
795        assert_equal(np.einsum("ij...,j...->i...", a, b), [[[2], [2]]])
796        assert_equal(np.einsum("...i,...i", [1, 2, 3], [2, 3, 4]), 20)
797        assert_equal(
798            np.einsum("...i,...i", [1, 2, 3], [2, 3, 4], optimize="greedy"), 20
799        )
800
801        # The iterator had an issue with buffering this reduction
802        a = np.ones((5, 12, 4, 2, 3), np.int64)
803        b = np.ones((5, 12, 11), np.int64)
804        assert_equal(
805            np.einsum("ijklm,ijn,ijn->", a, b, b), np.einsum("ijklm,ijn->", a, b)
806        )
807        assert_equal(
808            np.einsum("ijklm,ijn,ijn->", a, b, b, optimize=True),
809            np.einsum("ijklm,ijn->", a, b, optimize=True),
810        )
811
812        # Issue #2027, was a problem in the contiguous 3-argument
813        # inner loop implementation
814        a = np.arange(1, 3)
815        b = np.arange(1, 5).reshape(2, 2)
816        c = np.arange(1, 9).reshape(4, 2)
817        assert_equal(
818            np.einsum("x,yx,zx->xzy", a, b, c),
819            [
820                [[1, 3], [3, 9], [5, 15], [7, 21]],
821                [[8, 16], [16, 32], [24, 48], [32, 64]],
822            ],
823        )
824        assert_equal(
825            np.einsum("x,yx,zx->xzy", a, b, c, optimize=True),
826            [
827                [[1, 3], [3, 9], [5, 15], [7, 21]],
828                [[8, 16], [16, 32], [24, 48], [32, 64]],
829            ],
830        )
831
832        # Ensure explicitly setting out=None does not cause an error
833        # see issue gh-15776 and issue gh-15256
834        assert_equal(np.einsum("i,j", [1], [2], out=None), [[2]])
835
836    def test_subscript_range(self):
837        # Issue #7741, make sure that all letters of Latin alphabet (both uppercase & lowercase) can be used
838        # when creating a subscript from arrays
839        a = np.ones((2, 3))
840        b = np.ones((3, 4))
841        np.einsum(a, [0, 20], b, [20, 2], [0, 2], optimize=False)
842        np.einsum(a, [0, 27], b, [27, 2], [0, 2], optimize=False)
843        np.einsum(a, [0, 51], b, [51, 2], [0, 2], optimize=False)
844        assert_raises(
845            ValueError,
846            lambda: np.einsum(a, [0, 52], b, [52, 2], [0, 2], optimize=False),
847        )
848        assert_raises(
849            ValueError,
850            lambda: np.einsum(a, [-1, 5], b, [5, 2], [-1, 2], optimize=False),
851        )
852
853    def test_einsum_broadcast(self):
854        # Issue #2455 change in handling ellipsis
855        # remove the 'middle broadcast' error
856        # only use the 'RIGHT' iteration in prepare_op_axes
857        # adds auto broadcast on left where it belongs
858        # broadcast on right has to be explicit
859        # We need to test the optimized parsing as well
860
861        A = np.arange(2 * 3 * 4).reshape(2, 3, 4)
862        B = np.arange(3)
863        ref = np.einsum("ijk,j->ijk", A, B, optimize=False)
864        for opt in [True, False]:
865            assert_equal(np.einsum("ij...,j...->ij...", A, B, optimize=opt), ref)
866            assert_equal(np.einsum("ij...,...j->ij...", A, B, optimize=opt), ref)
867            assert_equal(
868                np.einsum("ij...,j->ij...", A, B, optimize=opt), ref
869            )  # used to raise error
870
871        A = np.arange(12).reshape((4, 3))
872        B = np.arange(6).reshape((3, 2))
873        ref = np.einsum("ik,kj->ij", A, B, optimize=False)
874        for opt in [True, False]:
875            assert_equal(np.einsum("ik...,k...->i...", A, B, optimize=opt), ref)
876            assert_equal(np.einsum("ik...,...kj->i...j", A, B, optimize=opt), ref)
877            assert_equal(
878                np.einsum("...k,kj", A, B, optimize=opt), ref
879            )  # used to raise error
880            assert_equal(
881                np.einsum("ik,k...->i...", A, B, optimize=opt), ref
882            )  # used to raise error
883
884        dims = [2, 3, 4, 5]
885        a = np.arange(np.prod(dims)).reshape(dims)
886        v = np.arange(dims[2])
887        ref = np.einsum("ijkl,k->ijl", a, v, optimize=False)
888        for opt in [True, False]:
889            assert_equal(np.einsum("ijkl,k", a, v, optimize=opt), ref)
890            assert_equal(
891                np.einsum("...kl,k", a, v, optimize=opt), ref
892            )  # used to raise error
893            assert_equal(np.einsum("...kl,k...", a, v, optimize=opt), ref)
894
895        J, K, M = 160, 160, 120
896        A = np.arange(J * K * M).reshape(1, 1, 1, J, K, M)
897        B = np.arange(J * K * M * 3).reshape(J, K, M, 3)
898        ref = np.einsum("...lmn,...lmno->...o", A, B, optimize=False)
899        for opt in [True, False]:
900            assert_equal(
901                np.einsum("...lmn,lmno->...o", A, B, optimize=opt), ref
902            )  # used to raise error
903
904    def test_einsum_fixedstridebug(self):
905        # Issue #4485 obscure einsum bug
906        # This case revealed a bug in nditer where it reported a stride
907        # as 'fixed' (0) when it was in fact not fixed during processing
908        # (0 or 4). The reason for the bug was that the check for a fixed
909        # stride was using the information from the 2D inner loop reuse
910        # to restrict the iteration dimensions it had to validate to be
911        # the same, but that 2D inner loop reuse logic is only triggered
912        # during the buffer copying step, and hence it was invalid to
913        # rely on those values. The fix is to check all the dimensions
914        # of the stride in question, which in the test case reveals that
915        # the stride is not fixed.
916        #
917        # NOTE: This test is triggered by the fact that the default buffersize,
918        #       used by einsum, is 8192, and 3*2731 = 8193, is larger than that
919        #       and results in a mismatch between the buffering and the
920        #       striding for operand A.
921        A = np.arange(2 * 3).reshape(2, 3).astype(np.float32)
922        B = np.arange(2 * 3 * 2731).reshape(2, 3, 2731).astype(np.int16)
923        es = np.einsum("cl, cpx->lpx", A, B)
924        tp = np.tensordot(A, B, axes=(0, 0))
925        assert_equal(es, tp)
926        # The following is the original test case from the bug report,
927        # made repeatable by changing random arrays to aranges.
928        A = np.arange(3 * 3).reshape(3, 3).astype(np.float64)
929        B = np.arange(3 * 3 * 64 * 64).reshape(3, 3, 64, 64).astype(np.float32)
930        es = np.einsum("cl, cpxy->lpxy", A, B)
931        tp = np.tensordot(A, B, axes=(0, 0))
932        assert_equal(es, tp)
933
934    def test_einsum_fixed_collapsingbug(self):
935        # Issue #5147.
936        # The bug only occurred when output argument of einssum was used.
937        x = np.random.normal(0, 1, (5, 5, 5, 5))
938        y1 = np.zeros((5, 5))
939        np.einsum("aabb->ab", x, out=y1)
940        idx = np.arange(5)
941        y2 = x[idx[:, None], idx[:, None], idx, idx]
942        assert_equal(y1, y2)
943
944    def test_einsum_failed_on_p9_and_s390x(self):
945        # Issues gh-14692 and gh-12689
946        # Bug with signed vs unsigned char errored on power9 and s390x Linux
947        tensor = np.random.random_sample((10, 10, 10, 10))
948        x = np.einsum("ijij->", tensor)
949        y = tensor.trace(axis1=0, axis2=2).trace()
950        assert_allclose(x, y)
951
952    @xfail  # (reason="no base")
953    def test_einsum_all_contig_non_contig_output(self):
954        # Issue gh-5907, tests that the all contiguous special case
955        # actually checks the contiguity of the output
956        x = np.ones((5, 5))
957        out = np.ones(10)[::2]
958        correct_base = np.ones(10)
959        correct_base[::2] = 5
960        # Always worked (inner iteration is done with 0-stride):
961        np.einsum("mi,mi,mi->m", x, x, x, out=out)
962        assert_array_equal(out.base, correct_base)
963        # Example 1:
964        out = np.ones(10)[::2]
965        np.einsum("im,im,im->m", x, x, x, out=out)
966        assert_array_equal(out.base, correct_base)
967        # Example 2, buffering causes x to be contiguous but
968        # special cases do not catch the operation before:
969        out = np.ones((2, 2, 2))[..., 0]
970        correct_base = np.ones((2, 2, 2))
971        correct_base[..., 0] = 2
972        x = np.ones((2, 2), np.float32)
973        np.einsum("ij,jk->ik", x, x, out=out)
974        assert_array_equal(out.base, correct_base)
975
976    @parametrize("dtype", np.typecodes["AllFloat"] + np.typecodes["AllInteger"])
977    def test_different_paths(self, dtype):
978        # Test originally added to cover broken float16 path: gh-20305
979        # Likely most are covered elsewhere, at least partially.
980        dtype = np.dtype(dtype)
981        # Simple test, designed to excersize most specialized code paths,
982        # note the +0.5 for floats.  This makes sure we use a float value
983        # where the results must be exact.
984        arr = (np.arange(7) + 0.5).astype(dtype)
985        scalar = np.array(2, dtype=dtype)
986
987        # contig -> scalar:
988        res = np.einsum("i->", arr)
989        assert res == arr.sum()
990        # contig, contig -> contig:
991        res = np.einsum("i,i->i", arr, arr)
992        assert_array_equal(res, arr * arr)
993        # noncontig, noncontig -> contig:
994        res = np.einsum("i,i->i", arr.repeat(2)[::2], arr.repeat(2)[::2])
995        assert_array_equal(res, arr * arr)
996        # contig + contig -> scalar
997        assert np.einsum("i,i->", arr, arr) == (arr * arr).sum()
998        # contig + scalar -> contig (with out)
999        out = np.ones(7, dtype=dtype)
1000        res = np.einsum("i,->i", arr, dtype.type(2), out=out)
1001        assert_array_equal(res, arr * dtype.type(2))
1002        # scalar + contig -> contig (with out)
1003        res = np.einsum(",i->i", scalar, arr)
1004        assert_array_equal(res, arr * dtype.type(2))
1005        # scalar + contig -> scalar
1006        res = np.einsum(",i->", scalar, arr)
1007        # Use einsum to compare to not have difference due to sum round-offs:
1008        assert res == np.einsum("i->", scalar * arr)
1009        # contig + scalar -> scalar
1010        res = np.einsum("i,->", arr, scalar)
1011        # Use einsum to compare to not have difference due to sum round-offs:
1012        assert res == np.einsum("i->", scalar * arr)
1013        # contig + contig + contig -> scalar
1014
1015        if dtype in ["e", "B", "b"]:
1016            # FIXME make xfail
1017            raise SkipTest("overflow differs in pytorch and numpy")
1018
1019        arr = np.array([0.5, 0.5, 0.25, 4.5, 3.0], dtype=dtype)
1020        res = np.einsum("i,i,i->", arr, arr, arr)
1021        assert_array_equal(res, (arr * arr * arr).sum())
1022        # four arrays:
1023        res = np.einsum("i,i,i,i->", arr, arr, arr, arr)
1024        assert_array_equal(res, (arr * arr * arr * arr).sum())
1025
1026    def test_small_boolean_arrays(self):
1027        # See gh-5946.
1028        # Use array of True embedded in False.
1029        a = np.zeros((16, 1, 1), dtype=np.bool_)[:2]
1030        a[...] = True
1031        out = np.zeros((16, 1, 1), dtype=np.bool_)[:2]
1032        tgt = np.ones((2, 1, 1), dtype=np.bool_)
1033        res = np.einsum("...ij,...jk->...ik", a, a, out=out)
1034        assert_equal(res, tgt)
1035
1036    def test_out_is_res(self):
1037        a = np.arange(9).reshape(3, 3)
1038        res = np.einsum("...ij,...jk->...ik", a, a, out=a)
1039        assert res is a
1040
1041    def optimize_compare(self, subscripts, operands=None):
1042        # Tests all paths of the optimization function against
1043        # conventional einsum
1044        if operands is None:
1045            args = [subscripts]
1046            terms = subscripts.split("->")[0].split(",")
1047            for term in terms:
1048                dims = [global_size_dict[x] for x in term]
1049                args.append(np.random.rand(*dims))
1050        else:
1051            args = [subscripts] + operands
1052
1053        noopt = np.einsum(*args, optimize=False)
1054        opt = np.einsum(*args, optimize="greedy")
1055        assert_almost_equal(opt, noopt)
1056        opt = np.einsum(*args, optimize="optimal")
1057        assert_almost_equal(opt, noopt)
1058
1059    def test_hadamard_like_products(self):
1060        # Hadamard outer products
1061        self.optimize_compare("a,ab,abc->abc")
1062        self.optimize_compare("a,b,ab->ab")
1063
1064    def test_index_transformations(self):
1065        # Simple index transformation cases
1066        self.optimize_compare("ea,fb,gc,hd,abcd->efgh")
1067        self.optimize_compare("ea,fb,abcd,gc,hd->efgh")
1068        self.optimize_compare("abcd,ea,fb,gc,hd->efgh")
1069
1070    def test_complex(self):
1071        # Long test cases
1072        self.optimize_compare("acdf,jbje,gihb,hfac,gfac,gifabc,hfac")
1073        self.optimize_compare("acdf,jbje,gihb,hfac,gfac,gifabc,hfac")
1074        self.optimize_compare("cd,bdhe,aidb,hgca,gc,hgibcd,hgac")
1075        self.optimize_compare("abhe,hidj,jgba,hiab,gab")
1076        self.optimize_compare("bde,cdh,agdb,hica,ibd,hgicd,hiac")
1077        self.optimize_compare("chd,bde,agbc,hiad,hgc,hgi,hiad")
1078        self.optimize_compare("chd,bde,agbc,hiad,bdi,cgh,agdb")
1079        self.optimize_compare("bdhe,acad,hiab,agac,hibd")
1080
1081    def test_collapse(self):
1082        # Inner products
1083        self.optimize_compare("ab,ab,c->")
1084        self.optimize_compare("ab,ab,c->c")
1085        self.optimize_compare("ab,ab,cd,cd->")
1086        self.optimize_compare("ab,ab,cd,cd->ac")
1087        self.optimize_compare("ab,ab,cd,cd->cd")
1088        self.optimize_compare("ab,ab,cd,cd,ef,ef->")
1089
1090    def test_expand(self):
1091        # Outer products
1092        self.optimize_compare("ab,cd,ef->abcdef")
1093        self.optimize_compare("ab,cd,ef->acdf")
1094        self.optimize_compare("ab,cd,de->abcde")
1095        self.optimize_compare("ab,cd,de->be")
1096        self.optimize_compare("ab,bcd,cd->abcd")
1097        self.optimize_compare("ab,bcd,cd->abd")
1098
1099    def test_edge_cases(self):
1100        # Difficult edge cases for optimization
1101        self.optimize_compare("eb,cb,fb->cef")
1102        self.optimize_compare("dd,fb,be,cdb->cef")
1103        self.optimize_compare("bca,cdb,dbf,afc->")
1104        self.optimize_compare("dcc,fce,ea,dbf->ab")
1105        self.optimize_compare("fdf,cdd,ccd,afe->ae")
1106        self.optimize_compare("abcd,ad")
1107        self.optimize_compare("ed,fcd,ff,bcf->be")
1108        self.optimize_compare("baa,dcf,af,cde->be")
1109        self.optimize_compare("bd,db,eac->ace")
1110        self.optimize_compare("fff,fae,bef,def->abd")
1111        self.optimize_compare("efc,dbc,acf,fd->abe")
1112        self.optimize_compare("ba,ac,da->bcd")
1113
1114    def test_inner_product(self):
1115        # Inner products
1116        self.optimize_compare("ab,ab")
1117        self.optimize_compare("ab,ba")
1118        self.optimize_compare("abc,abc")
1119        self.optimize_compare("abc,bac")
1120        self.optimize_compare("abc,cba")
1121
1122    def test_random_cases(self):
1123        # Randomly built test cases
1124        self.optimize_compare("aab,fa,df,ecc->bde")
1125        self.optimize_compare("ecb,fef,bad,ed->ac")
1126        self.optimize_compare("bcf,bbb,fbf,fc->")
1127        self.optimize_compare("bb,ff,be->e")
1128        self.optimize_compare("bcb,bb,fc,fff->")
1129        self.optimize_compare("fbb,dfd,fc,fc->")
1130        self.optimize_compare("afd,ba,cc,dc->bf")
1131        self.optimize_compare("adb,bc,fa,cfc->d")
1132        self.optimize_compare("bbd,bda,fc,db->acf")
1133        self.optimize_compare("dba,ead,cad->bce")
1134        self.optimize_compare("aef,fbc,dca->bde")
1135
1136    def test_combined_views_mapping(self):
1137        # gh-10792
1138        a = np.arange(9).reshape(1, 1, 3, 1, 3)
1139        b = np.einsum("bbcdc->d", a)
1140        assert_equal(b, [12])
1141
1142    def test_broadcasting_dot_cases(self):
1143        # Ensures broadcasting cases are not mistaken for GEMM
1144
1145        a = np.random.rand(1, 5, 4)
1146        b = np.random.rand(4, 6)
1147        c = np.random.rand(5, 6)
1148        d = np.random.rand(10)
1149
1150        self.optimize_compare("ijk,kl,jl", operands=[a, b, c])
1151        self.optimize_compare("ijk,kl,jl,i->i", operands=[a, b, c, d])
1152
1153        e = np.random.rand(1, 1, 5, 4)
1154        f = np.random.rand(7, 7)
1155        self.optimize_compare("abjk,kl,jl", operands=[e, b, c])
1156        self.optimize_compare("abjk,kl,jl,ab->ab", operands=[e, b, c, f])
1157
1158        # Edge case found in gh-11308
1159        g = np.arange(64).reshape(2, 4, 8)
1160        self.optimize_compare("obk,ijk->ioj", operands=[g, g])
1161
1162    @xfail  # (reason="order='F' not supported")
1163    def test_output_order(self):
1164        # Ensure output order is respected for optimize cases, the below
1165        # conraction should yield a reshaped tensor view
1166        # gh-16415
1167
1168        a = np.ones((2, 3, 5), order="F")
1169        b = np.ones((4, 3), order="F")
1170
1171        for opt in [True, False]:
1172            tmp = np.einsum("...ft,mf->...mt", a, b, order="a", optimize=opt)
1173            assert_(tmp.flags.f_contiguous)
1174
1175            tmp = np.einsum("...ft,mf->...mt", a, b, order="f", optimize=opt)
1176            assert_(tmp.flags.f_contiguous)
1177
1178            tmp = np.einsum("...ft,mf->...mt", a, b, order="c", optimize=opt)
1179            assert_(tmp.flags.c_contiguous)
1180
1181            tmp = np.einsum("...ft,mf->...mt", a, b, order="k", optimize=opt)
1182            assert_(tmp.flags.c_contiguous is False)
1183            assert_(tmp.flags.f_contiguous is False)
1184
1185            tmp = np.einsum("...ft,mf->...mt", a, b, optimize=opt)
1186            assert_(tmp.flags.c_contiguous is False)
1187            assert_(tmp.flags.f_contiguous is False)
1188
1189        c = np.ones((4, 3), order="C")
1190        for opt in [True, False]:
1191            tmp = np.einsum("...ft,mf->...mt", a, c, order="a", optimize=opt)
1192            assert_(tmp.flags.c_contiguous)
1193
1194        d = np.ones((2, 3, 5), order="C")
1195        for opt in [True, False]:
1196            tmp = np.einsum("...ft,mf->...mt", d, c, order="a", optimize=opt)
1197            assert_(tmp.flags.c_contiguous)
1198
1199
1200@skip(reason="no pytorch analog")
1201class TestEinsumPath(TestCase):
1202    def build_operands(self, string, size_dict=global_size_dict):
1203        # Builds views based off initial operands
1204        operands = [string]
1205        terms = string.split("->")[0].split(",")
1206        for term in terms:
1207            dims = [size_dict[x] for x in term]
1208            operands.append(np.random.rand(*dims))
1209
1210        return operands
1211
1212    def assert_path_equal(self, comp, benchmark):
1213        # Checks if list of tuples are equivalent
1214        ret = len(comp) == len(benchmark)
1215        assert_(ret)
1216        for pos in range(len(comp) - 1):
1217            ret &= isinstance(comp[pos + 1], tuple)
1218            ret &= comp[pos + 1] == benchmark[pos + 1]
1219        assert_(ret)
1220
1221    def test_memory_contraints(self):
1222        # Ensure memory constraints are satisfied
1223
1224        outer_test = self.build_operands("a,b,c->abc")
1225
1226        path, path_str = np.einsum_path(*outer_test, optimize=("greedy", 0))
1227        self.assert_path_equal(path, ["einsum_path", (0, 1, 2)])
1228
1229        path, path_str = np.einsum_path(*outer_test, optimize=("optimal", 0))
1230        self.assert_path_equal(path, ["einsum_path", (0, 1, 2)])
1231
1232        long_test = self.build_operands("acdf,jbje,gihb,hfac")
1233        path, path_str = np.einsum_path(*long_test, optimize=("greedy", 0))
1234        self.assert_path_equal(path, ["einsum_path", (0, 1, 2, 3)])
1235
1236        path, path_str = np.einsum_path(*long_test, optimize=("optimal", 0))
1237        self.assert_path_equal(path, ["einsum_path", (0, 1, 2, 3)])
1238
1239    def test_long_paths(self):
1240        # Long complex cases
1241
1242        # Long test 1
1243        long_test1 = self.build_operands("acdf,jbje,gihb,hfac,gfac,gifabc,hfac")
1244        path, path_str = np.einsum_path(*long_test1, optimize="greedy")
1245        self.assert_path_equal(
1246            path, ["einsum_path", (3, 6), (3, 4), (2, 4), (2, 3), (0, 2), (0, 1)]
1247        )
1248
1249        path, path_str = np.einsum_path(*long_test1, optimize="optimal")
1250        self.assert_path_equal(
1251            path, ["einsum_path", (3, 6), (3, 4), (2, 4), (2, 3), (0, 2), (0, 1)]
1252        )
1253
1254        # Long test 2
1255        long_test2 = self.build_operands("chd,bde,agbc,hiad,bdi,cgh,agdb")
1256        path, path_str = np.einsum_path(*long_test2, optimize="greedy")
1257        self.assert_path_equal(
1258            path, ["einsum_path", (3, 4), (0, 3), (3, 4), (1, 3), (1, 2), (0, 1)]
1259        )
1260
1261        path, path_str = np.einsum_path(*long_test2, optimize="optimal")
1262        self.assert_path_equal(
1263            path, ["einsum_path", (0, 5), (1, 4), (3, 4), (1, 3), (1, 2), (0, 1)]
1264        )
1265
1266    def test_edge_paths(self):
1267        # Difficult edge cases
1268
1269        # Edge test1
1270        edge_test1 = self.build_operands("eb,cb,fb->cef")
1271        path, path_str = np.einsum_path(*edge_test1, optimize="greedy")
1272        self.assert_path_equal(path, ["einsum_path", (0, 2), (0, 1)])
1273
1274        path, path_str = np.einsum_path(*edge_test1, optimize="optimal")
1275        self.assert_path_equal(path, ["einsum_path", (0, 2), (0, 1)])
1276
1277        # Edge test2
1278        edge_test2 = self.build_operands("dd,fb,be,cdb->cef")
1279        path, path_str = np.einsum_path(*edge_test2, optimize="greedy")
1280        self.assert_path_equal(path, ["einsum_path", (0, 3), (0, 1), (0, 1)])
1281
1282        path, path_str = np.einsum_path(*edge_test2, optimize="optimal")
1283        self.assert_path_equal(path, ["einsum_path", (0, 3), (0, 1), (0, 1)])
1284
1285        # Edge test3
1286        edge_test3 = self.build_operands("bca,cdb,dbf,afc->")
1287        path, path_str = np.einsum_path(*edge_test3, optimize="greedy")
1288        self.assert_path_equal(path, ["einsum_path", (1, 2), (0, 2), (0, 1)])
1289
1290        path, path_str = np.einsum_path(*edge_test3, optimize="optimal")
1291        self.assert_path_equal(path, ["einsum_path", (1, 2), (0, 2), (0, 1)])
1292
1293        # Edge test4
1294        edge_test4 = self.build_operands("dcc,fce,ea,dbf->ab")
1295        path, path_str = np.einsum_path(*edge_test4, optimize="greedy")
1296        self.assert_path_equal(path, ["einsum_path", (1, 2), (0, 1), (0, 1)])
1297
1298        path, path_str = np.einsum_path(*edge_test4, optimize="optimal")
1299        self.assert_path_equal(path, ["einsum_path", (1, 2), (0, 2), (0, 1)])
1300
1301        # Edge test5
1302        edge_test4 = self.build_operands(
1303            "a,ac,ab,ad,cd,bd,bc->", size_dict={"a": 20, "b": 20, "c": 20, "d": 20}
1304        )
1305        path, path_str = np.einsum_path(*edge_test4, optimize="greedy")
1306        self.assert_path_equal(path, ["einsum_path", (0, 1), (0, 1, 2, 3, 4, 5)])
1307
1308        path, path_str = np.einsum_path(*edge_test4, optimize="optimal")
1309        self.assert_path_equal(path, ["einsum_path", (0, 1), (0, 1, 2, 3, 4, 5)])
1310
1311    def test_path_type_input(self):
1312        # Test explicit path handling
1313        path_test = self.build_operands("dcc,fce,ea,dbf->ab")
1314
1315        path, path_str = np.einsum_path(*path_test, optimize=False)
1316        self.assert_path_equal(path, ["einsum_path", (0, 1, 2, 3)])
1317
1318        path, path_str = np.einsum_path(*path_test, optimize=True)
1319        self.assert_path_equal(path, ["einsum_path", (1, 2), (0, 1), (0, 1)])
1320
1321        exp_path = ["einsum_path", (0, 2), (0, 2), (0, 1)]
1322        path, path_str = np.einsum_path(*path_test, optimize=exp_path)
1323        self.assert_path_equal(path, exp_path)
1324
1325        # Double check einsum works on the input path
1326        noopt = np.einsum(*path_test, optimize=False)
1327        opt = np.einsum(*path_test, optimize=exp_path)
1328        assert_almost_equal(noopt, opt)
1329
1330    def test_path_type_input_internal_trace(self):
1331        # gh-20962
1332        path_test = self.build_operands("cab,cdd->ab")
1333        exp_path = ["einsum_path", (1,), (0, 1)]
1334
1335        path, path_str = np.einsum_path(*path_test, optimize=exp_path)
1336        self.assert_path_equal(path, exp_path)
1337
1338        # Double check einsum works on the input path
1339        noopt = np.einsum(*path_test, optimize=False)
1340        opt = np.einsum(*path_test, optimize=exp_path)
1341        assert_almost_equal(noopt, opt)
1342
1343    def test_path_type_input_invalid(self):
1344        path_test = self.build_operands("ab,bc,cd,de->ae")
1345        exp_path = ["einsum_path", (2, 3), (0, 1)]
1346        assert_raises(RuntimeError, np.einsum, *path_test, optimize=exp_path)
1347        assert_raises(RuntimeError, np.einsum_path, *path_test, optimize=exp_path)
1348
1349        path_test = self.build_operands("a,a,a->a")
1350        exp_path = ["einsum_path", (1,), (0, 1)]
1351        assert_raises(RuntimeError, np.einsum, *path_test, optimize=exp_path)
1352        assert_raises(RuntimeError, np.einsum_path, *path_test, optimize=exp_path)
1353
1354    def test_spaces(self):
1355        # gh-10794
1356        arr = np.array([[1]])
1357        for sp in itertools.product(["", " "], repeat=4):
1358            # no error for any spacing
1359            np.einsum("{}...a{}->{}...a{}".format(*sp), arr)
1360
1361
1362class TestMisc(TestCase):
1363    def test_overlap(self):
1364        a = np.arange(9, dtype=int).reshape(3, 3)
1365        b = np.arange(9, dtype=int).reshape(3, 3)
1366        d = np.dot(a, b)
1367        # sanity check
1368        c = np.einsum("ij,jk->ik", a, b)
1369        assert_equal(c, d)
1370        # gh-10080, out overlaps one of the operands
1371        c = np.einsum("ij,jk->ik", a, b, out=b)
1372        assert_equal(c, d)
1373
1374
1375if __name__ == "__main__":
1376    run_tests()
1377