1# Owner(s): ["module: inductor"] 2import sys 3import unittest 4from typing import NamedTuple 5 6import torch 7from torch._inductor import config 8from torch._inductor.test_case import TestCase as InductorTestCase 9from torch.testing._internal.common_device_type import ( 10 get_desired_device_type_test_bases, 11) 12from torch.testing._internal.common_utils import slowTest, TEST_WITH_ASAN 13from torch.testing._internal.inductor_utils import HAS_CUDA 14 15 16try: 17 try: 18 from . import ( 19 test_combo_kernels, 20 test_foreach, 21 test_pattern_matcher, 22 test_select_algorithm, 23 test_torchinductor, 24 test_torchinductor_dynamic_shapes, 25 ) 26 except ImportError: 27 import test_combo_kernels 28 29 import test_foreach 30 import test_pattern_matcher 31 import test_select_algorithm 32 import test_torchinductor 33 import test_torchinductor_dynamic_shapes 34except unittest.SkipTest: 35 if __name__ == "__main__": 36 sys.exit(0) 37 raise 38 39 40_desired_test_bases = get_desired_device_type_test_bases() 41RUN_CUDA = ( 42 HAS_CUDA 43 and any(getattr(x, "device_type", "") == "cuda" for x in _desired_test_bases) 44 and not TEST_WITH_ASAN 45) 46 47 48class CudaWrapperTemplate: 49 pass 50 51 52class TestCudaWrapper(InductorTestCase): 53 device = "cuda" 54 55 56class DynamicShapesCudaWrapperCudaTests(InductorTestCase): 57 device = "cuda" 58 59 60test_failures_cuda_wrapper = { 61 "test_mm_plus_mm2_cuda_dynamic_shapes": test_torchinductor.TestFailure( 62 ("cuda_wrapper",), is_skip=True 63 ), 64} 65 66 67if config.abi_compatible: 68 xfail_list = [] 69 for test_name in xfail_list: 70 test_failures_cuda_wrapper[test_name] = test_torchinductor.TestFailure( 71 ("cuda_wrapper",), is_skip=False 72 ) 73 test_failures_cuda_wrapper[ 74 f"{test_name}_dynamic_shapes" 75 ] = test_torchinductor.TestFailure(("cuda_wrapper",), is_skip=False) 76 skip_list = [] 77 for test_name in skip_list: 78 test_failures_cuda_wrapper[test_name] = test_torchinductor.TestFailure( 79 ("cuda_wrapper",), is_skip=True 80 ) 81 test_failures_cuda_wrapper[ 82 f"{test_name}_dynamic_shapes" 83 ] = test_torchinductor.TestFailure(("cuda_wrapper",), is_skip=True) 84 85 86def make_test_case( 87 name, 88 device, 89 tests, 90 condition=True, 91 slow=False, 92 func_inputs=None, 93 code_string_count=None, 94): 95 test_name = f"{name}_{device}" if device else name 96 if code_string_count is None: 97 code_string_count = {} 98 99 func = getattr(tests, test_name) 100 assert callable(func), "not a callable" 101 func = slowTest(func) if slow else func 102 103 @config.patch(cpp_wrapper=True) 104 def fn(self): 105 tests.setUpClass() 106 tests.setUp() 107 try: 108 with torch._C._PreserveDispatchKeyGuard(): 109 torch._C._dispatch_tls_set_dispatch_key_included( 110 torch._C.DispatchKey.Dense, True 111 ) 112 113 _, code = test_torchinductor.run_and_get_cpp_code( 114 func, *func_inputs if func_inputs else [] 115 ) 116 self.assertEqual("CppWrapperCodeCache" in code, True) 117 self.assertTrue( 118 all( 119 code.count(string) == code_string_count[string] 120 for string in code_string_count 121 ) 122 ) 123 finally: 124 tests.tearDown() 125 tests.tearDownClass() 126 127 fn.__name__ = test_name 128 import copy 129 130 fn.__dict__ = copy.deepcopy(func.__dict__) 131 if condition: 132 setattr( 133 CudaWrapperTemplate, 134 test_name, 135 fn, 136 ) 137 138 139if RUN_CUDA: 140 141 class BaseTest(NamedTuple): 142 name: str 143 device: str = "cuda" 144 tests: InductorTestCase = test_torchinductor.GPUTests() 145 146 # Maintain two separate test lists for cuda and cpp for now 147 for item in [ 148 BaseTest("test_add_complex"), 149 BaseTest("test_add_complex4"), 150 BaseTest("test_as_strided"), # buffer reuse 151 BaseTest("test_batch_norm_2d_2"), 152 BaseTest("test_bernoulli1"), 153 BaseTest("test_bitwise"), # int32 154 BaseTest("test_bmm1"), 155 BaseTest("test_bmm2"), 156 BaseTest("test_buffer_use_after_remove"), 157 BaseTest("test_cat"), # alias 158 BaseTest("test_convolution1"), 159 BaseTest("test_conv_backward"), 160 BaseTest("test_custom_op_1"), 161 BaseTest("test_custom_op_2"), 162 BaseTest("test_custom_op_3"), 163 BaseTest("test_embedding_bag"), # test default FallbackKernel 164 BaseTest("test_index_put_deterministic_fallback"), 165 BaseTest("test_adding_tensor_offsets"), 166 BaseTest("test_index_tensor"), 167 BaseTest("test_inductor_layout_optimization_input_mutations"), 168 BaseTest("test_insignificant_strides"), 169 BaseTest("test_layer_norm"), 170 BaseTest("test_linear1"), 171 BaseTest("test_linear2"), 172 BaseTest("test_mm_views"), 173 BaseTest("test_multi_device"), 174 BaseTest("test_multi_threading"), 175 BaseTest("test_pow3"), 176 BaseTest("test_profiler_mark_wrapper_call"), 177 BaseTest("test_randint"), 178 BaseTest("test_reduction1"), # Reduction 179 BaseTest("test_relu"), # multiple inputs 180 BaseTest("test_repeat_interleave_2"), 181 BaseTest("test_roi_align"), 182 BaseTest("test_scalar_input"), 183 BaseTest("test_scaled_dot_product_attention"), 184 BaseTest("test_scaled_dot_product_efficient_attention"), 185 BaseTest("test_sort"), 186 BaseTest("test_silu"), # single input, single output 187 BaseTest("test_sum_dtype"), # float64 188 BaseTest("test_sum_int"), # bool, int64, int8, uint8 189 BaseTest("test_transpose"), # multiple outputs, buffer clear 190 BaseTest("test_unspec_inputs"), 191 BaseTest("test_pointwise_hermite_polynomial_he"), 192 BaseTest("test_pointwise_hermite_polynomial_h"), 193 BaseTest( 194 "test_foreach_cpp_wrapper", 195 tests=test_foreach.ForeachTests(), 196 ), # test foreach 197 BaseTest( 198 "test_enable_dynamic_shapes_cpp_wrapper", 199 tests=test_foreach.ForeachTests(), 200 ), 201 BaseTest( 202 "test_dynamic_shapes_persistent_reduction_mixed_x_dim", 203 tests=test_combo_kernels.ComboKernelDynamicShapesTests(), 204 ), 205 BaseTest( 206 "test_cat_slice_cat", 207 tests=test_pattern_matcher.TestPatternMatcher(), 208 ), 209 # TODO: Re-enable this test after fixing cuda wrapper for conv Triton templates with dynamic shapes. 210 # This test is unstable: it succeeds when an ATEN kernel is used, and fails when a Triton kernel is used. 211 # Currently it passes on CI (an ATEN kernel is chosen) and fails locally (a Triton kernel is chosen). 212 # Ideally, it should succeed for whatever kernels. 213 # BaseTest( 214 # "test_convolution1", 215 # device=None, 216 # tests=test_select_algorithm.TestSelectAlgorithm(), 217 # ), 218 BaseTest( 219 "test_mm_plus_mm2", 220 tests=test_select_algorithm.TestSelectAlgorithm(), 221 ), 222 BaseTest( 223 "test_mm_plus_mm3", 224 tests=test_select_algorithm.TestSelectAlgorithm(), 225 ), 226 BaseTest("test_fft_real_input"), 227 BaseTest("test_fft_real_input_real_output"), 228 BaseTest("test_dtypeview"), 229 BaseTest("test_dtypeview_fusion"), 230 ]: 231 make_test_case(item.name, item.device, item.tests) 232 233 from torch._inductor.utils import is_big_gpu 234 235 if is_big_gpu(0): 236 for item in [ 237 BaseTest( 238 "test_addmm", 239 tests=test_select_algorithm.TestSelectAlgorithm(), 240 ), 241 BaseTest( 242 "test_linear_relu", 243 tests=test_select_algorithm.TestSelectAlgorithm(), 244 ), 245 ]: 246 make_test_case(item.name, item.device, item.tests) 247 248 test_torchinductor.copy_tests( 249 CudaWrapperTemplate, TestCudaWrapper, "cuda_wrapper", test_failures_cuda_wrapper 250 ) 251 252 DynamicShapesCudaWrapperTemplate = ( 253 test_torchinductor_dynamic_shapes.make_dynamic_cls(CudaWrapperTemplate) 254 ) 255 256 test_torchinductor.copy_tests( 257 DynamicShapesCudaWrapperTemplate, 258 DynamicShapesCudaWrapperCudaTests, 259 "cuda_wrapper", 260 test_failures_cuda_wrapper, 261 xfail_prop="_expected_failure_dynamic_wrapper", 262 ) 263 264if __name__ == "__main__": 265 from torch._inductor.test_case import run_tests 266 267 print(f"FS: run_cuda {RUN_CUDA}") 268 if RUN_CUDA: 269 run_tests(needs="filelock") 270