xref: /aosp_15_r20/external/pytorch/test/inductor/test_foreach.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2
3import sys
4import unittest
5
6import torch
7import torch._inductor
8from torch._inductor.test_case import TestCase
9from torch.testing._internal.common_utils import (
10    instantiate_parametrized_tests,
11    IS_FBCODE,
12    parametrize,
13)
14from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
15from torch.testing._internal.triton_utils import requires_cuda
16
17
18aten = torch.ops.aten
19
20try:
21    try:
22        from .test_torchinductor import check_model, check_model_cuda
23    except ImportError:
24        from test_torchinductor import check_model, check_model_cuda
25except (unittest.SkipTest, ImportError) as e:
26    sys.stderr.write(f"{type(e)}: {e}\n")
27    if __name__ == "__main__":
28        sys.exit(0)
29    raise
30
31inplace_bin_ops_under_test = [
32    torch._foreach_add_,
33    torch._foreach_mul_,
34    torch._foreach_sub_,
35    torch._foreach_div_,
36]
37
38bin_ops_under_test = [
39    torch._foreach_add,
40    torch._foreach_mul,
41    torch._foreach_sub,
42    torch._foreach_div,
43    torch._foreach_maximum,
44    torch._foreach_minimum,
45    torch._foreach_clamp_max,
46    torch._foreach_clamp_min,
47    aten._foreach_copy,
48]
49
50un_ops_under_test = [
51    torch._foreach_reciprocal,
52    torch._foreach_neg,
53    torch._foreach_sign,
54    torch._foreach_abs,
55    torch._foreach_sqrt,
56]
57compose_ops = [torch._foreach_addcdiv, torch._foreach_addcmul]
58all_ops = parametrize(
59    "op", bin_ops_under_test + un_ops_under_test, name_fn=lambda f: f.__name__
60)
61bin_ops = parametrize("op", bin_ops_under_test, name_fn=lambda f: f.__name__)
62inplace_bin_ops = parametrize(
63    "op", inplace_bin_ops_under_test, name_fn=lambda f: f.__name__
64)
65scalar_bin_ops = parametrize("op", bin_ops_under_test[:4], name_fn=lambda f: f.__name__)
66scalar_tensor_bin_ops = parametrize(
67    "op", bin_ops_under_test[:2], name_fn=lambda f: f.__name__
68)
69decomp_ops = parametrize("op", compose_ops, name_fn=lambda f: f.__name__)
70
71
72def gen_args(op):
73    if op in un_ops_under_test:
74        return (
75            torch.rand(10, 10, device="cuda:0"),
76            torch.rand(20, 20, device="cuda:0"),
77        )
78    else:
79        return (
80            torch.rand(10, 10, device="cuda:0"),
81            torch.rand(20, 20, device="cuda:0"),
82            torch.rand(10, 10, device="cuda:0"),
83            torch.rand(20, 20, device="cuda:0"),
84        )
85
86
87@instantiate_parametrized_tests
88class ForeachTests(TestCase):
89    check_model_cuda = check_model_cuda
90    check_model_cpu = check_model
91    check_kernel_count = True
92
93    def setUp(self):
94        super().setUp()
95        torch._inductor.metrics.reset()
96
97    def tearDown(self):
98        super().tearDown()
99        torch._inductor.metrics.reset()
100
101    def _test_single_list(self, op):
102        if op in un_ops_under_test:
103
104            def fn(a0, a1):
105                return op([a0, a1])
106
107        else:
108
109            def fn(a0, a1, b0, b1):
110                return op([a0, a1], [b0, b1])
111
112        self.check_model_cuda(
113            fn,
114            gen_args(op),
115        )
116
117    def _test_single_scalar(self, op):
118        def fn(a0, a1):
119            return op([a0, a1], 3.3)
120
121        self.check_model_cuda(
122            fn,
123            (
124                torch.rand(10, 10, device="cuda:0"),
125                torch.rand(20, 20, device="cuda:0"),
126            ),
127        )
128
129    def _test_single_scalar_tensor(self, op):
130        def fn(a0, a1):
131            return op([a0, a1], torch.tensor(3.3, device="cuda:0"))
132
133        self.check_model_cuda(
134            fn,
135            (
136                torch.rand(10, 10, device="cuda:0"),
137                torch.rand(20, 20, device="cuda:0"),
138            ),
139        )
140
141    # called in test_cuda_cpp_wrapper.py
142    @requires_cuda
143    def test_foreach_cpp_wrapper_cuda(self):
144        self._test_single_list(op=torch._foreach_add)
145
146    @requires_cuda
147    @all_ops
148    def test_single_list(self, op):
149        self._test_single_list(op)
150        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
151
152    @requires_cuda
153    @scalar_bin_ops
154    def test_single_scalar(self, op):
155        self._test_single_scalar(op)
156        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
157
158    @requires_cuda
159    @scalar_tensor_bin_ops
160    def test_single_scalar_tensor(self, op):
161        self._test_single_scalar_tensor(op)
162        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
163
164    @requires_cuda
165    @all_ops
166    def test_scheduler_fusion_list(self, op):
167        if op in un_ops_under_test:
168
169            def fn(a0, a1):
170                c = op([a0, a1])
171                return torch._foreach_sqrt(c)
172
173        else:
174
175            def fn(a0, a1, b0, b1):
176                c = op([a0, a1], [b0, b1])
177                return c, torch._foreach_add([a0, a1], c)
178
179        self.check_model_cuda(
180            fn,
181            gen_args(op),
182        )
183
184        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
185
186    @requires_cuda
187    @scalar_bin_ops
188    def test_scheduler_fusion_scalar(self, op):
189        def fn(a0, a1):
190            c = op([a0, a1], 3.4)
191            return c, torch._foreach_add([a0, a1], c)
192
193        self.check_model_cuda(
194            fn,
195            (
196                torch.rand(10, 10, device="cuda:0"),
197                torch.rand(20, 20, device="cuda:0"),
198            ),
199        )
200
201        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
202
203    @requires_cuda
204    @scalar_bin_ops
205    def test_broadcasting(self, op):
206        def fn(a0, a1, b0, b1):
207            return op([a0, a1], [b0, b1])
208
209        fn_opt = torch._dynamo.optimize()(fn)
210
211        inputs = (
212            torch.rand(10, 1, device="cuda:0"),
213            torch.rand(20, 20, device="cuda:0"),
214            torch.rand(1, 10, device="cuda:0"),
215            torch.rand(20, 20, device="cuda:0"),
216        )
217        actual = fn_opt(*inputs)
218        expected = fn(*inputs)
219        self.assertEqual(actual, expected)
220        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
221
222    @requires_cuda
223    @all_ops
224    def test_singleton_lists(self, op):
225        if op in un_ops_under_test:
226
227            def fn(a0):
228                return op([a0])
229
230            args = (torch.rand(10, 10, device="cuda:0"),)
231        else:
232
233            def fn(a0, b0):
234                return op([a0], [b0])
235
236            args = (
237                torch.rand(10, 10, device="cuda:0"),
238                torch.rand(10, 10, device="cuda:0"),
239            )
240
241        self.check_model_cuda(
242            fn,
243            args,
244        )
245
246        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
247
248    @requires_cuda
249    @bin_ops
250    def test_type_promotion(self, op):
251        def fn(a0, a1, b0, b1):
252            return op([a0, a1], [b0, b1])
253
254        fn_opt = torch._dynamo.optimize()(fn)
255
256        max32 = torch.iinfo(torch.int32).max
257        max64 = torch.iinfo(torch.int64).max
258        inputs = (
259            torch.randint(max32, (10, 10), device="cuda:0", dtype=torch.int32),
260            torch.randint(max32, (20, 20), device="cuda:0", dtype=torch.int32),
261            torch.randint(max32, (10, 10), device="cuda:0", dtype=torch.int32),
262            torch.randint(max64, (20, 20), device="cuda:0", dtype=torch.int64),
263        )
264        actual = fn_opt(*inputs)
265        expected = fn(*inputs)
266        self.assertEqual(actual, expected)
267        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
268
269    @requires_cuda
270    @scalar_bin_ops
271    def test_kernel_split_arg_limit_list(self, op):
272        # NB: foeach_copy won't pass this test because it will dce one set of buffers
273
274        def fn(a, b):
275            return op(a, b)
276
277        fn_opt = torch._dynamo.optimize()(fn)
278
279        max_args = 370
280        max_list_len = (max_args // 3) + 1
281        inputs = (
282            [torch.rand(10, 10, device="cuda:0") for _ in range(max_list_len)],
283            [torch.rand(10, 10, device="cuda:0") for _ in range(max_list_len)],
284        )
285
286        actual = fn_opt(*inputs)
287        expected = fn(*inputs)
288        self.assertEqual(actual, expected)
289        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2)
290
291    @requires_cuda
292    @scalar_bin_ops
293    @unittest.skip(
294        "Triton recursion depth exceeded: https://github.com/openai/triton/issues/1763"
295    )
296    def test_kernel_split_arg_limit_scalar(self, op):
297        def fn(a):
298            return op(a, 3.3)
299
300        fn_opt = torch._dynamo.optimize()(fn)
301
302        max_args = 370
303        max_list_len = (max_args // 2) + 1
304        inputs = ([torch.rand(10, 10, device="cuda:0") for _ in range(max_list_len)],)
305
306        actual = fn_opt(*inputs)
307        expected = fn(*inputs)
308        self.assertEqual(actual, expected)
309        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2)
310
311    @requires_cuda
312    @bin_ops
313    def test_fusion_duplicate_buffer_list(self, op):
314        def fn(a0, a1, b0, b1):
315            c = op([a0, a1], [b0, b1])
316            return op([a0, b0], [c[0], c[0]])
317
318        self.check_model_cuda(
319            fn,
320            (
321                torch.rand(10, 10, device="cuda:0"),
322                torch.rand(20, 20, device="cuda:0"),
323                torch.rand(10, 10, device="cuda:0"),
324                torch.rand(20, 20, device="cuda:0"),
325            ),
326            reference_in_float=False,
327            check_lowp=False,
328        )
329
330        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
331
332    @requires_cuda
333    @all_ops
334    def test_non_foreach_consumer_list(self, op):
335        if op in un_ops_under_test:
336
337            def fn(a0, a1):
338                c = op([a0, a1])
339                return torch.mul(c[0], a0)
340
341        else:
342
343            def fn(a0, a1, b0, b1):
344                c = op([a0, a1], [b0, b1])
345                return torch.mul(c[0], a0)
346
347        self.check_model_cuda(
348            fn,
349            gen_args(op),
350        )
351
352        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
353
354    @requires_cuda
355    @scalar_bin_ops
356    def test_non_foreach_consumer_scalar(self, op):
357        def fn(a0, a1):
358            c = op([a0, a1], 4.7)
359            return torch.mul(c[0], a0)
360
361        self.check_model_cuda(
362            fn,
363            (
364                torch.rand(10, 10, device="cuda:0"),
365                torch.rand(20, 20, device="cuda:0"),
366            ),
367        )
368
369        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
370
371    @requires_cuda
372    @all_ops
373    def test_non_foreach_producer_list(self, op):
374        if op in un_ops_under_test:
375
376            def fn(a0, a1):
377                c0 = torch.add(a0, a0)
378                c1 = torch.add(a1, a1)
379                return op([c0, c1])
380
381        else:
382
383            def fn(a0, a1, b0, b1):
384                c0 = torch.add(a0, b0)
385                c1 = torch.add(a1, b1)
386                return op([a0, a1], [c0, c1])
387
388        self.check_model_cuda(
389            fn, gen_args(op), reference_in_float=False, check_lowp=False
390        )
391
392        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
393
394    @requires_cuda
395    @scalar_bin_ops
396    def test_non_foreach_producer_scalar(self, op):
397        def fn(a0, a1, b0, b1):
398            c0 = torch.mul(a0, b0)
399            c1 = torch.mul(a1, b1)
400            return op([c0, c1], 5.6)
401
402        self.check_model_cuda(
403            fn,
404            (
405                torch.rand(10, 10, device="cuda:0"),
406                torch.rand(20, 20, device="cuda:0"),
407                torch.rand(10, 10, device="cuda:0"),
408                torch.rand(20, 20, device="cuda:0"),
409            ),
410        )
411
412        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
413
414    @requires_cuda
415    @all_ops
416    def test_non_foreach_consumer_producer_list(self, op):
417        if op in un_ops_under_test:
418
419            def fn(a0, a1):
420                c0 = torch.add(a0, a0)
421                c1 = torch.mul(a1, a1)
422                d = op([c0, c1])
423                e0 = torch.mul(d[0], a0)
424                e1 = torch.mul(d[1], a1)
425                return [e0, e1]
426
427        else:
428
429            def fn(a0, a1, b0, b1):
430                c0 = torch.add(a0, b0)
431                c1 = torch.add(a1, b1)
432                d = op([a0, a1], [c0, c1])
433                e0 = torch.mul(d[0], a0)
434                e1 = torch.mul(d[1], a1)
435                return [e0, e1]
436
437        self.check_model_cuda(
438            fn,
439            gen_args(op),
440            reference_in_float=False,
441            check_lowp=False,
442        )
443
444        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
445
446    @requires_cuda
447    @scalar_bin_ops
448    def test_non_foreach_consumer_producer_scalar(self, op):
449        def fn(a0, a1, b0, b1):
450            c0 = torch.add(a0, b0)
451            c1 = torch.add(a1, b1)
452            d = op([c0, c1], 5.8)
453            e0 = torch.mul(d[0], a0)
454            e1 = torch.mul(d[1], a1)
455            return [e0, e1]
456
457        self.check_model_cuda(
458            fn,
459            (
460                torch.rand(10, 10, device="cuda:0"),
461                torch.rand(20, 20, device="cuda:0"),
462                torch.rand(10, 10, device="cuda:0"),
463                torch.rand(20, 20, device="cuda:0"),
464            ),
465            reference_in_float=False,
466            check_lowp=False,
467        )
468
469        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
470
471    @requires_cuda
472    @bin_ops
473    @torch._dynamo.config.patch("automatic_dynamic_shapes", False)
474    @torch._dynamo.config.patch("assume_static_by_default", False)
475    def test_dynamic_shapes_fallback(self, op):
476        def fn(a0, a1, b0, b1):
477            return op([a0, a1], [b0, b1])
478
479        inputs = (
480            torch.rand(10, 10, device="cuda:0"),
481            torch.rand(20, 20, device="cuda:0"),
482            torch.rand(10, 10, device="cuda:0"),
483            torch.rand(20, 20, device="cuda:0"),
484        )
485
486        self.check_model_cuda(fn, inputs)
487
488        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2)
489
490    @requires_cuda
491    @torch._dynamo.config.patch("automatic_dynamic_shapes", False)
492    @torch._dynamo.config.patch("assume_static_by_default", False)
493    @torch._inductor.config.patch("combo_kernel_foreach_dynamic_shapes", True)
494    def test_enable_dynamic_shapes_python_wrapper(self, op=torch._foreach_add):
495        def fn(a0, a1, b0, b1):
496            return op([a0, a1], [b0, b1])
497
498        inputs = (
499            torch.rand(10, 10, device="cuda:0"),
500            torch.rand(20, 20, device="cuda:0"),
501            torch.rand(10, 10, device="cuda:0"),
502            torch.rand(20, 20, device="cuda:0"),
503        )
504
505        self.check_model_cuda(fn, inputs)
506
507        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
508
509    @requires_cuda
510    @torch._dynamo.config.patch("automatic_dynamic_shapes", False)
511    @torch._dynamo.config.patch("assume_static_by_default", False)
512    @torch._inductor.config.patch("combo_kernel_foreach_dynamic_shapes", True)
513    @torch._inductor.config.patch("cpp_wrapper", True)
514    def test_enable_dynamic_shapes_cpp_wrapper_cuda(self, op=torch._foreach_add):
515        def fn(a0, a1, b0, b1):
516            return op([a0, a1], [b0, b1])
517
518        inputs = (
519            torch.rand(10, 10, device="cuda:0"),
520            torch.rand(20, 20, device="cuda:0"),
521            torch.rand(10, 10, device="cuda:0"),
522            torch.rand(20, 20, device="cuda:0"),
523        )
524
525        self.check_model_cuda(fn, inputs)
526
527    @unittest.skipIf(IS_FBCODE, "cpp compile not supported in fbcode")
528    @bin_ops
529    def test_cpu_cpp_fallback(self, op):
530        def fn(a0, a1, b0, b1):
531            return op([a0, a1], [b0, b1])
532
533        inputs = (
534            torch.rand(10, 10, device="cpu"),
535            torch.rand(20, 20, device="cpu"),
536            torch.rand(10, 10, device="cpu"),
537            torch.rand(20, 20, device="cpu"),
538        )
539
540        self.check_model_cpu(fn, inputs)
541
542        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2)
543
544    @requires_cuda
545    @decomp_ops
546    def test_decomp(self, op):
547        def fn(a0, a1, b0, b1, c0, c1):
548            return op([a0, a1], [b0, b1], [c0, c1], value=0.5)
549
550        self.check_model_cuda(
551            fn,
552            (
553                torch.rand(10, 10, device="cuda:0"),
554                torch.rand(20, 20, device="cuda:0"),
555                torch.rand(10, 10, device="cuda:0"),
556                torch.rand(20, 20, device="cuda:0"),
557                torch.rand(10, 10, device="cuda:0"),
558                torch.rand(20, 20, device="cuda:0"),
559            ),
560        )
561
562        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
563
564    @requires_cuda
565    def test_fuse_concat(self):
566        def fn(x1, x2, x3, w1, w2, w3):
567            x = torch.stack([x1, x2, x3])
568            w = torch.stack([w1, w2, w3])
569
570            y = torch.bmm(x, w)
571
572            return y
573
574        x1 = torch.randn(5, 4).cuda()
575        x2 = x1 + 1
576        x3 = x1 + 2
577        w1 = torch.randn(4, 3).cuda()
578        w2 = w1 + 1
579        w3 = w1 + 2
580
581        args = (x1, x2, x3, w1, w2, w3)
582
583        self.check_model_cuda(fn, args)
584
585        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2)
586
587    @requires_cuda
588    def test_zero_elems(self):
589        def fn(a0, a1, b0, b1):
590            return torch._foreach_add([a0, a1], [b0, b1])
591
592        self.check_model_cuda(
593            fn,
594            (
595                torch.rand(0, device="cuda:0"),
596                torch.rand(10, 10, device="cuda:0"),
597                torch.rand(0, device="cuda:0"),
598                torch.rand(10, 10, device="cuda:0"),
599            ),
600        )
601
602        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
603
604    @requires_cuda
605    @bin_ops
606    def test_2d_blocking(self, op):
607        def fn(a0, a1, b0, b1):
608            return op([a0, a1], [b0, b1])
609
610        self.check_model_cuda(
611            fn,
612            (
613                torch.rand(10, 40, device="cuda:0"),
614                torch.rand(10, 30, device="cuda:0"),
615                torch.rand(40, 10, device="cuda:0").t(),
616                torch.rand(30, 10, device="cuda:0").t(),
617            ),
618        )
619
620        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
621
622    @requires_cuda
623    @bin_ops
624    def test_2d_blocking_partitioning(self, op):
625        def fn(a0, a1, b0, b1):
626            return op([a0, a1], [b0, b1])
627
628        self.check_model_cuda(
629            fn,
630            (
631                torch.rand(30, 20, device="cuda:0"),
632                torch.rand(40, 30, device="cuda:0"),
633                torch.rand(30, 20, device="cuda:0"),
634                torch.rand(30, 40, device="cuda:0").t(),
635            ),
636        )
637
638        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2)
639
640    @requires_cuda
641    @bin_ops
642    def test_2d_blocking_partitioning_elems(self, op):
643        """2D blocking should be grouped by number of yelems"""
644
645        def fn(a0, a1, a2, b0, b1, b2):
646            return op([a0, a1, a2], [b0, b1, b2])
647
648        self.check_model_cuda(
649            fn,
650            (
651                torch.rand(10, 20, device="cuda:0"),
652                torch.rand(30, 20, device="cuda:0"),
653                torch.rand(10, 30, device="cuda:0"),
654                torch.rand(20, 10, device="cuda:0").t(),
655                torch.rand(20, 30, device="cuda:0").t(),
656                torch.rand(30, 10, device="cuda:0").t(),
657            ),
658        )
659
660        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2)
661
662    @requires_cuda
663    @bin_ops
664    @torch._inductor.config.patch("combo_kernel_allow_mixed_sizes", 2)
665    def test_2d_blocking_partitioning_mixed_sizes(self, op):
666        """2D blocking with mixed sizes should group together"""
667
668        def fn(a0, a1, a2, b0, b1, b2):
669            return op([a0, a1, a2], [b0, b1, b2])
670
671        self.check_model_cuda(
672            fn,
673            (
674                torch.rand(10, 20, device="cuda:0"),
675                torch.rand(30, 20, device="cuda:0"),
676                torch.rand(10, 30, device="cuda:0"),
677                torch.rand(20, 10, device="cuda:0").t(),
678                torch.rand(20, 30, device="cuda:0").t(),
679                torch.rand(30, 10, device="cuda:0").t(),
680            ),
681        )
682
683        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
684
685    @requires_cuda
686    @inplace_bin_ops
687    def test_reinplacing(self, op):
688        def fn(a0, a1, b0, b1):
689            op([a0, a1], [b0, b1])
690            return [a0, a1]
691
692        inputs = (
693            torch.rand(10, 10, device="cuda:0"),
694            torch.rand(20, 20, device="cuda:0"),
695            torch.rand(10, 10, device="cuda:0"),
696            torch.rand(20, 20, device="cuda:0"),
697        )
698
699        self.check_model_cuda(fn, inputs, check_lowp=False)
700
701        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
702
703    @requires_cuda
704    @inplace_bin_ops
705    def test_reinplacing_mut_before(self, op):
706        def fn(a0, a1, b0, b1):
707            a0.add_(torch.ones(10, 10, device="cuda:0"))
708            op([a0, a1], [b0, b1])
709            return [a0, a1]
710
711        inputs = (
712            torch.rand(10, 10, device="cuda:0"),
713            torch.rand(20, 20, device="cuda:0"),
714            torch.rand(10, 10, device="cuda:0"),
715            torch.rand(20, 20, device="cuda:0"),
716        )
717
718        self.check_model_cuda(fn, inputs, check_lowp=False)
719
720        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
721
722    @requires_cuda
723    @inplace_bin_ops
724    def test_reinplacing_mut_after(self, op):
725        def fn(a0, a1, b0, b1):
726            op([a0, a1], [b0, b1])
727            a0.add_(torch.ones(10, 10, device="cuda:0"))
728            return [a0, a1]
729
730        inputs = (
731            torch.rand(10, 10, device="cuda:0"),
732            torch.rand(20, 20, device="cuda:0"),
733            torch.rand(10, 10, device="cuda:0"),
734            torch.rand(20, 20, device="cuda:0"),
735        )
736
737        self.check_model_cuda(fn, inputs, check_lowp=False)
738
739        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
740
741    @requires_cuda
742    def test_multi_device(self):
743        def test_foreach_add(a0, a1, b0, b1):
744            return torch._foreach_add([a0, a1], [b0, b1])
745
746        inps = [
747            torch.ones(10, 10, device="cuda"),
748            torch.ones(20, 20, device="cpu"),
749            torch.zeros(10, 10, device="cuda"),
750            torch.zeros(20, 20, device="cpu"),
751        ]
752
753        out_eager = test_foreach_add(*inps)
754        out_compiled = torch.compile(test_foreach_add)(*inps)
755
756        self.assertEqual(out_eager, out_compiled)
757        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2)
758
759    @requires_cuda
760    def test_aliasing(self):
761        def test_foreach_add(a0, a1, a2, b0, b1, b2):
762            return torch._foreach_add_([a0, a1, a2], [b0, b1, b2])
763
764        input = torch.ones(10, 10, device="cuda")
765        input2 = torch.ones(10, 10, device="cuda")
766        inps = [
767            input,
768            input.view(10, 10),
769            input.view(10, 10),
770            input2,
771            input2.view(10, 10),
772            input2.view(10, 10),
773        ]
774
775        out_eager = test_foreach_add(*inps)
776        out_compiled = torch.compile(test_foreach_add)(*inps)
777
778        self.assertEqual(out_eager, out_compiled)
779        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 4)
780
781    @requires_cuda
782    @torch._inductor.config.patch("combo_kernel_allow_mixed_sizes", 1)
783    def test_2d_block_no_mixed_sizes_no_mask(self):
784        """2D blocking with no mixed sizes constant mask"""
785
786        def fn(a0, a1, a2, b0, b1, b2):
787            return torch._foreach_add([a0, a1, a2], [b0, b1, b2])
788
789        self.check_model_cuda(
790            fn,
791            (
792                torch.rand(1024, 2048, device="cuda:0"),
793                torch.rand(2048, 2048, device="cuda:0"),
794                torch.rand(1024, 2048, device="cuda:0"),
795                torch.rand(2048, 1024, device="cuda:0").t(),
796                torch.rand(2048, 2048, device="cuda:0").t(),
797                torch.rand(2048, 1024, device="cuda:0").t(),
798            ),
799        )
800
801        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2)
802
803    @requires_cuda
804    @torch._inductor.config.patch("combo_kernel_allow_mixed_sizes", 2)
805    def test_2d_block_mixed_sizes_with_mask(self):
806        """2D blocking with mixed sizes should have mask"""
807
808        def fn(a0, a1, a2, b0, b1, b2):
809            return torch._foreach_add([a0, a1, a2], [b0, b1, b2])
810
811        self.check_model_cuda(
812            fn,
813            (
814                torch.rand(1024, 2048, device="cuda:0"),
815                torch.rand(2048, 2048, device="cuda:0"),
816                torch.rand(1024, 2048, device="cuda:0"),
817                torch.rand(2048, 1024, device="cuda:0").t(),
818                torch.rand(2048, 2048, device="cuda:0").t(),
819                torch.rand(2048, 1024, device="cuda:0").t(),
820            ),
821        )
822
823        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
824
825
826if __name__ == "__main__":
827    from torch._inductor.test_case import run_tests
828
829    if HAS_CPU or HAS_CUDA:
830        run_tests(needs="filelock")
831