1# Owner(s): ["module: inductor"] 2 3import sys 4import unittest 5 6import torch 7import torch._inductor 8from torch._inductor.test_case import TestCase 9from torch.testing._internal.common_utils import ( 10 instantiate_parametrized_tests, 11 IS_FBCODE, 12 parametrize, 13) 14from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA 15from torch.testing._internal.triton_utils import requires_cuda 16 17 18aten = torch.ops.aten 19 20try: 21 try: 22 from .test_torchinductor import check_model, check_model_cuda 23 except ImportError: 24 from test_torchinductor import check_model, check_model_cuda 25except (unittest.SkipTest, ImportError) as e: 26 sys.stderr.write(f"{type(e)}: {e}\n") 27 if __name__ == "__main__": 28 sys.exit(0) 29 raise 30 31inplace_bin_ops_under_test = [ 32 torch._foreach_add_, 33 torch._foreach_mul_, 34 torch._foreach_sub_, 35 torch._foreach_div_, 36] 37 38bin_ops_under_test = [ 39 torch._foreach_add, 40 torch._foreach_mul, 41 torch._foreach_sub, 42 torch._foreach_div, 43 torch._foreach_maximum, 44 torch._foreach_minimum, 45 torch._foreach_clamp_max, 46 torch._foreach_clamp_min, 47 aten._foreach_copy, 48] 49 50un_ops_under_test = [ 51 torch._foreach_reciprocal, 52 torch._foreach_neg, 53 torch._foreach_sign, 54 torch._foreach_abs, 55 torch._foreach_sqrt, 56] 57compose_ops = [torch._foreach_addcdiv, torch._foreach_addcmul] 58all_ops = parametrize( 59 "op", bin_ops_under_test + un_ops_under_test, name_fn=lambda f: f.__name__ 60) 61bin_ops = parametrize("op", bin_ops_under_test, name_fn=lambda f: f.__name__) 62inplace_bin_ops = parametrize( 63 "op", inplace_bin_ops_under_test, name_fn=lambda f: f.__name__ 64) 65scalar_bin_ops = parametrize("op", bin_ops_under_test[:4], name_fn=lambda f: f.__name__) 66scalar_tensor_bin_ops = parametrize( 67 "op", bin_ops_under_test[:2], name_fn=lambda f: f.__name__ 68) 69decomp_ops = parametrize("op", compose_ops, name_fn=lambda f: f.__name__) 70 71 72def gen_args(op): 73 if op in un_ops_under_test: 74 return ( 75 torch.rand(10, 10, device="cuda:0"), 76 torch.rand(20, 20, device="cuda:0"), 77 ) 78 else: 79 return ( 80 torch.rand(10, 10, device="cuda:0"), 81 torch.rand(20, 20, device="cuda:0"), 82 torch.rand(10, 10, device="cuda:0"), 83 torch.rand(20, 20, device="cuda:0"), 84 ) 85 86 87@instantiate_parametrized_tests 88class ForeachTests(TestCase): 89 check_model_cuda = check_model_cuda 90 check_model_cpu = check_model 91 check_kernel_count = True 92 93 def setUp(self): 94 super().setUp() 95 torch._inductor.metrics.reset() 96 97 def tearDown(self): 98 super().tearDown() 99 torch._inductor.metrics.reset() 100 101 def _test_single_list(self, op): 102 if op in un_ops_under_test: 103 104 def fn(a0, a1): 105 return op([a0, a1]) 106 107 else: 108 109 def fn(a0, a1, b0, b1): 110 return op([a0, a1], [b0, b1]) 111 112 self.check_model_cuda( 113 fn, 114 gen_args(op), 115 ) 116 117 def _test_single_scalar(self, op): 118 def fn(a0, a1): 119 return op([a0, a1], 3.3) 120 121 self.check_model_cuda( 122 fn, 123 ( 124 torch.rand(10, 10, device="cuda:0"), 125 torch.rand(20, 20, device="cuda:0"), 126 ), 127 ) 128 129 def _test_single_scalar_tensor(self, op): 130 def fn(a0, a1): 131 return op([a0, a1], torch.tensor(3.3, device="cuda:0")) 132 133 self.check_model_cuda( 134 fn, 135 ( 136 torch.rand(10, 10, device="cuda:0"), 137 torch.rand(20, 20, device="cuda:0"), 138 ), 139 ) 140 141 # called in test_cuda_cpp_wrapper.py 142 @requires_cuda 143 def test_foreach_cpp_wrapper_cuda(self): 144 self._test_single_list(op=torch._foreach_add) 145 146 @requires_cuda 147 @all_ops 148 def test_single_list(self, op): 149 self._test_single_list(op) 150 self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) 151 152 @requires_cuda 153 @scalar_bin_ops 154 def test_single_scalar(self, op): 155 self._test_single_scalar(op) 156 self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) 157 158 @requires_cuda 159 @scalar_tensor_bin_ops 160 def test_single_scalar_tensor(self, op): 161 self._test_single_scalar_tensor(op) 162 self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) 163 164 @requires_cuda 165 @all_ops 166 def test_scheduler_fusion_list(self, op): 167 if op in un_ops_under_test: 168 169 def fn(a0, a1): 170 c = op([a0, a1]) 171 return torch._foreach_sqrt(c) 172 173 else: 174 175 def fn(a0, a1, b0, b1): 176 c = op([a0, a1], [b0, b1]) 177 return c, torch._foreach_add([a0, a1], c) 178 179 self.check_model_cuda( 180 fn, 181 gen_args(op), 182 ) 183 184 self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) 185 186 @requires_cuda 187 @scalar_bin_ops 188 def test_scheduler_fusion_scalar(self, op): 189 def fn(a0, a1): 190 c = op([a0, a1], 3.4) 191 return c, torch._foreach_add([a0, a1], c) 192 193 self.check_model_cuda( 194 fn, 195 ( 196 torch.rand(10, 10, device="cuda:0"), 197 torch.rand(20, 20, device="cuda:0"), 198 ), 199 ) 200 201 self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) 202 203 @requires_cuda 204 @scalar_bin_ops 205 def test_broadcasting(self, op): 206 def fn(a0, a1, b0, b1): 207 return op([a0, a1], [b0, b1]) 208 209 fn_opt = torch._dynamo.optimize()(fn) 210 211 inputs = ( 212 torch.rand(10, 1, device="cuda:0"), 213 torch.rand(20, 20, device="cuda:0"), 214 torch.rand(1, 10, device="cuda:0"), 215 torch.rand(20, 20, device="cuda:0"), 216 ) 217 actual = fn_opt(*inputs) 218 expected = fn(*inputs) 219 self.assertEqual(actual, expected) 220 self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) 221 222 @requires_cuda 223 @all_ops 224 def test_singleton_lists(self, op): 225 if op in un_ops_under_test: 226 227 def fn(a0): 228 return op([a0]) 229 230 args = (torch.rand(10, 10, device="cuda:0"),) 231 else: 232 233 def fn(a0, b0): 234 return op([a0], [b0]) 235 236 args = ( 237 torch.rand(10, 10, device="cuda:0"), 238 torch.rand(10, 10, device="cuda:0"), 239 ) 240 241 self.check_model_cuda( 242 fn, 243 args, 244 ) 245 246 self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) 247 248 @requires_cuda 249 @bin_ops 250 def test_type_promotion(self, op): 251 def fn(a0, a1, b0, b1): 252 return op([a0, a1], [b0, b1]) 253 254 fn_opt = torch._dynamo.optimize()(fn) 255 256 max32 = torch.iinfo(torch.int32).max 257 max64 = torch.iinfo(torch.int64).max 258 inputs = ( 259 torch.randint(max32, (10, 10), device="cuda:0", dtype=torch.int32), 260 torch.randint(max32, (20, 20), device="cuda:0", dtype=torch.int32), 261 torch.randint(max32, (10, 10), device="cuda:0", dtype=torch.int32), 262 torch.randint(max64, (20, 20), device="cuda:0", dtype=torch.int64), 263 ) 264 actual = fn_opt(*inputs) 265 expected = fn(*inputs) 266 self.assertEqual(actual, expected) 267 self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) 268 269 @requires_cuda 270 @scalar_bin_ops 271 def test_kernel_split_arg_limit_list(self, op): 272 # NB: foeach_copy won't pass this test because it will dce one set of buffers 273 274 def fn(a, b): 275 return op(a, b) 276 277 fn_opt = torch._dynamo.optimize()(fn) 278 279 max_args = 370 280 max_list_len = (max_args // 3) + 1 281 inputs = ( 282 [torch.rand(10, 10, device="cuda:0") for _ in range(max_list_len)], 283 [torch.rand(10, 10, device="cuda:0") for _ in range(max_list_len)], 284 ) 285 286 actual = fn_opt(*inputs) 287 expected = fn(*inputs) 288 self.assertEqual(actual, expected) 289 self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) 290 291 @requires_cuda 292 @scalar_bin_ops 293 @unittest.skip( 294 "Triton recursion depth exceeded: https://github.com/openai/triton/issues/1763" 295 ) 296 def test_kernel_split_arg_limit_scalar(self, op): 297 def fn(a): 298 return op(a, 3.3) 299 300 fn_opt = torch._dynamo.optimize()(fn) 301 302 max_args = 370 303 max_list_len = (max_args // 2) + 1 304 inputs = ([torch.rand(10, 10, device="cuda:0") for _ in range(max_list_len)],) 305 306 actual = fn_opt(*inputs) 307 expected = fn(*inputs) 308 self.assertEqual(actual, expected) 309 self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) 310 311 @requires_cuda 312 @bin_ops 313 def test_fusion_duplicate_buffer_list(self, op): 314 def fn(a0, a1, b0, b1): 315 c = op([a0, a1], [b0, b1]) 316 return op([a0, b0], [c[0], c[0]]) 317 318 self.check_model_cuda( 319 fn, 320 ( 321 torch.rand(10, 10, device="cuda:0"), 322 torch.rand(20, 20, device="cuda:0"), 323 torch.rand(10, 10, device="cuda:0"), 324 torch.rand(20, 20, device="cuda:0"), 325 ), 326 reference_in_float=False, 327 check_lowp=False, 328 ) 329 330 self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) 331 332 @requires_cuda 333 @all_ops 334 def test_non_foreach_consumer_list(self, op): 335 if op in un_ops_under_test: 336 337 def fn(a0, a1): 338 c = op([a0, a1]) 339 return torch.mul(c[0], a0) 340 341 else: 342 343 def fn(a0, a1, b0, b1): 344 c = op([a0, a1], [b0, b1]) 345 return torch.mul(c[0], a0) 346 347 self.check_model_cuda( 348 fn, 349 gen_args(op), 350 ) 351 352 self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) 353 354 @requires_cuda 355 @scalar_bin_ops 356 def test_non_foreach_consumer_scalar(self, op): 357 def fn(a0, a1): 358 c = op([a0, a1], 4.7) 359 return torch.mul(c[0], a0) 360 361 self.check_model_cuda( 362 fn, 363 ( 364 torch.rand(10, 10, device="cuda:0"), 365 torch.rand(20, 20, device="cuda:0"), 366 ), 367 ) 368 369 self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) 370 371 @requires_cuda 372 @all_ops 373 def test_non_foreach_producer_list(self, op): 374 if op in un_ops_under_test: 375 376 def fn(a0, a1): 377 c0 = torch.add(a0, a0) 378 c1 = torch.add(a1, a1) 379 return op([c0, c1]) 380 381 else: 382 383 def fn(a0, a1, b0, b1): 384 c0 = torch.add(a0, b0) 385 c1 = torch.add(a1, b1) 386 return op([a0, a1], [c0, c1]) 387 388 self.check_model_cuda( 389 fn, gen_args(op), reference_in_float=False, check_lowp=False 390 ) 391 392 self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) 393 394 @requires_cuda 395 @scalar_bin_ops 396 def test_non_foreach_producer_scalar(self, op): 397 def fn(a0, a1, b0, b1): 398 c0 = torch.mul(a0, b0) 399 c1 = torch.mul(a1, b1) 400 return op([c0, c1], 5.6) 401 402 self.check_model_cuda( 403 fn, 404 ( 405 torch.rand(10, 10, device="cuda:0"), 406 torch.rand(20, 20, device="cuda:0"), 407 torch.rand(10, 10, device="cuda:0"), 408 torch.rand(20, 20, device="cuda:0"), 409 ), 410 ) 411 412 self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) 413 414 @requires_cuda 415 @all_ops 416 def test_non_foreach_consumer_producer_list(self, op): 417 if op in un_ops_under_test: 418 419 def fn(a0, a1): 420 c0 = torch.add(a0, a0) 421 c1 = torch.mul(a1, a1) 422 d = op([c0, c1]) 423 e0 = torch.mul(d[0], a0) 424 e1 = torch.mul(d[1], a1) 425 return [e0, e1] 426 427 else: 428 429 def fn(a0, a1, b0, b1): 430 c0 = torch.add(a0, b0) 431 c1 = torch.add(a1, b1) 432 d = op([a0, a1], [c0, c1]) 433 e0 = torch.mul(d[0], a0) 434 e1 = torch.mul(d[1], a1) 435 return [e0, e1] 436 437 self.check_model_cuda( 438 fn, 439 gen_args(op), 440 reference_in_float=False, 441 check_lowp=False, 442 ) 443 444 self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) 445 446 @requires_cuda 447 @scalar_bin_ops 448 def test_non_foreach_consumer_producer_scalar(self, op): 449 def fn(a0, a1, b0, b1): 450 c0 = torch.add(a0, b0) 451 c1 = torch.add(a1, b1) 452 d = op([c0, c1], 5.8) 453 e0 = torch.mul(d[0], a0) 454 e1 = torch.mul(d[1], a1) 455 return [e0, e1] 456 457 self.check_model_cuda( 458 fn, 459 ( 460 torch.rand(10, 10, device="cuda:0"), 461 torch.rand(20, 20, device="cuda:0"), 462 torch.rand(10, 10, device="cuda:0"), 463 torch.rand(20, 20, device="cuda:0"), 464 ), 465 reference_in_float=False, 466 check_lowp=False, 467 ) 468 469 self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) 470 471 @requires_cuda 472 @bin_ops 473 @torch._dynamo.config.patch("automatic_dynamic_shapes", False) 474 @torch._dynamo.config.patch("assume_static_by_default", False) 475 def test_dynamic_shapes_fallback(self, op): 476 def fn(a0, a1, b0, b1): 477 return op([a0, a1], [b0, b1]) 478 479 inputs = ( 480 torch.rand(10, 10, device="cuda:0"), 481 torch.rand(20, 20, device="cuda:0"), 482 torch.rand(10, 10, device="cuda:0"), 483 torch.rand(20, 20, device="cuda:0"), 484 ) 485 486 self.check_model_cuda(fn, inputs) 487 488 self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) 489 490 @requires_cuda 491 @torch._dynamo.config.patch("automatic_dynamic_shapes", False) 492 @torch._dynamo.config.patch("assume_static_by_default", False) 493 @torch._inductor.config.patch("combo_kernel_foreach_dynamic_shapes", True) 494 def test_enable_dynamic_shapes_python_wrapper(self, op=torch._foreach_add): 495 def fn(a0, a1, b0, b1): 496 return op([a0, a1], [b0, b1]) 497 498 inputs = ( 499 torch.rand(10, 10, device="cuda:0"), 500 torch.rand(20, 20, device="cuda:0"), 501 torch.rand(10, 10, device="cuda:0"), 502 torch.rand(20, 20, device="cuda:0"), 503 ) 504 505 self.check_model_cuda(fn, inputs) 506 507 self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) 508 509 @requires_cuda 510 @torch._dynamo.config.patch("automatic_dynamic_shapes", False) 511 @torch._dynamo.config.patch("assume_static_by_default", False) 512 @torch._inductor.config.patch("combo_kernel_foreach_dynamic_shapes", True) 513 @torch._inductor.config.patch("cpp_wrapper", True) 514 def test_enable_dynamic_shapes_cpp_wrapper_cuda(self, op=torch._foreach_add): 515 def fn(a0, a1, b0, b1): 516 return op([a0, a1], [b0, b1]) 517 518 inputs = ( 519 torch.rand(10, 10, device="cuda:0"), 520 torch.rand(20, 20, device="cuda:0"), 521 torch.rand(10, 10, device="cuda:0"), 522 torch.rand(20, 20, device="cuda:0"), 523 ) 524 525 self.check_model_cuda(fn, inputs) 526 527 @unittest.skipIf(IS_FBCODE, "cpp compile not supported in fbcode") 528 @bin_ops 529 def test_cpu_cpp_fallback(self, op): 530 def fn(a0, a1, b0, b1): 531 return op([a0, a1], [b0, b1]) 532 533 inputs = ( 534 torch.rand(10, 10, device="cpu"), 535 torch.rand(20, 20, device="cpu"), 536 torch.rand(10, 10, device="cpu"), 537 torch.rand(20, 20, device="cpu"), 538 ) 539 540 self.check_model_cpu(fn, inputs) 541 542 self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) 543 544 @requires_cuda 545 @decomp_ops 546 def test_decomp(self, op): 547 def fn(a0, a1, b0, b1, c0, c1): 548 return op([a0, a1], [b0, b1], [c0, c1], value=0.5) 549 550 self.check_model_cuda( 551 fn, 552 ( 553 torch.rand(10, 10, device="cuda:0"), 554 torch.rand(20, 20, device="cuda:0"), 555 torch.rand(10, 10, device="cuda:0"), 556 torch.rand(20, 20, device="cuda:0"), 557 torch.rand(10, 10, device="cuda:0"), 558 torch.rand(20, 20, device="cuda:0"), 559 ), 560 ) 561 562 self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) 563 564 @requires_cuda 565 def test_fuse_concat(self): 566 def fn(x1, x2, x3, w1, w2, w3): 567 x = torch.stack([x1, x2, x3]) 568 w = torch.stack([w1, w2, w3]) 569 570 y = torch.bmm(x, w) 571 572 return y 573 574 x1 = torch.randn(5, 4).cuda() 575 x2 = x1 + 1 576 x3 = x1 + 2 577 w1 = torch.randn(4, 3).cuda() 578 w2 = w1 + 1 579 w3 = w1 + 2 580 581 args = (x1, x2, x3, w1, w2, w3) 582 583 self.check_model_cuda(fn, args) 584 585 self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) 586 587 @requires_cuda 588 def test_zero_elems(self): 589 def fn(a0, a1, b0, b1): 590 return torch._foreach_add([a0, a1], [b0, b1]) 591 592 self.check_model_cuda( 593 fn, 594 ( 595 torch.rand(0, device="cuda:0"), 596 torch.rand(10, 10, device="cuda:0"), 597 torch.rand(0, device="cuda:0"), 598 torch.rand(10, 10, device="cuda:0"), 599 ), 600 ) 601 602 self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) 603 604 @requires_cuda 605 @bin_ops 606 def test_2d_blocking(self, op): 607 def fn(a0, a1, b0, b1): 608 return op([a0, a1], [b0, b1]) 609 610 self.check_model_cuda( 611 fn, 612 ( 613 torch.rand(10, 40, device="cuda:0"), 614 torch.rand(10, 30, device="cuda:0"), 615 torch.rand(40, 10, device="cuda:0").t(), 616 torch.rand(30, 10, device="cuda:0").t(), 617 ), 618 ) 619 620 self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) 621 622 @requires_cuda 623 @bin_ops 624 def test_2d_blocking_partitioning(self, op): 625 def fn(a0, a1, b0, b1): 626 return op([a0, a1], [b0, b1]) 627 628 self.check_model_cuda( 629 fn, 630 ( 631 torch.rand(30, 20, device="cuda:0"), 632 torch.rand(40, 30, device="cuda:0"), 633 torch.rand(30, 20, device="cuda:0"), 634 torch.rand(30, 40, device="cuda:0").t(), 635 ), 636 ) 637 638 self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) 639 640 @requires_cuda 641 @bin_ops 642 def test_2d_blocking_partitioning_elems(self, op): 643 """2D blocking should be grouped by number of yelems""" 644 645 def fn(a0, a1, a2, b0, b1, b2): 646 return op([a0, a1, a2], [b0, b1, b2]) 647 648 self.check_model_cuda( 649 fn, 650 ( 651 torch.rand(10, 20, device="cuda:0"), 652 torch.rand(30, 20, device="cuda:0"), 653 torch.rand(10, 30, device="cuda:0"), 654 torch.rand(20, 10, device="cuda:0").t(), 655 torch.rand(20, 30, device="cuda:0").t(), 656 torch.rand(30, 10, device="cuda:0").t(), 657 ), 658 ) 659 660 self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) 661 662 @requires_cuda 663 @bin_ops 664 @torch._inductor.config.patch("combo_kernel_allow_mixed_sizes", 2) 665 def test_2d_blocking_partitioning_mixed_sizes(self, op): 666 """2D blocking with mixed sizes should group together""" 667 668 def fn(a0, a1, a2, b0, b1, b2): 669 return op([a0, a1, a2], [b0, b1, b2]) 670 671 self.check_model_cuda( 672 fn, 673 ( 674 torch.rand(10, 20, device="cuda:0"), 675 torch.rand(30, 20, device="cuda:0"), 676 torch.rand(10, 30, device="cuda:0"), 677 torch.rand(20, 10, device="cuda:0").t(), 678 torch.rand(20, 30, device="cuda:0").t(), 679 torch.rand(30, 10, device="cuda:0").t(), 680 ), 681 ) 682 683 self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) 684 685 @requires_cuda 686 @inplace_bin_ops 687 def test_reinplacing(self, op): 688 def fn(a0, a1, b0, b1): 689 op([a0, a1], [b0, b1]) 690 return [a0, a1] 691 692 inputs = ( 693 torch.rand(10, 10, device="cuda:0"), 694 torch.rand(20, 20, device="cuda:0"), 695 torch.rand(10, 10, device="cuda:0"), 696 torch.rand(20, 20, device="cuda:0"), 697 ) 698 699 self.check_model_cuda(fn, inputs, check_lowp=False) 700 701 self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) 702 703 @requires_cuda 704 @inplace_bin_ops 705 def test_reinplacing_mut_before(self, op): 706 def fn(a0, a1, b0, b1): 707 a0.add_(torch.ones(10, 10, device="cuda:0")) 708 op([a0, a1], [b0, b1]) 709 return [a0, a1] 710 711 inputs = ( 712 torch.rand(10, 10, device="cuda:0"), 713 torch.rand(20, 20, device="cuda:0"), 714 torch.rand(10, 10, device="cuda:0"), 715 torch.rand(20, 20, device="cuda:0"), 716 ) 717 718 self.check_model_cuda(fn, inputs, check_lowp=False) 719 720 self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) 721 722 @requires_cuda 723 @inplace_bin_ops 724 def test_reinplacing_mut_after(self, op): 725 def fn(a0, a1, b0, b1): 726 op([a0, a1], [b0, b1]) 727 a0.add_(torch.ones(10, 10, device="cuda:0")) 728 return [a0, a1] 729 730 inputs = ( 731 torch.rand(10, 10, device="cuda:0"), 732 torch.rand(20, 20, device="cuda:0"), 733 torch.rand(10, 10, device="cuda:0"), 734 torch.rand(20, 20, device="cuda:0"), 735 ) 736 737 self.check_model_cuda(fn, inputs, check_lowp=False) 738 739 self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) 740 741 @requires_cuda 742 def test_multi_device(self): 743 def test_foreach_add(a0, a1, b0, b1): 744 return torch._foreach_add([a0, a1], [b0, b1]) 745 746 inps = [ 747 torch.ones(10, 10, device="cuda"), 748 torch.ones(20, 20, device="cpu"), 749 torch.zeros(10, 10, device="cuda"), 750 torch.zeros(20, 20, device="cpu"), 751 ] 752 753 out_eager = test_foreach_add(*inps) 754 out_compiled = torch.compile(test_foreach_add)(*inps) 755 756 self.assertEqual(out_eager, out_compiled) 757 self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) 758 759 @requires_cuda 760 def test_aliasing(self): 761 def test_foreach_add(a0, a1, a2, b0, b1, b2): 762 return torch._foreach_add_([a0, a1, a2], [b0, b1, b2]) 763 764 input = torch.ones(10, 10, device="cuda") 765 input2 = torch.ones(10, 10, device="cuda") 766 inps = [ 767 input, 768 input.view(10, 10), 769 input.view(10, 10), 770 input2, 771 input2.view(10, 10), 772 input2.view(10, 10), 773 ] 774 775 out_eager = test_foreach_add(*inps) 776 out_compiled = torch.compile(test_foreach_add)(*inps) 777 778 self.assertEqual(out_eager, out_compiled) 779 self.assertEqual(torch._inductor.metrics.generated_kernel_count, 4) 780 781 @requires_cuda 782 @torch._inductor.config.patch("combo_kernel_allow_mixed_sizes", 1) 783 def test_2d_block_no_mixed_sizes_no_mask(self): 784 """2D blocking with no mixed sizes constant mask""" 785 786 def fn(a0, a1, a2, b0, b1, b2): 787 return torch._foreach_add([a0, a1, a2], [b0, b1, b2]) 788 789 self.check_model_cuda( 790 fn, 791 ( 792 torch.rand(1024, 2048, device="cuda:0"), 793 torch.rand(2048, 2048, device="cuda:0"), 794 torch.rand(1024, 2048, device="cuda:0"), 795 torch.rand(2048, 1024, device="cuda:0").t(), 796 torch.rand(2048, 2048, device="cuda:0").t(), 797 torch.rand(2048, 1024, device="cuda:0").t(), 798 ), 799 ) 800 801 self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) 802 803 @requires_cuda 804 @torch._inductor.config.patch("combo_kernel_allow_mixed_sizes", 2) 805 def test_2d_block_mixed_sizes_with_mask(self): 806 """2D blocking with mixed sizes should have mask""" 807 808 def fn(a0, a1, a2, b0, b1, b2): 809 return torch._foreach_add([a0, a1, a2], [b0, b1, b2]) 810 811 self.check_model_cuda( 812 fn, 813 ( 814 torch.rand(1024, 2048, device="cuda:0"), 815 torch.rand(2048, 2048, device="cuda:0"), 816 torch.rand(1024, 2048, device="cuda:0"), 817 torch.rand(2048, 1024, device="cuda:0").t(), 818 torch.rand(2048, 2048, device="cuda:0").t(), 819 torch.rand(2048, 1024, device="cuda:0").t(), 820 ), 821 ) 822 823 self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) 824 825 826if __name__ == "__main__": 827 from torch._inductor.test_case import run_tests 828 829 if HAS_CPU or HAS_CUDA: 830 run_tests(needs="filelock") 831