1# Owner(s): ["module: dynamo"] 2 3import functools 4import itertools 5from unittest import expectedFailure as xfail, skipIf as skipif, SkipTest 6 7from pytest import raises as assert_raises 8 9import torch._numpy as np 10from torch._numpy.testing import ( 11 assert_, 12 assert_allclose, 13 assert_almost_equal, 14 assert_array_equal, 15 assert_equal, 16 suppress_warnings, 17) 18from torch.testing._internal.common_utils import ( 19 instantiate_parametrized_tests, 20 parametrize, 21 run_tests, 22 TestCase, 23) 24 25 26skip = functools.partial(skipif, True) 27 28 29# Setup for optimize einsum 30chars = "abcdefghij" 31sizes = np.array([2, 3, 4, 5, 4, 3, 2, 6, 5, 4, 3]) 32global_size_dict = dict(zip(chars, sizes)) 33 34 35@instantiate_parametrized_tests 36class TestEinsum(TestCase): 37 def test_einsum_errors(self): 38 for do_opt in [True, False]: 39 # Need enough arguments 40 assert_raises( 41 (TypeError, IndexError, ValueError), np.einsum, optimize=do_opt 42 ) 43 assert_raises((IndexError, ValueError), np.einsum, "", optimize=do_opt) 44 45 # subscripts must be a string 46 assert_raises((AttributeError, TypeError), np.einsum, 0, 0, optimize=do_opt) 47 48 # out parameter must be an array 49 assert_raises(TypeError, np.einsum, "", 0, out="test", optimize=do_opt) 50 51 # order parameter must be a valid order 52 assert_raises( 53 (NotImplementedError, ValueError), 54 np.einsum, 55 "", 56 0, 57 order="W", 58 optimize=do_opt, 59 ) 60 61 # casting parameter must be a valid casting 62 assert_raises(ValueError, np.einsum, "", 0, casting="blah", optimize=do_opt) 63 64 # dtype parameter must be a valid dtype 65 assert_raises( 66 TypeError, np.einsum, "", 0, dtype="bad_data_type", optimize=do_opt 67 ) 68 69 # other keyword arguments are rejected 70 assert_raises(TypeError, np.einsum, "", 0, bad_arg=0, optimize=do_opt) 71 72 # issue 4528 revealed a segfault with this call 73 assert_raises( 74 (RuntimeError, TypeError), np.einsum, *(None,) * 63, optimize=do_opt 75 ) 76 77 # number of operands must match count in subscripts string 78 assert_raises( 79 (RuntimeError, ValueError), np.einsum, "", 0, 0, optimize=do_opt 80 ) 81 assert_raises( 82 (RuntimeError, ValueError), np.einsum, ",", 0, [0], [0], optimize=do_opt 83 ) 84 assert_raises( 85 (RuntimeError, ValueError), np.einsum, ",", [0], optimize=do_opt 86 ) 87 88 # can't have more subscripts than dimensions in the operand 89 assert_raises( 90 (RuntimeError, ValueError), np.einsum, "i", 0, optimize=do_opt 91 ) 92 assert_raises( 93 (RuntimeError, ValueError), np.einsum, "ij", [0, 0], optimize=do_opt 94 ) 95 assert_raises( 96 (RuntimeError, ValueError), np.einsum, "...i", 0, optimize=do_opt 97 ) 98 assert_raises( 99 (RuntimeError, ValueError), np.einsum, "i...j", [0, 0], optimize=do_opt 100 ) 101 assert_raises( 102 (RuntimeError, ValueError), np.einsum, "i...", 0, optimize=do_opt 103 ) 104 assert_raises( 105 (RuntimeError, ValueError), np.einsum, "ij...", [0, 0], optimize=do_opt 106 ) 107 108 # invalid ellipsis 109 assert_raises( 110 (RuntimeError, ValueError), np.einsum, "i..", [0, 0], optimize=do_opt 111 ) 112 assert_raises( 113 (RuntimeError, ValueError), np.einsum, ".i...", [0, 0], optimize=do_opt 114 ) 115 assert_raises( 116 (RuntimeError, ValueError), np.einsum, "j->..j", [0, 0], optimize=do_opt 117 ) 118 assert_raises( 119 (RuntimeError, ValueError), 120 np.einsum, 121 "j->.j...", 122 [0, 0], 123 optimize=do_opt, 124 ) 125 126 # invalid subscript character 127 assert_raises( 128 (RuntimeError, ValueError), np.einsum, "i%...", [0, 0], optimize=do_opt 129 ) 130 assert_raises( 131 (RuntimeError, ValueError), np.einsum, "...j$", [0, 0], optimize=do_opt 132 ) 133 assert_raises( 134 (RuntimeError, ValueError), np.einsum, "i->&", [0, 0], optimize=do_opt 135 ) 136 137 # output subscripts must appear in input 138 assert_raises( 139 (RuntimeError, ValueError), np.einsum, "i->ij", [0, 0], optimize=do_opt 140 ) 141 142 # output subscripts may only be specified once 143 assert_raises( 144 (RuntimeError, ValueError), 145 np.einsum, 146 "ij->jij", 147 [[0, 0], [0, 0]], 148 optimize=do_opt, 149 ) 150 151 # dimensions much match when being collapsed 152 assert_raises( 153 (RuntimeError, ValueError), 154 np.einsum, 155 "ii", 156 np.arange(6).reshape(2, 3), 157 optimize=do_opt, 158 ) 159 assert_raises( 160 (RuntimeError, ValueError), 161 np.einsum, 162 "ii->i", 163 np.arange(6).reshape(2, 3), 164 optimize=do_opt, 165 ) 166 167 # broadcasting to new dimensions must be enabled explicitly 168 assert_raises( 169 (RuntimeError, ValueError), 170 np.einsum, 171 "i", 172 np.arange(6).reshape(2, 3), 173 optimize=do_opt, 174 ) 175 assert_raises( 176 (RuntimeError, ValueError), 177 np.einsum, 178 "i->i", 179 [[0, 1], [0, 1]], 180 out=np.arange(4).reshape(2, 2), 181 optimize=do_opt, 182 ) 183 with assert_raises((RuntimeError, ValueError)): # , match="'b'"): 184 # gh-11221 - 'c' erroneously appeared in the error message 185 a = np.ones((3, 3, 4, 5, 6)) 186 b = np.ones((3, 4, 5)) 187 np.einsum("aabcb,abc", a, b) 188 189 # Check order kwarg, asanyarray allows 1d to pass through 190 assert_raises( 191 (NotImplementedError, ValueError), 192 np.einsum, 193 "i->i", 194 np.arange(6).reshape(-1, 1), 195 optimize=do_opt, 196 order="d", 197 ) 198 199 @xfail # (reason="a view into smth else") 200 def test_einsum_views(self): 201 # pass-through 202 for do_opt in [True, False]: 203 a = np.arange(6) 204 a = a.reshape(2, 3) 205 206 b = np.einsum("...", a, optimize=do_opt) 207 assert_(b.tensor._base is a.tensor) 208 209 b = np.einsum(a, [Ellipsis], optimize=do_opt) 210 assert_(b.base is a) 211 212 b = np.einsum("ij", a, optimize=do_opt) 213 assert_(b.base is a) 214 assert_equal(b, a) 215 216 b = np.einsum(a, [0, 1], optimize=do_opt) 217 assert_(b.base is a) 218 assert_equal(b, a) 219 220 # output is writeable whenever input is writeable 221 b = np.einsum("...", a, optimize=do_opt) 222 assert_(b.flags["WRITEABLE"]) 223 a.flags["WRITEABLE"] = False 224 b = np.einsum("...", a, optimize=do_opt) 225 assert_(not b.flags["WRITEABLE"]) 226 227 # transpose 228 a = np.arange(6) 229 a.shape = (2, 3) 230 231 b = np.einsum("ji", a, optimize=do_opt) 232 assert_(b.base is a) 233 assert_equal(b, a.T) 234 235 b = np.einsum(a, [1, 0], optimize=do_opt) 236 assert_(b.base is a) 237 assert_equal(b, a.T) 238 239 # diagonal 240 a = np.arange(9) 241 a.shape = (3, 3) 242 243 b = np.einsum("ii->i", a, optimize=do_opt) 244 assert_(b.base is a) 245 assert_equal(b, [a[i, i] for i in range(3)]) 246 247 b = np.einsum(a, [0, 0], [0], optimize=do_opt) 248 assert_(b.base is a) 249 assert_equal(b, [a[i, i] for i in range(3)]) 250 251 # diagonal with various ways of broadcasting an additional dimension 252 a = np.arange(27) 253 a.shape = (3, 3, 3) 254 255 b = np.einsum("...ii->...i", a, optimize=do_opt) 256 assert_(b.base is a) 257 assert_equal(b, [[x[i, i] for i in range(3)] for x in a]) 258 259 b = np.einsum(a, [Ellipsis, 0, 0], [Ellipsis, 0], optimize=do_opt) 260 assert_(b.base is a) 261 assert_equal(b, [[x[i, i] for i in range(3)] for x in a]) 262 263 b = np.einsum("ii...->...i", a, optimize=do_opt) 264 assert_(b.base is a) 265 assert_equal(b, [[x[i, i] for i in range(3)] for x in a.transpose(2, 0, 1)]) 266 267 b = np.einsum(a, [0, 0, Ellipsis], [Ellipsis, 0], optimize=do_opt) 268 assert_(b.base is a) 269 assert_equal(b, [[x[i, i] for i in range(3)] for x in a.transpose(2, 0, 1)]) 270 271 b = np.einsum("...ii->i...", a, optimize=do_opt) 272 assert_(b.base is a) 273 assert_equal(b, [a[:, i, i] for i in range(3)]) 274 275 b = np.einsum(a, [Ellipsis, 0, 0], [0, Ellipsis], optimize=do_opt) 276 assert_(b.base is a) 277 assert_equal(b, [a[:, i, i] for i in range(3)]) 278 279 b = np.einsum("jii->ij", a, optimize=do_opt) 280 assert_(b.base is a) 281 assert_equal(b, [a[:, i, i] for i in range(3)]) 282 283 b = np.einsum(a, [1, 0, 0], [0, 1], optimize=do_opt) 284 assert_(b.base is a) 285 assert_equal(b, [a[:, i, i] for i in range(3)]) 286 287 b = np.einsum("ii...->i...", a, optimize=do_opt) 288 assert_(b.base is a) 289 assert_equal(b, [a.transpose(2, 0, 1)[:, i, i] for i in range(3)]) 290 291 b = np.einsum(a, [0, 0, Ellipsis], [0, Ellipsis], optimize=do_opt) 292 assert_(b.base is a) 293 assert_equal(b, [a.transpose(2, 0, 1)[:, i, i] for i in range(3)]) 294 295 b = np.einsum("i...i->i...", a, optimize=do_opt) 296 assert_(b.base is a) 297 assert_equal(b, [a.transpose(1, 0, 2)[:, i, i] for i in range(3)]) 298 299 b = np.einsum(a, [0, Ellipsis, 0], [0, Ellipsis], optimize=do_opt) 300 assert_(b.base is a) 301 assert_equal(b, [a.transpose(1, 0, 2)[:, i, i] for i in range(3)]) 302 303 b = np.einsum("i...i->...i", a, optimize=do_opt) 304 assert_(b.base is a) 305 assert_equal(b, [[x[i, i] for i in range(3)] for x in a.transpose(1, 0, 2)]) 306 307 b = np.einsum(a, [0, Ellipsis, 0], [Ellipsis, 0], optimize=do_opt) 308 assert_(b.base is a) 309 assert_equal(b, [[x[i, i] for i in range(3)] for x in a.transpose(1, 0, 2)]) 310 311 # triple diagonal 312 a = np.arange(27) 313 a.shape = (3, 3, 3) 314 315 b = np.einsum("iii->i", a, optimize=do_opt) 316 assert_(b.base is a) 317 assert_equal(b, [a[i, i, i] for i in range(3)]) 318 319 b = np.einsum(a, [0, 0, 0], [0], optimize=do_opt) 320 assert_(b.base is a) 321 assert_equal(b, [a[i, i, i] for i in range(3)]) 322 323 # swap axes 324 a = np.arange(24) 325 a.shape = (2, 3, 4) 326 327 b = np.einsum("ijk->jik", a, optimize=do_opt) 328 assert_(b.base is a) 329 assert_equal(b, a.swapaxes(0, 1)) 330 331 b = np.einsum(a, [0, 1, 2], [1, 0, 2], optimize=do_opt) 332 assert_(b.base is a) 333 assert_equal(b, a.swapaxes(0, 1)) 334 335 # @np._no_nep50_warning() 336 def check_einsum_sums(self, dtype, do_opt=False): 337 dtype = np.dtype(dtype) 338 # Check various sums. Does many sizes to exercise unrolled loops. 339 340 # sum(a, axis=-1) 341 for n in range(1, 17): 342 a = np.arange(n, dtype=dtype) 343 assert_equal( 344 np.einsum("i->", a, optimize=do_opt), np.sum(a, axis=-1).astype(dtype) 345 ) 346 assert_equal( 347 np.einsum(a, [0], [], optimize=do_opt), np.sum(a, axis=-1).astype(dtype) 348 ) 349 350 for n in range(1, 17): 351 a = np.arange(2 * 3 * n, dtype=dtype).reshape(2, 3, n) 352 assert_equal( 353 np.einsum("...i->...", a, optimize=do_opt), 354 np.sum(a, axis=-1).astype(dtype), 355 ) 356 assert_equal( 357 np.einsum(a, [Ellipsis, 0], [Ellipsis], optimize=do_opt), 358 np.sum(a, axis=-1).astype(dtype), 359 ) 360 361 # sum(a, axis=0) 362 for n in range(1, 17): 363 a = np.arange(2 * n, dtype=dtype).reshape(2, n) 364 assert_equal( 365 np.einsum("i...->...", a, optimize=do_opt), 366 np.sum(a, axis=0).astype(dtype), 367 ) 368 assert_equal( 369 np.einsum(a, [0, Ellipsis], [Ellipsis], optimize=do_opt), 370 np.sum(a, axis=0).astype(dtype), 371 ) 372 373 for n in range(1, 17): 374 a = np.arange(2 * 3 * n, dtype=dtype).reshape(2, 3, n) 375 assert_equal( 376 np.einsum("i...->...", a, optimize=do_opt), 377 np.sum(a, axis=0).astype(dtype), 378 ) 379 assert_equal( 380 np.einsum(a, [0, Ellipsis], [Ellipsis], optimize=do_opt), 381 np.sum(a, axis=0).astype(dtype), 382 ) 383 384 # trace(a) 385 for n in range(1, 17): 386 a = np.arange(n * n, dtype=dtype).reshape(n, n) 387 assert_equal(np.einsum("ii", a, optimize=do_opt), np.trace(a).astype(dtype)) 388 assert_equal( 389 np.einsum(a, [0, 0], optimize=do_opt), # torch? 390 np.trace(a).astype(dtype), 391 ) 392 393 # gh-15961: should accept numpy int64 type in subscript list 394 # np_array = np.asarray([0, 0]) 395 # assert_equal(np.einsum(a, np_array, optimize=do_opt), 396 # np.trace(a).astype(dtype)) 397 # assert_equal(np.einsum(a, list(np_array), optimize=do_opt), 398 # np.trace(a).astype(dtype)) 399 400 # multiply(a, b) 401 assert_equal(np.einsum("..., ...", 3, 4), 12) # scalar case 402 for n in range(1, 17): 403 a = np.arange(3 * n, dtype=dtype).reshape(3, n) 404 b = np.arange(2 * 3 * n, dtype=dtype).reshape(2, 3, n) 405 assert_equal( 406 np.einsum("..., ...", a, b, optimize=do_opt), np.multiply(a, b) 407 ) 408 assert_equal( 409 np.einsum(a, [Ellipsis], b, [Ellipsis], optimize=do_opt), 410 np.multiply(a, b), 411 ) 412 413 # inner(a,b) 414 for n in range(1, 17): 415 a = np.arange(2 * 3 * n, dtype=dtype).reshape(2, 3, n) 416 b = np.arange(n, dtype=dtype) 417 assert_equal(np.einsum("...i, ...i", a, b, optimize=do_opt), np.inner(a, b)) 418 assert_equal( 419 np.einsum(a, [Ellipsis, 0], b, [Ellipsis, 0], optimize=do_opt), 420 np.inner(a, b), 421 ) 422 423 for n in range(1, 11): 424 a = np.arange(n * 3 * 2, dtype=dtype).reshape(n, 3, 2) 425 b = np.arange(n, dtype=dtype) 426 assert_equal( 427 np.einsum("i..., i...", a, b, optimize=do_opt), np.inner(a.T, b.T).T 428 ) 429 assert_equal( 430 np.einsum(a, [0, Ellipsis], b, [0, Ellipsis], optimize=do_opt), 431 np.inner(a.T, b.T).T, 432 ) 433 434 # outer(a,b) 435 for n in range(1, 17): 436 a = np.arange(3, dtype=dtype) + 1 437 b = np.arange(n, dtype=dtype) + 1 438 assert_equal(np.einsum("i,j", a, b, optimize=do_opt), np.outer(a, b)) 439 assert_equal(np.einsum(a, [0], b, [1], optimize=do_opt), np.outer(a, b)) 440 441 # Suppress the complex warnings for the 'as f8' tests 442 with suppress_warnings() as sup: 443 # sup.filter(np.ComplexWarning) 444 445 # matvec(a,b) / a.dot(b) where a is matrix, b is vector 446 for n in range(1, 17): 447 a = np.arange(4 * n, dtype=dtype).reshape(4, n) 448 b = np.arange(n, dtype=dtype) 449 assert_equal(np.einsum("ij, j", a, b, optimize=do_opt), np.dot(a, b)) 450 assert_equal( 451 np.einsum(a, [0, 1], b, [1], optimize=do_opt), np.dot(a, b) 452 ) 453 454 c = np.arange(4, dtype=dtype) 455 np.einsum( 456 "ij,j", a, b, out=c, dtype="f8", casting="unsafe", optimize=do_opt 457 ) 458 assert_equal(c, np.dot(a.astype("f8"), b.astype("f8")).astype(dtype)) 459 c[...] = 0 460 np.einsum( 461 a, 462 [0, 1], 463 b, 464 [1], 465 out=c, 466 dtype="f8", 467 casting="unsafe", 468 optimize=do_opt, 469 ) 470 assert_equal(c, np.dot(a.astype("f8"), b.astype("f8")).astype(dtype)) 471 472 for n in range(1, 17): 473 a = np.arange(4 * n, dtype=dtype).reshape(4, n) 474 b = np.arange(n, dtype=dtype) 475 assert_equal( 476 np.einsum("ji,j", a.T, b.T, optimize=do_opt), np.dot(b.T, a.T) 477 ) 478 assert_equal( 479 np.einsum(a.T, [1, 0], b.T, [1], optimize=do_opt), np.dot(b.T, a.T) 480 ) 481 482 c = np.arange(4, dtype=dtype) 483 np.einsum( 484 "ji,j", 485 a.T, 486 b.T, 487 out=c, 488 dtype="f8", 489 casting="unsafe", 490 optimize=do_opt, 491 ) 492 assert_equal( 493 c, np.dot(b.T.astype("f8"), a.T.astype("f8")).astype(dtype) 494 ) 495 c[...] = 0 496 np.einsum( 497 a.T, 498 [1, 0], 499 b.T, 500 [1], 501 out=c, 502 dtype="f8", 503 casting="unsafe", 504 optimize=do_opt, 505 ) 506 assert_equal( 507 c, np.dot(b.T.astype("f8"), a.T.astype("f8")).astype(dtype) 508 ) 509 510 # matmat(a,b) / a.dot(b) where a is matrix, b is matrix 511 for n in range(1, 17): 512 if n < 8 or dtype != "f2": 513 a = np.arange(4 * n, dtype=dtype).reshape(4, n) 514 b = np.arange(n * 6, dtype=dtype).reshape(n, 6) 515 assert_equal( 516 np.einsum("ij,jk", a, b, optimize=do_opt), np.dot(a, b) 517 ) 518 assert_equal( 519 np.einsum(a, [0, 1], b, [1, 2], optimize=do_opt), np.dot(a, b) 520 ) 521 522 for n in range(1, 17): 523 a = np.arange(4 * n, dtype=dtype).reshape(4, n) 524 b = np.arange(n * 6, dtype=dtype).reshape(n, 6) 525 c = np.arange(24, dtype=dtype).reshape(4, 6) 526 np.einsum( 527 "ij,jk", a, b, out=c, dtype="f8", casting="unsafe", optimize=do_opt 528 ) 529 assert_equal(c, np.dot(a.astype("f8"), b.astype("f8")).astype(dtype)) 530 c[...] = 0 531 np.einsum( 532 a, 533 [0, 1], 534 b, 535 [1, 2], 536 out=c, 537 dtype="f8", 538 casting="unsafe", 539 optimize=do_opt, 540 ) 541 assert_equal(c, np.dot(a.astype("f8"), b.astype("f8")).astype(dtype)) 542 543 # matrix triple product (note this is not currently an efficient 544 # way to multiply 3 matrices) 545 a = np.arange(12, dtype=dtype).reshape(3, 4) 546 b = np.arange(20, dtype=dtype).reshape(4, 5) 547 c = np.arange(30, dtype=dtype).reshape(5, 6) 548 if dtype != "f2": 549 assert_equal( 550 np.einsum("ij,jk,kl", a, b, c, optimize=do_opt), a.dot(b).dot(c) 551 ) 552 assert_equal( 553 np.einsum(a, [0, 1], b, [1, 2], c, [2, 3], optimize=do_opt), 554 a.dot(b).dot(c), 555 ) 556 557 d = np.arange(18, dtype=dtype).reshape(3, 6) 558 np.einsum( 559 "ij,jk,kl", 560 a, 561 b, 562 c, 563 out=d, 564 dtype="f8", 565 casting="unsafe", 566 optimize=do_opt, 567 ) 568 tgt = a.astype("f8").dot(b.astype("f8")) 569 tgt = tgt.dot(c.astype("f8")).astype(dtype) 570 assert_equal(d, tgt) 571 572 d[...] = 0 573 np.einsum( 574 a, 575 [0, 1], 576 b, 577 [1, 2], 578 c, 579 [2, 3], 580 out=d, 581 dtype="f8", 582 casting="unsafe", 583 optimize=do_opt, 584 ) 585 tgt = a.astype("f8").dot(b.astype("f8")) 586 tgt = tgt.dot(c.astype("f8")).astype(dtype) 587 assert_equal(d, tgt) 588 589 # tensordot(a, b) 590 if np.dtype(dtype) != np.dtype("f2"): 591 a = np.arange(60, dtype=dtype).reshape(3, 4, 5) 592 b = np.arange(24, dtype=dtype).reshape(4, 3, 2) 593 assert_equal( 594 np.einsum("ijk, jil -> kl", a, b), 595 np.tensordot(a, b, axes=([1, 0], [0, 1])), 596 ) 597 assert_equal( 598 np.einsum(a, [0, 1, 2], b, [1, 0, 3], [2, 3]), 599 np.tensordot(a, b, axes=([1, 0], [0, 1])), 600 ) 601 602 c = np.arange(10, dtype=dtype).reshape(5, 2) 603 np.einsum( 604 "ijk,jil->kl", 605 a, 606 b, 607 out=c, 608 dtype="f8", 609 casting="unsafe", 610 optimize=do_opt, 611 ) 612 assert_equal( 613 c, 614 np.tensordot( 615 a.astype("f8"), b.astype("f8"), axes=([1, 0], [0, 1]) 616 ).astype(dtype), 617 ) 618 c[...] = 0 619 np.einsum( 620 a, 621 [0, 1, 2], 622 b, 623 [1, 0, 3], 624 [2, 3], 625 out=c, 626 dtype="f8", 627 casting="unsafe", 628 optimize=do_opt, 629 ) 630 assert_equal( 631 c, 632 np.tensordot( 633 a.astype("f8"), b.astype("f8"), axes=([1, 0], [0, 1]) 634 ).astype(dtype), 635 ) 636 637 # logical_and(logical_and(a!=0, b!=0), c!=0) 638 neg_val = -2 if dtype.kind != "u" else np.iinfo(dtype).max - 1 639 a = np.array([1, 3, neg_val, 0, 12, 13, 0, 1], dtype=dtype) 640 b = np.array([0, 3.5, 0.0, neg_val, 0, 1, 3, 12], dtype=dtype) 641 c = np.array([True, True, False, True, True, False, True, True]) 642 643 assert_equal( 644 np.einsum( 645 "i,i,i->i", a, b, c, dtype="?", casting="unsafe", optimize=do_opt 646 ), 647 np.logical_and(np.logical_and(a != 0, b != 0), c != 0), 648 ) 649 assert_equal( 650 np.einsum(a, [0], b, [0], c, [0], [0], dtype="?", casting="unsafe"), 651 np.logical_and(np.logical_and(a != 0, b != 0), c != 0), 652 ) 653 654 a = np.arange(9, dtype=dtype) 655 assert_equal(np.einsum(",i->", 3, a), 3 * np.sum(a)) 656 assert_equal(np.einsum(3, [], a, [0], []), 3 * np.sum(a)) 657 assert_equal(np.einsum("i,->", a, 3), 3 * np.sum(a)) 658 assert_equal(np.einsum(a, [0], 3, [], []), 3 * np.sum(a)) 659 660 # Various stride0, contiguous, and SSE aligned variants 661 for n in range(1, 25): 662 a = np.arange(n, dtype=dtype) 663 if np.dtype(dtype).itemsize > 1: 664 assert_equal( 665 np.einsum("...,...", a, a, optimize=do_opt), np.multiply(a, a) 666 ) 667 assert_equal(np.einsum("i,i", a, a, optimize=do_opt), np.dot(a, a)) 668 assert_equal(np.einsum("i,->i", a, 2, optimize=do_opt), 2 * a) 669 assert_equal(np.einsum(",i->i", 2, a, optimize=do_opt), 2 * a) 670 assert_equal(np.einsum("i,->", a, 2, optimize=do_opt), 2 * np.sum(a)) 671 assert_equal(np.einsum(",i->", 2, a, optimize=do_opt), 2 * np.sum(a)) 672 673 assert_equal( 674 np.einsum("...,...", a[1:], a[:-1], optimize=do_opt), 675 np.multiply(a[1:], a[:-1]), 676 ) 677 assert_equal( 678 np.einsum("i,i", a[1:], a[:-1], optimize=do_opt), 679 np.dot(a[1:], a[:-1]), 680 ) 681 assert_equal(np.einsum("i,->i", a[1:], 2, optimize=do_opt), 2 * a[1:]) 682 assert_equal(np.einsum(",i->i", 2, a[1:], optimize=do_opt), 2 * a[1:]) 683 assert_equal( 684 np.einsum("i,->", a[1:], 2, optimize=do_opt), 2 * np.sum(a[1:]) 685 ) 686 assert_equal( 687 np.einsum(",i->", 2, a[1:], optimize=do_opt), 2 * np.sum(a[1:]) 688 ) 689 690 # An object array, summed as the data type 691 # a = np.arange(9, dtype=object) 692 # 693 # b = np.einsum("i->", a, dtype=dtype, casting='unsafe') 694 # assert_equal(b, np.sum(a)) 695 # assert_equal(b.dtype, np.dtype(dtype)) 696 # 697 # b = np.einsum(a, [0], [], dtype=dtype, casting='unsafe') 698 # assert_equal(b, np.sum(a)) 699 # assert_equal(b.dtype, np.dtype(dtype)) 700 701 # A case which was failing (ticket #1885) 702 p = np.arange(2) + 1 703 q = np.arange(4).reshape(2, 2) + 3 704 r = np.arange(4).reshape(2, 2) + 7 705 assert_equal(np.einsum("z,mz,zm->", p, q, r), 253) 706 707 # singleton dimensions broadcast (gh-10343) 708 p = np.ones((10, 2)) 709 q = np.ones((1, 2)) 710 assert_array_equal( 711 np.einsum("ij,ij->j", p, q, optimize=True), 712 np.einsum("ij,ij->j", p, q, optimize=False), 713 ) 714 assert_array_equal(np.einsum("ij,ij->j", p, q, optimize=True), [10.0] * 2) 715 716 # a blas-compatible contraction broadcasting case which was failing 717 # for optimize=True (ticket #10930) 718 x = np.array([2.0, 3.0]) 719 y = np.array([4.0]) 720 assert_array_equal(np.einsum("i, i", x, y, optimize=False), 20.0) 721 assert_array_equal(np.einsum("i, i", x, y, optimize=True), 20.0) 722 723 # all-ones array was bypassing bug (ticket #10930) 724 p = np.ones((1, 5)) / 2 725 q = np.ones((5, 5)) / 2 726 for optimize in (True, False): 727 assert_array_equal( 728 np.einsum("...ij,...jk->...ik", p, p, optimize=optimize), 729 np.einsum("...ij,...jk->...ik", p, q, optimize=optimize), 730 ) 731 assert_array_equal( 732 np.einsum("...ij,...jk->...ik", p, q, optimize=optimize), 733 np.full((1, 5), 1.25), 734 ) 735 736 # Cases which were failing (gh-10899) 737 x = np.eye(2, dtype=dtype) 738 y = np.ones(2, dtype=dtype) 739 assert_array_equal( 740 np.einsum("ji,i->", x, y, optimize=optimize), [2.0] 741 ) # contig_contig_outstride0_two 742 assert_array_equal( 743 np.einsum("i,ij->", y, x, optimize=optimize), [2.0] 744 ) # stride0_contig_outstride0_two 745 assert_array_equal( 746 np.einsum("ij,i->", x, y, optimize=optimize), [2.0] 747 ) # contig_stride0_outstride0_two 748 749 @xfail # (reason="int overflow differs in numpy and pytorch") 750 def test_einsum_sums_int8(self): 751 self.check_einsum_sums("i1") 752 753 @xfail # (reason="int overflow differs in numpy and pytorch") 754 def test_einsum_sums_uint8(self): 755 self.check_einsum_sums("u1") 756 757 @xfail # (reason="int overflow differs in numpy and pytorch") 758 def test_einsum_sums_int16(self): 759 self.check_einsum_sums("i2") 760 761 def test_einsum_sums_int32(self): 762 self.check_einsum_sums("i4") 763 self.check_einsum_sums("i4", True) 764 765 def test_einsum_sums_int64(self): 766 self.check_einsum_sums("i8") 767 768 @xfail # (reason="np.float16(4641) == 4640.0") 769 def test_einsum_sums_float16(self): 770 self.check_einsum_sums("f2") 771 772 def test_einsum_sums_float32(self): 773 self.check_einsum_sums("f4") 774 775 def test_einsum_sums_float64(self): 776 self.check_einsum_sums("f8") 777 self.check_einsum_sums("f8", True) 778 779 def test_einsum_sums_cfloat64(self): 780 self.check_einsum_sums("c8") 781 self.check_einsum_sums("c8", True) 782 783 def test_einsum_sums_cfloat128(self): 784 self.check_einsum_sums("c16") 785 786 def test_einsum_misc(self): 787 # This call used to crash because of a bug in 788 # PyArray_AssignZero 789 a = np.ones((1, 2)) 790 b = np.ones((2, 2, 1)) 791 assert_equal(np.einsum("ij...,j...->i...", a, b), [[[2], [2]]]) 792 assert_equal(np.einsum("ij...,j...->i...", a, b, optimize=True), [[[2], [2]]]) 793 794 # Regression test for issue #10369 (test unicode inputs with Python 2) 795 assert_equal(np.einsum("ij...,j...->i...", a, b), [[[2], [2]]]) 796 assert_equal(np.einsum("...i,...i", [1, 2, 3], [2, 3, 4]), 20) 797 assert_equal( 798 np.einsum("...i,...i", [1, 2, 3], [2, 3, 4], optimize="greedy"), 20 799 ) 800 801 # The iterator had an issue with buffering this reduction 802 a = np.ones((5, 12, 4, 2, 3), np.int64) 803 b = np.ones((5, 12, 11), np.int64) 804 assert_equal( 805 np.einsum("ijklm,ijn,ijn->", a, b, b), np.einsum("ijklm,ijn->", a, b) 806 ) 807 assert_equal( 808 np.einsum("ijklm,ijn,ijn->", a, b, b, optimize=True), 809 np.einsum("ijklm,ijn->", a, b, optimize=True), 810 ) 811 812 # Issue #2027, was a problem in the contiguous 3-argument 813 # inner loop implementation 814 a = np.arange(1, 3) 815 b = np.arange(1, 5).reshape(2, 2) 816 c = np.arange(1, 9).reshape(4, 2) 817 assert_equal( 818 np.einsum("x,yx,zx->xzy", a, b, c), 819 [ 820 [[1, 3], [3, 9], [5, 15], [7, 21]], 821 [[8, 16], [16, 32], [24, 48], [32, 64]], 822 ], 823 ) 824 assert_equal( 825 np.einsum("x,yx,zx->xzy", a, b, c, optimize=True), 826 [ 827 [[1, 3], [3, 9], [5, 15], [7, 21]], 828 [[8, 16], [16, 32], [24, 48], [32, 64]], 829 ], 830 ) 831 832 # Ensure explicitly setting out=None does not cause an error 833 # see issue gh-15776 and issue gh-15256 834 assert_equal(np.einsum("i,j", [1], [2], out=None), [[2]]) 835 836 def test_subscript_range(self): 837 # Issue #7741, make sure that all letters of Latin alphabet (both uppercase & lowercase) can be used 838 # when creating a subscript from arrays 839 a = np.ones((2, 3)) 840 b = np.ones((3, 4)) 841 np.einsum(a, [0, 20], b, [20, 2], [0, 2], optimize=False) 842 np.einsum(a, [0, 27], b, [27, 2], [0, 2], optimize=False) 843 np.einsum(a, [0, 51], b, [51, 2], [0, 2], optimize=False) 844 assert_raises( 845 ValueError, 846 lambda: np.einsum(a, [0, 52], b, [52, 2], [0, 2], optimize=False), 847 ) 848 assert_raises( 849 ValueError, 850 lambda: np.einsum(a, [-1, 5], b, [5, 2], [-1, 2], optimize=False), 851 ) 852 853 def test_einsum_broadcast(self): 854 # Issue #2455 change in handling ellipsis 855 # remove the 'middle broadcast' error 856 # only use the 'RIGHT' iteration in prepare_op_axes 857 # adds auto broadcast on left where it belongs 858 # broadcast on right has to be explicit 859 # We need to test the optimized parsing as well 860 861 A = np.arange(2 * 3 * 4).reshape(2, 3, 4) 862 B = np.arange(3) 863 ref = np.einsum("ijk,j->ijk", A, B, optimize=False) 864 for opt in [True, False]: 865 assert_equal(np.einsum("ij...,j...->ij...", A, B, optimize=opt), ref) 866 assert_equal(np.einsum("ij...,...j->ij...", A, B, optimize=opt), ref) 867 assert_equal( 868 np.einsum("ij...,j->ij...", A, B, optimize=opt), ref 869 ) # used to raise error 870 871 A = np.arange(12).reshape((4, 3)) 872 B = np.arange(6).reshape((3, 2)) 873 ref = np.einsum("ik,kj->ij", A, B, optimize=False) 874 for opt in [True, False]: 875 assert_equal(np.einsum("ik...,k...->i...", A, B, optimize=opt), ref) 876 assert_equal(np.einsum("ik...,...kj->i...j", A, B, optimize=opt), ref) 877 assert_equal( 878 np.einsum("...k,kj", A, B, optimize=opt), ref 879 ) # used to raise error 880 assert_equal( 881 np.einsum("ik,k...->i...", A, B, optimize=opt), ref 882 ) # used to raise error 883 884 dims = [2, 3, 4, 5] 885 a = np.arange(np.prod(dims)).reshape(dims) 886 v = np.arange(dims[2]) 887 ref = np.einsum("ijkl,k->ijl", a, v, optimize=False) 888 for opt in [True, False]: 889 assert_equal(np.einsum("ijkl,k", a, v, optimize=opt), ref) 890 assert_equal( 891 np.einsum("...kl,k", a, v, optimize=opt), ref 892 ) # used to raise error 893 assert_equal(np.einsum("...kl,k...", a, v, optimize=opt), ref) 894 895 J, K, M = 160, 160, 120 896 A = np.arange(J * K * M).reshape(1, 1, 1, J, K, M) 897 B = np.arange(J * K * M * 3).reshape(J, K, M, 3) 898 ref = np.einsum("...lmn,...lmno->...o", A, B, optimize=False) 899 for opt in [True, False]: 900 assert_equal( 901 np.einsum("...lmn,lmno->...o", A, B, optimize=opt), ref 902 ) # used to raise error 903 904 def test_einsum_fixedstridebug(self): 905 # Issue #4485 obscure einsum bug 906 # This case revealed a bug in nditer where it reported a stride 907 # as 'fixed' (0) when it was in fact not fixed during processing 908 # (0 or 4). The reason for the bug was that the check for a fixed 909 # stride was using the information from the 2D inner loop reuse 910 # to restrict the iteration dimensions it had to validate to be 911 # the same, but that 2D inner loop reuse logic is only triggered 912 # during the buffer copying step, and hence it was invalid to 913 # rely on those values. The fix is to check all the dimensions 914 # of the stride in question, which in the test case reveals that 915 # the stride is not fixed. 916 # 917 # NOTE: This test is triggered by the fact that the default buffersize, 918 # used by einsum, is 8192, and 3*2731 = 8193, is larger than that 919 # and results in a mismatch between the buffering and the 920 # striding for operand A. 921 A = np.arange(2 * 3).reshape(2, 3).astype(np.float32) 922 B = np.arange(2 * 3 * 2731).reshape(2, 3, 2731).astype(np.int16) 923 es = np.einsum("cl, cpx->lpx", A, B) 924 tp = np.tensordot(A, B, axes=(0, 0)) 925 assert_equal(es, tp) 926 # The following is the original test case from the bug report, 927 # made repeatable by changing random arrays to aranges. 928 A = np.arange(3 * 3).reshape(3, 3).astype(np.float64) 929 B = np.arange(3 * 3 * 64 * 64).reshape(3, 3, 64, 64).astype(np.float32) 930 es = np.einsum("cl, cpxy->lpxy", A, B) 931 tp = np.tensordot(A, B, axes=(0, 0)) 932 assert_equal(es, tp) 933 934 def test_einsum_fixed_collapsingbug(self): 935 # Issue #5147. 936 # The bug only occurred when output argument of einssum was used. 937 x = np.random.normal(0, 1, (5, 5, 5, 5)) 938 y1 = np.zeros((5, 5)) 939 np.einsum("aabb->ab", x, out=y1) 940 idx = np.arange(5) 941 y2 = x[idx[:, None], idx[:, None], idx, idx] 942 assert_equal(y1, y2) 943 944 def test_einsum_failed_on_p9_and_s390x(self): 945 # Issues gh-14692 and gh-12689 946 # Bug with signed vs unsigned char errored on power9 and s390x Linux 947 tensor = np.random.random_sample((10, 10, 10, 10)) 948 x = np.einsum("ijij->", tensor) 949 y = tensor.trace(axis1=0, axis2=2).trace() 950 assert_allclose(x, y) 951 952 @xfail # (reason="no base") 953 def test_einsum_all_contig_non_contig_output(self): 954 # Issue gh-5907, tests that the all contiguous special case 955 # actually checks the contiguity of the output 956 x = np.ones((5, 5)) 957 out = np.ones(10)[::2] 958 correct_base = np.ones(10) 959 correct_base[::2] = 5 960 # Always worked (inner iteration is done with 0-stride): 961 np.einsum("mi,mi,mi->m", x, x, x, out=out) 962 assert_array_equal(out.base, correct_base) 963 # Example 1: 964 out = np.ones(10)[::2] 965 np.einsum("im,im,im->m", x, x, x, out=out) 966 assert_array_equal(out.base, correct_base) 967 # Example 2, buffering causes x to be contiguous but 968 # special cases do not catch the operation before: 969 out = np.ones((2, 2, 2))[..., 0] 970 correct_base = np.ones((2, 2, 2)) 971 correct_base[..., 0] = 2 972 x = np.ones((2, 2), np.float32) 973 np.einsum("ij,jk->ik", x, x, out=out) 974 assert_array_equal(out.base, correct_base) 975 976 @parametrize("dtype", np.typecodes["AllFloat"] + np.typecodes["AllInteger"]) 977 def test_different_paths(self, dtype): 978 # Test originally added to cover broken float16 path: gh-20305 979 # Likely most are covered elsewhere, at least partially. 980 dtype = np.dtype(dtype) 981 # Simple test, designed to excersize most specialized code paths, 982 # note the +0.5 for floats. This makes sure we use a float value 983 # where the results must be exact. 984 arr = (np.arange(7) + 0.5).astype(dtype) 985 scalar = np.array(2, dtype=dtype) 986 987 # contig -> scalar: 988 res = np.einsum("i->", arr) 989 assert res == arr.sum() 990 # contig, contig -> contig: 991 res = np.einsum("i,i->i", arr, arr) 992 assert_array_equal(res, arr * arr) 993 # noncontig, noncontig -> contig: 994 res = np.einsum("i,i->i", arr.repeat(2)[::2], arr.repeat(2)[::2]) 995 assert_array_equal(res, arr * arr) 996 # contig + contig -> scalar 997 assert np.einsum("i,i->", arr, arr) == (arr * arr).sum() 998 # contig + scalar -> contig (with out) 999 out = np.ones(7, dtype=dtype) 1000 res = np.einsum("i,->i", arr, dtype.type(2), out=out) 1001 assert_array_equal(res, arr * dtype.type(2)) 1002 # scalar + contig -> contig (with out) 1003 res = np.einsum(",i->i", scalar, arr) 1004 assert_array_equal(res, arr * dtype.type(2)) 1005 # scalar + contig -> scalar 1006 res = np.einsum(",i->", scalar, arr) 1007 # Use einsum to compare to not have difference due to sum round-offs: 1008 assert res == np.einsum("i->", scalar * arr) 1009 # contig + scalar -> scalar 1010 res = np.einsum("i,->", arr, scalar) 1011 # Use einsum to compare to not have difference due to sum round-offs: 1012 assert res == np.einsum("i->", scalar * arr) 1013 # contig + contig + contig -> scalar 1014 1015 if dtype in ["e", "B", "b"]: 1016 # FIXME make xfail 1017 raise SkipTest("overflow differs in pytorch and numpy") 1018 1019 arr = np.array([0.5, 0.5, 0.25, 4.5, 3.0], dtype=dtype) 1020 res = np.einsum("i,i,i->", arr, arr, arr) 1021 assert_array_equal(res, (arr * arr * arr).sum()) 1022 # four arrays: 1023 res = np.einsum("i,i,i,i->", arr, arr, arr, arr) 1024 assert_array_equal(res, (arr * arr * arr * arr).sum()) 1025 1026 def test_small_boolean_arrays(self): 1027 # See gh-5946. 1028 # Use array of True embedded in False. 1029 a = np.zeros((16, 1, 1), dtype=np.bool_)[:2] 1030 a[...] = True 1031 out = np.zeros((16, 1, 1), dtype=np.bool_)[:2] 1032 tgt = np.ones((2, 1, 1), dtype=np.bool_) 1033 res = np.einsum("...ij,...jk->...ik", a, a, out=out) 1034 assert_equal(res, tgt) 1035 1036 def test_out_is_res(self): 1037 a = np.arange(9).reshape(3, 3) 1038 res = np.einsum("...ij,...jk->...ik", a, a, out=a) 1039 assert res is a 1040 1041 def optimize_compare(self, subscripts, operands=None): 1042 # Tests all paths of the optimization function against 1043 # conventional einsum 1044 if operands is None: 1045 args = [subscripts] 1046 terms = subscripts.split("->")[0].split(",") 1047 for term in terms: 1048 dims = [global_size_dict[x] for x in term] 1049 args.append(np.random.rand(*dims)) 1050 else: 1051 args = [subscripts] + operands 1052 1053 noopt = np.einsum(*args, optimize=False) 1054 opt = np.einsum(*args, optimize="greedy") 1055 assert_almost_equal(opt, noopt) 1056 opt = np.einsum(*args, optimize="optimal") 1057 assert_almost_equal(opt, noopt) 1058 1059 def test_hadamard_like_products(self): 1060 # Hadamard outer products 1061 self.optimize_compare("a,ab,abc->abc") 1062 self.optimize_compare("a,b,ab->ab") 1063 1064 def test_index_transformations(self): 1065 # Simple index transformation cases 1066 self.optimize_compare("ea,fb,gc,hd,abcd->efgh") 1067 self.optimize_compare("ea,fb,abcd,gc,hd->efgh") 1068 self.optimize_compare("abcd,ea,fb,gc,hd->efgh") 1069 1070 def test_complex(self): 1071 # Long test cases 1072 self.optimize_compare("acdf,jbje,gihb,hfac,gfac,gifabc,hfac") 1073 self.optimize_compare("acdf,jbje,gihb,hfac,gfac,gifabc,hfac") 1074 self.optimize_compare("cd,bdhe,aidb,hgca,gc,hgibcd,hgac") 1075 self.optimize_compare("abhe,hidj,jgba,hiab,gab") 1076 self.optimize_compare("bde,cdh,agdb,hica,ibd,hgicd,hiac") 1077 self.optimize_compare("chd,bde,agbc,hiad,hgc,hgi,hiad") 1078 self.optimize_compare("chd,bde,agbc,hiad,bdi,cgh,agdb") 1079 self.optimize_compare("bdhe,acad,hiab,agac,hibd") 1080 1081 def test_collapse(self): 1082 # Inner products 1083 self.optimize_compare("ab,ab,c->") 1084 self.optimize_compare("ab,ab,c->c") 1085 self.optimize_compare("ab,ab,cd,cd->") 1086 self.optimize_compare("ab,ab,cd,cd->ac") 1087 self.optimize_compare("ab,ab,cd,cd->cd") 1088 self.optimize_compare("ab,ab,cd,cd,ef,ef->") 1089 1090 def test_expand(self): 1091 # Outer products 1092 self.optimize_compare("ab,cd,ef->abcdef") 1093 self.optimize_compare("ab,cd,ef->acdf") 1094 self.optimize_compare("ab,cd,de->abcde") 1095 self.optimize_compare("ab,cd,de->be") 1096 self.optimize_compare("ab,bcd,cd->abcd") 1097 self.optimize_compare("ab,bcd,cd->abd") 1098 1099 def test_edge_cases(self): 1100 # Difficult edge cases for optimization 1101 self.optimize_compare("eb,cb,fb->cef") 1102 self.optimize_compare("dd,fb,be,cdb->cef") 1103 self.optimize_compare("bca,cdb,dbf,afc->") 1104 self.optimize_compare("dcc,fce,ea,dbf->ab") 1105 self.optimize_compare("fdf,cdd,ccd,afe->ae") 1106 self.optimize_compare("abcd,ad") 1107 self.optimize_compare("ed,fcd,ff,bcf->be") 1108 self.optimize_compare("baa,dcf,af,cde->be") 1109 self.optimize_compare("bd,db,eac->ace") 1110 self.optimize_compare("fff,fae,bef,def->abd") 1111 self.optimize_compare("efc,dbc,acf,fd->abe") 1112 self.optimize_compare("ba,ac,da->bcd") 1113 1114 def test_inner_product(self): 1115 # Inner products 1116 self.optimize_compare("ab,ab") 1117 self.optimize_compare("ab,ba") 1118 self.optimize_compare("abc,abc") 1119 self.optimize_compare("abc,bac") 1120 self.optimize_compare("abc,cba") 1121 1122 def test_random_cases(self): 1123 # Randomly built test cases 1124 self.optimize_compare("aab,fa,df,ecc->bde") 1125 self.optimize_compare("ecb,fef,bad,ed->ac") 1126 self.optimize_compare("bcf,bbb,fbf,fc->") 1127 self.optimize_compare("bb,ff,be->e") 1128 self.optimize_compare("bcb,bb,fc,fff->") 1129 self.optimize_compare("fbb,dfd,fc,fc->") 1130 self.optimize_compare("afd,ba,cc,dc->bf") 1131 self.optimize_compare("adb,bc,fa,cfc->d") 1132 self.optimize_compare("bbd,bda,fc,db->acf") 1133 self.optimize_compare("dba,ead,cad->bce") 1134 self.optimize_compare("aef,fbc,dca->bde") 1135 1136 def test_combined_views_mapping(self): 1137 # gh-10792 1138 a = np.arange(9).reshape(1, 1, 3, 1, 3) 1139 b = np.einsum("bbcdc->d", a) 1140 assert_equal(b, [12]) 1141 1142 def test_broadcasting_dot_cases(self): 1143 # Ensures broadcasting cases are not mistaken for GEMM 1144 1145 a = np.random.rand(1, 5, 4) 1146 b = np.random.rand(4, 6) 1147 c = np.random.rand(5, 6) 1148 d = np.random.rand(10) 1149 1150 self.optimize_compare("ijk,kl,jl", operands=[a, b, c]) 1151 self.optimize_compare("ijk,kl,jl,i->i", operands=[a, b, c, d]) 1152 1153 e = np.random.rand(1, 1, 5, 4) 1154 f = np.random.rand(7, 7) 1155 self.optimize_compare("abjk,kl,jl", operands=[e, b, c]) 1156 self.optimize_compare("abjk,kl,jl,ab->ab", operands=[e, b, c, f]) 1157 1158 # Edge case found in gh-11308 1159 g = np.arange(64).reshape(2, 4, 8) 1160 self.optimize_compare("obk,ijk->ioj", operands=[g, g]) 1161 1162 @xfail # (reason="order='F' not supported") 1163 def test_output_order(self): 1164 # Ensure output order is respected for optimize cases, the below 1165 # conraction should yield a reshaped tensor view 1166 # gh-16415 1167 1168 a = np.ones((2, 3, 5), order="F") 1169 b = np.ones((4, 3), order="F") 1170 1171 for opt in [True, False]: 1172 tmp = np.einsum("...ft,mf->...mt", a, b, order="a", optimize=opt) 1173 assert_(tmp.flags.f_contiguous) 1174 1175 tmp = np.einsum("...ft,mf->...mt", a, b, order="f", optimize=opt) 1176 assert_(tmp.flags.f_contiguous) 1177 1178 tmp = np.einsum("...ft,mf->...mt", a, b, order="c", optimize=opt) 1179 assert_(tmp.flags.c_contiguous) 1180 1181 tmp = np.einsum("...ft,mf->...mt", a, b, order="k", optimize=opt) 1182 assert_(tmp.flags.c_contiguous is False) 1183 assert_(tmp.flags.f_contiguous is False) 1184 1185 tmp = np.einsum("...ft,mf->...mt", a, b, optimize=opt) 1186 assert_(tmp.flags.c_contiguous is False) 1187 assert_(tmp.flags.f_contiguous is False) 1188 1189 c = np.ones((4, 3), order="C") 1190 for opt in [True, False]: 1191 tmp = np.einsum("...ft,mf->...mt", a, c, order="a", optimize=opt) 1192 assert_(tmp.flags.c_contiguous) 1193 1194 d = np.ones((2, 3, 5), order="C") 1195 for opt in [True, False]: 1196 tmp = np.einsum("...ft,mf->...mt", d, c, order="a", optimize=opt) 1197 assert_(tmp.flags.c_contiguous) 1198 1199 1200@skip(reason="no pytorch analog") 1201class TestEinsumPath(TestCase): 1202 def build_operands(self, string, size_dict=global_size_dict): 1203 # Builds views based off initial operands 1204 operands = [string] 1205 terms = string.split("->")[0].split(",") 1206 for term in terms: 1207 dims = [size_dict[x] for x in term] 1208 operands.append(np.random.rand(*dims)) 1209 1210 return operands 1211 1212 def assert_path_equal(self, comp, benchmark): 1213 # Checks if list of tuples are equivalent 1214 ret = len(comp) == len(benchmark) 1215 assert_(ret) 1216 for pos in range(len(comp) - 1): 1217 ret &= isinstance(comp[pos + 1], tuple) 1218 ret &= comp[pos + 1] == benchmark[pos + 1] 1219 assert_(ret) 1220 1221 def test_memory_contraints(self): 1222 # Ensure memory constraints are satisfied 1223 1224 outer_test = self.build_operands("a,b,c->abc") 1225 1226 path, path_str = np.einsum_path(*outer_test, optimize=("greedy", 0)) 1227 self.assert_path_equal(path, ["einsum_path", (0, 1, 2)]) 1228 1229 path, path_str = np.einsum_path(*outer_test, optimize=("optimal", 0)) 1230 self.assert_path_equal(path, ["einsum_path", (0, 1, 2)]) 1231 1232 long_test = self.build_operands("acdf,jbje,gihb,hfac") 1233 path, path_str = np.einsum_path(*long_test, optimize=("greedy", 0)) 1234 self.assert_path_equal(path, ["einsum_path", (0, 1, 2, 3)]) 1235 1236 path, path_str = np.einsum_path(*long_test, optimize=("optimal", 0)) 1237 self.assert_path_equal(path, ["einsum_path", (0, 1, 2, 3)]) 1238 1239 def test_long_paths(self): 1240 # Long complex cases 1241 1242 # Long test 1 1243 long_test1 = self.build_operands("acdf,jbje,gihb,hfac,gfac,gifabc,hfac") 1244 path, path_str = np.einsum_path(*long_test1, optimize="greedy") 1245 self.assert_path_equal( 1246 path, ["einsum_path", (3, 6), (3, 4), (2, 4), (2, 3), (0, 2), (0, 1)] 1247 ) 1248 1249 path, path_str = np.einsum_path(*long_test1, optimize="optimal") 1250 self.assert_path_equal( 1251 path, ["einsum_path", (3, 6), (3, 4), (2, 4), (2, 3), (0, 2), (0, 1)] 1252 ) 1253 1254 # Long test 2 1255 long_test2 = self.build_operands("chd,bde,agbc,hiad,bdi,cgh,agdb") 1256 path, path_str = np.einsum_path(*long_test2, optimize="greedy") 1257 self.assert_path_equal( 1258 path, ["einsum_path", (3, 4), (0, 3), (3, 4), (1, 3), (1, 2), (0, 1)] 1259 ) 1260 1261 path, path_str = np.einsum_path(*long_test2, optimize="optimal") 1262 self.assert_path_equal( 1263 path, ["einsum_path", (0, 5), (1, 4), (3, 4), (1, 3), (1, 2), (0, 1)] 1264 ) 1265 1266 def test_edge_paths(self): 1267 # Difficult edge cases 1268 1269 # Edge test1 1270 edge_test1 = self.build_operands("eb,cb,fb->cef") 1271 path, path_str = np.einsum_path(*edge_test1, optimize="greedy") 1272 self.assert_path_equal(path, ["einsum_path", (0, 2), (0, 1)]) 1273 1274 path, path_str = np.einsum_path(*edge_test1, optimize="optimal") 1275 self.assert_path_equal(path, ["einsum_path", (0, 2), (0, 1)]) 1276 1277 # Edge test2 1278 edge_test2 = self.build_operands("dd,fb,be,cdb->cef") 1279 path, path_str = np.einsum_path(*edge_test2, optimize="greedy") 1280 self.assert_path_equal(path, ["einsum_path", (0, 3), (0, 1), (0, 1)]) 1281 1282 path, path_str = np.einsum_path(*edge_test2, optimize="optimal") 1283 self.assert_path_equal(path, ["einsum_path", (0, 3), (0, 1), (0, 1)]) 1284 1285 # Edge test3 1286 edge_test3 = self.build_operands("bca,cdb,dbf,afc->") 1287 path, path_str = np.einsum_path(*edge_test3, optimize="greedy") 1288 self.assert_path_equal(path, ["einsum_path", (1, 2), (0, 2), (0, 1)]) 1289 1290 path, path_str = np.einsum_path(*edge_test3, optimize="optimal") 1291 self.assert_path_equal(path, ["einsum_path", (1, 2), (0, 2), (0, 1)]) 1292 1293 # Edge test4 1294 edge_test4 = self.build_operands("dcc,fce,ea,dbf->ab") 1295 path, path_str = np.einsum_path(*edge_test4, optimize="greedy") 1296 self.assert_path_equal(path, ["einsum_path", (1, 2), (0, 1), (0, 1)]) 1297 1298 path, path_str = np.einsum_path(*edge_test4, optimize="optimal") 1299 self.assert_path_equal(path, ["einsum_path", (1, 2), (0, 2), (0, 1)]) 1300 1301 # Edge test5 1302 edge_test4 = self.build_operands( 1303 "a,ac,ab,ad,cd,bd,bc->", size_dict={"a": 20, "b": 20, "c": 20, "d": 20} 1304 ) 1305 path, path_str = np.einsum_path(*edge_test4, optimize="greedy") 1306 self.assert_path_equal(path, ["einsum_path", (0, 1), (0, 1, 2, 3, 4, 5)]) 1307 1308 path, path_str = np.einsum_path(*edge_test4, optimize="optimal") 1309 self.assert_path_equal(path, ["einsum_path", (0, 1), (0, 1, 2, 3, 4, 5)]) 1310 1311 def test_path_type_input(self): 1312 # Test explicit path handling 1313 path_test = self.build_operands("dcc,fce,ea,dbf->ab") 1314 1315 path, path_str = np.einsum_path(*path_test, optimize=False) 1316 self.assert_path_equal(path, ["einsum_path", (0, 1, 2, 3)]) 1317 1318 path, path_str = np.einsum_path(*path_test, optimize=True) 1319 self.assert_path_equal(path, ["einsum_path", (1, 2), (0, 1), (0, 1)]) 1320 1321 exp_path = ["einsum_path", (0, 2), (0, 2), (0, 1)] 1322 path, path_str = np.einsum_path(*path_test, optimize=exp_path) 1323 self.assert_path_equal(path, exp_path) 1324 1325 # Double check einsum works on the input path 1326 noopt = np.einsum(*path_test, optimize=False) 1327 opt = np.einsum(*path_test, optimize=exp_path) 1328 assert_almost_equal(noopt, opt) 1329 1330 def test_path_type_input_internal_trace(self): 1331 # gh-20962 1332 path_test = self.build_operands("cab,cdd->ab") 1333 exp_path = ["einsum_path", (1,), (0, 1)] 1334 1335 path, path_str = np.einsum_path(*path_test, optimize=exp_path) 1336 self.assert_path_equal(path, exp_path) 1337 1338 # Double check einsum works on the input path 1339 noopt = np.einsum(*path_test, optimize=False) 1340 opt = np.einsum(*path_test, optimize=exp_path) 1341 assert_almost_equal(noopt, opt) 1342 1343 def test_path_type_input_invalid(self): 1344 path_test = self.build_operands("ab,bc,cd,de->ae") 1345 exp_path = ["einsum_path", (2, 3), (0, 1)] 1346 assert_raises(RuntimeError, np.einsum, *path_test, optimize=exp_path) 1347 assert_raises(RuntimeError, np.einsum_path, *path_test, optimize=exp_path) 1348 1349 path_test = self.build_operands("a,a,a->a") 1350 exp_path = ["einsum_path", (1,), (0, 1)] 1351 assert_raises(RuntimeError, np.einsum, *path_test, optimize=exp_path) 1352 assert_raises(RuntimeError, np.einsum_path, *path_test, optimize=exp_path) 1353 1354 def test_spaces(self): 1355 # gh-10794 1356 arr = np.array([[1]]) 1357 for sp in itertools.product(["", " "], repeat=4): 1358 # no error for any spacing 1359 np.einsum("{}...a{}->{}...a{}".format(*sp), arr) 1360 1361 1362class TestMisc(TestCase): 1363 def test_overlap(self): 1364 a = np.arange(9, dtype=int).reshape(3, 3) 1365 b = np.arange(9, dtype=int).reshape(3, 3) 1366 d = np.dot(a, b) 1367 # sanity check 1368 c = np.einsum("ij,jk->ik", a, b) 1369 assert_equal(c, d) 1370 # gh-10080, out overlaps one of the operands 1371 c = np.einsum("ij,jk->ik", a, b, out=b) 1372 assert_equal(c, d) 1373 1374 1375if __name__ == "__main__": 1376 run_tests() 1377