xref: /aosp_15_r20/external/pytorch/test/inductor/test_benchmark_fusion.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2import math
3import os
4import sys
5
6import torch
7from torch._inductor.test_case import TestCase as InductorTestCase
8from torch._inductor.test_operators import realize
9from torch._inductor.utils import fresh_inductor_cache, is_big_gpu, run_and_get_code
10from torch.testing import FileCheck
11from torch.testing._internal.common_utils import slowTest, TEST_WITH_ASAN
12from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
13
14
15# Make the helper files in test/ importable
16pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
17sys.path.append(pytorch_test_dir)
18
19import contextlib
20import unittest
21
22from inductor.test_torchinductor import check_model, check_model_cuda, copy_tests
23from torch._inductor import config
24from torch._inductor.scheduler import Scheduler
25
26
27class TestCase(InductorTestCase):
28    @classmethod
29    def setUpClass(cls):
30        super().setUpClass()
31        cls._stack = contextlib.ExitStack()
32        cls._stack.enter_context(
33            config.patch(
34                {
35                    "benchmark_kernel": True,
36                    "benchmark_fusion": True,
37                }
38            )
39        )
40
41    @classmethod
42    def tearDownClass(cls):
43        cls._stack.close()
44        super().tearDownClass()
45
46
47class BenchmarkFusionTestTemplate:
48    def test_softmax(self):
49        def f(x):
50            return torch.nn.functional.softmax(x, dim=-1)
51
52        self.common(f, (torch.rand(2, 8192),))
53
54    @slowTest
55    def test_resnet18(self):
56        import torchvision
57
58        model = torchvision.models.resnet18()
59        model.eval()
60        batch_size = 16
61        inputs = (torch.randn((batch_size, 3, 224, 224)),)
62        self.common(model, inputs, atol=1e-2, rtol=1e-2)
63
64    def test_register_spills(self):
65        """
66        The test can potentially trigger register spills
67        """
68        old_benchmark_fn = Scheduler.benchmark_fused_nodes
69
70        def new_benchmark_fn(scheduler, nodes):
71            """
72            We override Scheduler.benchmark_fused_nodes to return latency 1.0
73            if there are no register spills. Without this, we may not able to
74            test the code path handling register spilling because before register
75            start spilling, the related fusion may have already been skipped
76            due to longer lantency.
77            """
78            ms, path = old_benchmark_fn(scheduler, nodes)
79            if not math.isinf(ms):
80                ms = 1.0
81            return ms, path
82
83        # Disable dynamic_scale_rblock to make it easier to trigger register
84        # spilling.
85        with unittest.mock.patch.object(
86            Scheduler, "benchmark_fused_nodes", new_benchmark_fn
87        ), config.patch("dynamic_scale_rblock", False):
88            S = 512
89
90            def f(*inputs):
91                inputs = list(inputs)
92                outputs = []
93                out = torch.zeros(S, device=self.device)
94                for x in inputs:
95                    x = x * 2
96                    x = x + 1
97                    x = x.sum(dim=-1)
98                    outputs.append(x)
99                    out = out + x
100                return outputs, out
101
102            N = int(os.environ.get("NINP", "30"))
103            inputs = [torch.randn(S, 2560, device=self.device) for _ in range(N)]
104            opt_f = torch.compile(f)
105            opt_f(*inputs)
106
107    def test_foreach_kernel(self):
108        """
109        Benchmark fusion should skip benchmarking kernels involves foreach kernel
110        for now. Without the skipping logic, `codegen_node_schedule` may fail.
111        """
112        a = torch.randn(1024, 256, device=self.device)
113        b = torch.randn(1024, 512, device=self.device)
114
115        def f(a, b):
116            a, b = torch._foreach_abs([a, b])
117            return a + 1, b + 2
118
119        self.common(f, (a, b))
120
121    @torch._inductor.config.patch(max_autotune_gemm_backends="TRITON")
122    def test_avoid_register_spilling(self):
123        if self.device != "cuda":
124            raise unittest.SkipTest("CUDA only")
125
126        from torch.nn.functional import gelu
127
128        def foo(m, inp):
129            curr = m(inp)
130            tmps = []
131            for _ in range(4):
132                curr = gelu(curr)
133                for t in tmps:
134                    curr = curr + t
135                tmps.append(curr)
136
137            return curr
138
139        m = torch.nn.Linear(2048, 2048, bias=True).half().cuda()
140        inp = torch.rand([2048, 2048]).half().cuda()
141
142        with torch.no_grad():
143            foo_c = torch.compile(mode="max-autotune-no-cudagraphs")(foo)
144
145            _, out_code = run_and_get_code(foo_c, m, inp)
146
147            # occasionally, CI will make this one kernel. just skip in this case
148            if not out_code[0].count("def triton_") == 2:
149                return
150
151            # should be multiple triton invocations
152            FileCheck().check("async_compile.wait").check_count(
153                ".run", 2, exactly=True
154            ).run(out_code[0])
155
156        with config.patch(
157            {"benchmark_fusion": False, "epilogue_fusion": False}
158        ), torch.no_grad():
159            torch._dynamo.reset()
160
161            foo_c = torch.compile(mode="max-autotune-no-cudagraphs")(foo)
162
163            _, out_code2 = run_and_get_code(foo_c, m, inp)
164
165        for c in out_code[0], out_code2[0]:
166            FileCheck().check("async_compile.wait").check("DeviceGuard").check_count(
167                "empty_strided_cuda", 2, exactly=True
168            ).check("return").run(c)
169
170    def test_tield_kernel_fusion(self):
171        def f(x):
172            y = realize(x + x.t())
173            return y + 1
174
175        x = torch.randn(1024, 1024, device=self.device)
176        self.common(f, (x,))
177
178
179if HAS_CUDA and not TEST_WITH_ASAN:
180
181    class BenchmarkFusionCudaTest(TestCase):
182        common = check_model_cuda
183        device = "cuda"
184
185    copy_tests(BenchmarkFusionTestTemplate, BenchmarkFusionCudaTest, "cuda")
186
187    class BenchmarkMultiTemplateFusionCudaTest(InductorTestCase):
188        @classmethod
189        def setUpClass(cls):
190            super().setUpClass()
191            cls._stack = contextlib.ExitStack()
192            cls._stack.enter_context(
193                config.patch(
194                    {
195                        "benchmark_kernel": True,
196                        "benchmark_fusion": True,
197                        "benchmark_epilogue_fusion": True,
198                    }
199                )
200            )
201
202        @classmethod
203        def tearDownClass(cls):
204            cls._stack.close()
205            super().tearDownClass()
206
207        def setUp(self):
208            super().setUp()
209            if not is_big_gpu(0):
210                return self.skipTest("Need a big GPU to run max_autotune=True")
211
212        def _equivalent_output_code_impl(self, size, first_dim=None, activation=True):
213            def foo(m, inp):
214                a = m(inp)
215                if activation:
216                    return torch.nn.functional.relu(a)
217                return a
218
219            foo_c = torch.compile(mode="max-autotune-no-cudagraphs")(foo)
220            first_dim = first_dim if first_dim is not None else size
221
222            m = torch.nn.Linear(size, size, bias=True).half().cuda()
223            inp = torch.rand([first_dim, size]).half().cuda()
224
225            with torch.no_grad():
226                res, code = run_and_get_code(foo_c, m, inp)
227
228            torch._dynamo.reset()
229            with unittest.mock.patch.object(
230                torch._inductor.config, "benchmark_epilogue_fusion", False
231            ):
232                foo_c = torch.compile(mode="max-autotune-no-cudagraphs")(foo)
233                with torch.no_grad():
234                    res2, code2 = run_and_get_code(foo_c, m, inp)
235
236            self.assertEqual(res, res2, atol=1e-4, rtol=1.1)
237            return code, code2
238
239        @fresh_inductor_cache()
240        @torch._inductor.config.patch(max_autotune_gemm_backends="TRITON")
241        def test_equivalent_template_code(self):
242            code, code2 = self._equivalent_output_code_impl(256)
243            for out_code in [code, code2]:
244                FileCheck().check("def call").check_count(
245                    "empty_strided_cuda", 1, exactly=True
246                ).check("triton_tem_fused_relu_0.run").check_count(
247                    "del", 3, exactly=True
248                ).check(
249                    "return"
250                ).run(
251                    out_code[0]
252                )
253
254        @fresh_inductor_cache()
255        @torch._inductor.config.patch(max_autotune_gemm_backends="ATEN")
256        def test_equivalent_extern_code(self):
257            torch._dynamo.reset()
258
259            code, code2 = self._equivalent_output_code_impl(512, 1, False)
260
261            for out_code in [code, code2]:
262                FileCheck().check("def call").check_count(
263                    "empty_strided_cuda", 1, exactly=True
264                ).check("extern_kernels.").check_count("del", 3, exactly=True).check(
265                    "return"
266                ).run(
267                    out_code[0]
268                )
269
270        def test_changed_layout(self):
271            # cat addmm planning will change layout - make sure propagated
272            def fn(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
273                return torch.cat(
274                    [
275                        torch.addmm(a, b, c),
276                        torch.addmm(b, c, a),
277                    ],
278                    1,
279                )
280
281            args = [
282                torch.randn(4, 4, device="cuda"),
283                torch.randn(4, 4, device="cuda"),
284                torch.randn(4, 4, device="cuda"),
285            ]
286
287            expected = fn(*args)
288            actual = torch.compile(fn, mode="max-autotune")(*args)
289            self.assertEqual(expected, actual)
290
291            torch._dynamo.reset()
292
293
294if HAS_CPU and not torch.backends.mps.is_available():
295
296    class BenchmarkFusionCpuTest(TestCase):
297        common = check_model
298        device = "cpu"
299
300    copy_tests(BenchmarkFusionTestTemplate, BenchmarkFusionCpuTest, "cpu")
301
302if __name__ == "__main__":
303    from torch._inductor.test_case import run_tests
304
305    if HAS_CPU or HAS_CUDA:
306        run_tests()
307