xref: /aosp_15_r20/external/pytorch/test/inductor/test_ck_backend.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2import logging
3import os
4import unittest
5
6import torch
7from torch._inductor import config
8from torch._inductor.test_case import run_tests, TestCase
9from torch.testing._internal.common_utils import (
10    instantiate_parametrized_tests,
11    parametrize,
12)
13from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
14
15
16torch.set_float32_matmul_precision("high")
17if HAS_CUDA:
18    torch.cuda.memory._set_allocator_settings("expandable_segments:False")
19
20log = logging.getLogger(__name__)
21
22
23def _get_path_without_sccache() -> str:
24    """
25    Get the PATH environment variable without sccache.
26    """
27    path_envs = os.environ.get("PATH", "").split(":")
28    path_envs = [env for env in path_envs if "/opt/cache/bin" not in env]
29    return ":".join(path_envs)
30
31
32@instantiate_parametrized_tests
33class TestCKBackend(TestCase):
34    def setUp(self):
35        # The new inductor cache refresh mechanism
36        # introduced with https://github.com/pytorch/pytorch/pull/122661
37        # interacts badly with persistent subprocesses during
38        # autotuning. So we need to disable automatic cache refresh
39        # before calling setUp() on the parent class.
40        old_disable_fresh_cache_envvar = os.environ.get(
41            "INDUCTOR_TEST_DISABLE_FRESH_CACHE", ""
42        )
43
44        torch.random.manual_seed(1234)
45        try:
46            import ck4inductor
47
48            self.ck_dir = os.path.dirname(ck4inductor.__file__)
49            os.environ["TORCHINDUCTOR_CK_DIR"] = self.ck_dir
50        except ImportError as e:
51            raise unittest.SkipTest("Composable Kernel library not installed") from e
52
53        try:
54            os.environ["INDUCTOR_TEST_DISABLE_FRESH_CACHE"] = "1"
55            super().setUp()
56        finally:
57            os.environ[
58                "INDUCTOR_TEST_DISABLE_FRESH_CACHE"
59            ] = old_disable_fresh_cache_envvar
60
61    @unittest.skipIf(not torch.version.hip, "ROCM only")
62    @unittest.skipIf(config.is_fbcode(), "fbcode requires different CK path setup")
63    @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
64    @parametrize("max_autotune_gemm_backends", ("CK", "ATen,Triton,CK"))
65    @parametrize("autotune_in_subproc", (True, False))
66    def test_max_autotune_precompile_matmul(
67        self, max_autotune_gemm_backends, autotune_in_subproc
68    ):
69        """
70        Make sure autotuning mm doesn't crash.
71        """
72
73        torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
74
75        def mm(a, b):
76            return a @ b
77
78        tensor_options = {"device": "cuda", "dtype": torch.bfloat16}
79
80        a = torch.randn(2240, 256, **tensor_options)
81        b = torch.randn(256, 2048, **tensor_options)
82
83        assert "rocm" in dir(config)
84
85        with config.patch(
86            {
87                "max_autotune": True,
88                "autotune_in_subproc": autotune_in_subproc,
89                "max_autotune_gemm_backends": max_autotune_gemm_backends,
90                "compile_threads": 2,
91                "rocm.n_max_profiling_configs": 2,
92                "rocm.ck_dir": self.ck_dir,
93            }
94        ):
95            Y_compiled = torch.compile(mm, dynamic=False)(a, b)
96            Y = mm(a, b)
97            torch.testing.assert_close(Y_compiled, Y)
98
99    @unittest.skipIf(not torch.version.hip, "ROCM only")
100    @unittest.skipIf(config.is_fbcode(), "fbcode requires different CK path setup")
101    @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
102    @parametrize("max_autotune_gemm_backends", ("CK",))
103    @parametrize("autotune_in_subproc", (True,))
104    def test_max_autotune_precompile_matmul_dynamic(
105        self, max_autotune_gemm_backends, autotune_in_subproc
106    ):
107        """
108        Test matmul with dynamic shapes
109        """
110
111        torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
112
113        tensor_options = {"device": "cuda", "dtype": torch.bfloat16}
114
115        a = torch.randn(2240, 256, **tensor_options)
116        b = torch.randn(256, 2048, **tensor_options)
117
118        torch._dynamo.mark_dynamic(a, 0)
119
120        assert "rocm" in dir(config)
121
122        with config.patch(
123            {
124                "max_autotune": True,
125                "autotune_in_subproc": autotune_in_subproc,
126                "max_autotune_gemm_backends": max_autotune_gemm_backends,
127                "compile_threads": 2,
128                "rocm.n_max_profiling_configs": 2,
129                "rocm.ck_dir": self.ck_dir,
130            }
131        ):
132
133            @torch.compile(dynamic=True)
134            def compiled_mm(a, b):
135                return a @ b
136
137            Y_compiled = compiled_mm(a, b)
138            Y = a @ b
139            torch.testing.assert_close(Y_compiled, Y)
140
141            a1 = torch.randn(1024, 256, **tensor_options)
142            Y1_compiled = compiled_mm(a1, b)
143            Y1 = a1 @ b
144            torch.testing.assert_close(Y1_compiled, Y1)
145
146    @unittest.skipIf(not torch.version.hip, "ROCM only")
147    @unittest.skipIf(config.is_fbcode(), "fbcode requires different CK path setup")
148    @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
149    @parametrize("max_autotune_gemm_backends", ("CK", "ATen,Triton,CK"))
150    def test_max_autotune_precompile_preselected(self, max_autotune_gemm_backends):
151        """
152        End to end test for picking preselected ck instances
153        """
154
155        torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
156
157        def mm(a, b):
158            return a @ b
159
160        tensor_options = {"device": "cuda", "dtype": torch.float16}
161
162        a = torch.randn(2240, 256, **tensor_options)
163        b = torch.randn(2048, 256, **tensor_options).transpose(0, 1)
164
165        assert "rocm" in dir(config)
166
167        with config.patch(
168            {
169                "max_autotune": True,
170                "autotune_in_subproc": True,
171                "max_autotune_gemm_backends": max_autotune_gemm_backends,
172                "compile_threads": 12,
173                "rocm.ck_dir": self.ck_dir,
174                "rocm.use_preselected_instances": True,
175            }
176        ):
177            Y_compiled = torch.compile(mm, dynamic=False)(a, b)
178            Y = mm(a, b)
179            torch.testing.assert_close(Y_compiled, Y)
180
181    @unittest.skipIf(not torch.version.hip, "ROCM only")
182    @unittest.skipIf(config.is_fbcode(), "fbcode requires different CK path setup")
183    @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
184    @parametrize("max_autotune_gemm_backends", ("CK", "ATen,Triton,CK"))
185    def test_max_autotune_precompile_non_contiguous(self, max_autotune_gemm_backends):
186        """
187        Make sure the ck template can work with non-contiguous inputs
188        """
189
190        torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
191
192        tensor_options = {"device": "cuda", "dtype": torch.float16}
193
194        a = torch.empty_strided((50257, 32768), (1, 50304), **tensor_options)
195        b = torch.empty_strided((32768, 768), (768, 1), **tensor_options)
196
197        assert "rocm" in dir(config)
198
199        with config.patch(
200            {
201                "max_autotune": True,
202                "autotune_in_subproc": True,
203                "max_autotune_gemm_backends": max_autotune_gemm_backends,
204                "compile_threads": 2,
205                "rocm.ck_dir": self.ck_dir,
206                "rocm.n_max_profiling_configs": 2,
207            }
208        ):
209
210            @torch.compile(dynamic=False)
211            def mm(a, b):
212                return a @ b
213
214            Y_compiled = mm(a, b)
215            Y_eager = a @ b
216            torch.testing.assert_close(Y_compiled, Y_eager)
217
218    @unittest.skipIf(not torch.version.hip, "ROCM only")
219    @unittest.skipIf(config.is_fbcode(), "fbcode requires different CK path setup")
220    @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
221    @parametrize("max_autotune_gemm_backends", ("CK", "ATen,Triton,CK"))
222    @parametrize("x_shape", ([4096, 2048], [2048], [4096, 1]))
223    def test_max_autotune_addmm(self, max_autotune_gemm_backends, x_shape):
224        torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
225
226        m, k, n = 4096, 224, 2048
227        alpha, beta = 1.0, 1.0
228
229        tensor_options = {"device": "cuda", "dtype": torch.float16}
230        x = torch.ones(x_shape, **tensor_options)
231        a = torch.randn(m, k, **tensor_options)
232        b = torch.randn(k, n, **tensor_options)
233
234        assert "rocm" in dir(config)
235
236        with config.patch(
237            {
238                "max_autotune": True,
239                "autotune_in_subproc": True,
240                "max_autotune_gemm_backends": max_autotune_gemm_backends,
241                "compile_threads": 2,
242                "rocm.ck_dir": self.ck_dir,
243                "rocm.n_max_profiling_configs": 2,
244            }
245        ):
246
247            @torch.compile(dynamic=False)
248            def addmm(x, a, b, alpha, beta):
249                return torch.addmm(x, a, b, alpha=alpha, beta=beta)
250
251            Y_compiled = addmm(x, a, b, alpha, beta)
252            Y_eager = torch.addmm(x, a, b, alpha=alpha, beta=beta)
253
254            torch.testing.assert_close(Y_compiled, Y_eager)
255
256
257if __name__ == "__main__":
258    from torch._inductor.utils import is_big_gpu
259
260    # Set env to make it work in CI.
261    if HAS_CUDA and HAS_CPU and is_big_gpu(0):
262        run_tests()
263