xref: /aosp_15_r20/external/pytorch/test/inductor/test_cuda_cpp_wrapper.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2import sys
3import unittest
4from typing import NamedTuple
5
6import torch
7from torch._inductor import config
8from torch._inductor.test_case import TestCase as InductorTestCase
9from torch.testing._internal.common_device_type import (
10    get_desired_device_type_test_bases,
11)
12from torch.testing._internal.common_utils import slowTest, TEST_WITH_ASAN
13from torch.testing._internal.inductor_utils import HAS_CUDA
14
15
16try:
17    try:
18        from . import (
19            test_combo_kernels,
20            test_foreach,
21            test_pattern_matcher,
22            test_select_algorithm,
23            test_torchinductor,
24            test_torchinductor_dynamic_shapes,
25        )
26    except ImportError:
27        import test_combo_kernels
28
29        import test_foreach
30        import test_pattern_matcher
31        import test_select_algorithm
32        import test_torchinductor
33        import test_torchinductor_dynamic_shapes
34except unittest.SkipTest:
35    if __name__ == "__main__":
36        sys.exit(0)
37    raise
38
39
40_desired_test_bases = get_desired_device_type_test_bases()
41RUN_CUDA = (
42    HAS_CUDA
43    and any(getattr(x, "device_type", "") == "cuda" for x in _desired_test_bases)
44    and not TEST_WITH_ASAN
45)
46
47
48class CudaWrapperTemplate:
49    pass
50
51
52class TestCudaWrapper(InductorTestCase):
53    device = "cuda"
54
55
56class DynamicShapesCudaWrapperCudaTests(InductorTestCase):
57    device = "cuda"
58
59
60test_failures_cuda_wrapper = {
61    "test_mm_plus_mm2_cuda_dynamic_shapes": test_torchinductor.TestFailure(
62        ("cuda_wrapper",), is_skip=True
63    ),
64}
65
66
67if config.abi_compatible:
68    xfail_list = []
69    for test_name in xfail_list:
70        test_failures_cuda_wrapper[test_name] = test_torchinductor.TestFailure(
71            ("cuda_wrapper",), is_skip=False
72        )
73        test_failures_cuda_wrapper[
74            f"{test_name}_dynamic_shapes"
75        ] = test_torchinductor.TestFailure(("cuda_wrapper",), is_skip=False)
76    skip_list = []
77    for test_name in skip_list:
78        test_failures_cuda_wrapper[test_name] = test_torchinductor.TestFailure(
79            ("cuda_wrapper",), is_skip=True
80        )
81        test_failures_cuda_wrapper[
82            f"{test_name}_dynamic_shapes"
83        ] = test_torchinductor.TestFailure(("cuda_wrapper",), is_skip=True)
84
85
86def make_test_case(
87    name,
88    device,
89    tests,
90    condition=True,
91    slow=False,
92    func_inputs=None,
93    code_string_count=None,
94):
95    test_name = f"{name}_{device}" if device else name
96    if code_string_count is None:
97        code_string_count = {}
98
99    func = getattr(tests, test_name)
100    assert callable(func), "not a callable"
101    func = slowTest(func) if slow else func
102
103    @config.patch(cpp_wrapper=True)
104    def fn(self):
105        tests.setUpClass()
106        tests.setUp()
107        try:
108            with torch._C._PreserveDispatchKeyGuard():
109                torch._C._dispatch_tls_set_dispatch_key_included(
110                    torch._C.DispatchKey.Dense, True
111                )
112
113                _, code = test_torchinductor.run_and_get_cpp_code(
114                    func, *func_inputs if func_inputs else []
115                )
116                self.assertEqual("CppWrapperCodeCache" in code, True)
117                self.assertTrue(
118                    all(
119                        code.count(string) == code_string_count[string]
120                        for string in code_string_count
121                    )
122                )
123        finally:
124            tests.tearDown()
125            tests.tearDownClass()
126
127    fn.__name__ = test_name
128    import copy
129
130    fn.__dict__ = copy.deepcopy(func.__dict__)
131    if condition:
132        setattr(
133            CudaWrapperTemplate,
134            test_name,
135            fn,
136        )
137
138
139if RUN_CUDA:
140
141    class BaseTest(NamedTuple):
142        name: str
143        device: str = "cuda"
144        tests: InductorTestCase = test_torchinductor.GPUTests()
145
146    # Maintain two separate test lists for cuda and cpp for now
147    for item in [
148        BaseTest("test_add_complex"),
149        BaseTest("test_add_complex4"),
150        BaseTest("test_as_strided"),  # buffer reuse
151        BaseTest("test_batch_norm_2d_2"),
152        BaseTest("test_bernoulli1"),
153        BaseTest("test_bitwise"),  # int32
154        BaseTest("test_bmm1"),
155        BaseTest("test_bmm2"),
156        BaseTest("test_buffer_use_after_remove"),
157        BaseTest("test_cat"),  # alias
158        BaseTest("test_convolution1"),
159        BaseTest("test_conv_backward"),
160        BaseTest("test_custom_op_1"),
161        BaseTest("test_custom_op_2"),
162        BaseTest("test_custom_op_3"),
163        BaseTest("test_embedding_bag"),  # test default FallbackKernel
164        BaseTest("test_index_put_deterministic_fallback"),
165        BaseTest("test_adding_tensor_offsets"),
166        BaseTest("test_index_tensor"),
167        BaseTest("test_inductor_layout_optimization_input_mutations"),
168        BaseTest("test_insignificant_strides"),
169        BaseTest("test_layer_norm"),
170        BaseTest("test_linear1"),
171        BaseTest("test_linear2"),
172        BaseTest("test_mm_views"),
173        BaseTest("test_multi_device"),
174        BaseTest("test_multi_threading"),
175        BaseTest("test_pow3"),
176        BaseTest("test_profiler_mark_wrapper_call"),
177        BaseTest("test_randint"),
178        BaseTest("test_reduction1"),  # Reduction
179        BaseTest("test_relu"),  # multiple inputs
180        BaseTest("test_repeat_interleave_2"),
181        BaseTest("test_roi_align"),
182        BaseTest("test_scalar_input"),
183        BaseTest("test_scaled_dot_product_attention"),
184        BaseTest("test_scaled_dot_product_efficient_attention"),
185        BaseTest("test_sort"),
186        BaseTest("test_silu"),  # single input, single output
187        BaseTest("test_sum_dtype"),  # float64
188        BaseTest("test_sum_int"),  # bool, int64, int8, uint8
189        BaseTest("test_transpose"),  # multiple outputs, buffer clear
190        BaseTest("test_unspec_inputs"),
191        BaseTest("test_pointwise_hermite_polynomial_he"),
192        BaseTest("test_pointwise_hermite_polynomial_h"),
193        BaseTest(
194            "test_foreach_cpp_wrapper",
195            tests=test_foreach.ForeachTests(),
196        ),  # test foreach
197        BaseTest(
198            "test_enable_dynamic_shapes_cpp_wrapper",
199            tests=test_foreach.ForeachTests(),
200        ),
201        BaseTest(
202            "test_dynamic_shapes_persistent_reduction_mixed_x_dim",
203            tests=test_combo_kernels.ComboKernelDynamicShapesTests(),
204        ),
205        BaseTest(
206            "test_cat_slice_cat",
207            tests=test_pattern_matcher.TestPatternMatcher(),
208        ),
209        # TODO: Re-enable this test after fixing cuda wrapper for conv Triton templates with dynamic shapes.
210        # This test is unstable: it succeeds when an ATEN kernel is used, and fails when a Triton kernel is used.
211        # Currently it passes on CI (an ATEN kernel is chosen) and fails locally (a Triton kernel is chosen).
212        # Ideally, it should succeed for whatever kernels.
213        # BaseTest(
214        #     "test_convolution1",
215        #     device=None,
216        #     tests=test_select_algorithm.TestSelectAlgorithm(),
217        # ),
218        BaseTest(
219            "test_mm_plus_mm2",
220            tests=test_select_algorithm.TestSelectAlgorithm(),
221        ),
222        BaseTest(
223            "test_mm_plus_mm3",
224            tests=test_select_algorithm.TestSelectAlgorithm(),
225        ),
226        BaseTest("test_fft_real_input"),
227        BaseTest("test_fft_real_input_real_output"),
228        BaseTest("test_dtypeview"),
229        BaseTest("test_dtypeview_fusion"),
230    ]:
231        make_test_case(item.name, item.device, item.tests)
232
233    from torch._inductor.utils import is_big_gpu
234
235    if is_big_gpu(0):
236        for item in [
237            BaseTest(
238                "test_addmm",
239                tests=test_select_algorithm.TestSelectAlgorithm(),
240            ),
241            BaseTest(
242                "test_linear_relu",
243                tests=test_select_algorithm.TestSelectAlgorithm(),
244            ),
245        ]:
246            make_test_case(item.name, item.device, item.tests)
247
248    test_torchinductor.copy_tests(
249        CudaWrapperTemplate, TestCudaWrapper, "cuda_wrapper", test_failures_cuda_wrapper
250    )
251
252    DynamicShapesCudaWrapperTemplate = (
253        test_torchinductor_dynamic_shapes.make_dynamic_cls(CudaWrapperTemplate)
254    )
255
256    test_torchinductor.copy_tests(
257        DynamicShapesCudaWrapperTemplate,
258        DynamicShapesCudaWrapperCudaTests,
259        "cuda_wrapper",
260        test_failures_cuda_wrapper,
261        xfail_prop="_expected_failure_dynamic_wrapper",
262    )
263
264if __name__ == "__main__":
265    from torch._inductor.test_case import run_tests
266
267    print(f"FS: run_cuda {RUN_CUDA}")
268    if RUN_CUDA:
269        run_tests(needs="filelock")
270