1# Owner(s): ["module: sparse"] 2import itertools 3import random 4import unittest 5 6import torch 7from torch import nn 8import torch.nn.functional as F 9 10from torch.sparse import ( 11 SparseSemiStructuredTensor, 12 SparseSemiStructuredTensorCUSPARSELT, 13 SparseSemiStructuredTensorCUTLASS, 14 to_sparse_semi_structured, 15) 16 17from torch.sparse._semi_structured_conversions import ( 18 sparse_semi_structured_from_dense_cutlass, 19 _sparse_semi_structured_tile, 20 _compute_compressed_swizzled_bitmask, 21) 22 23from torch.testing import make_tensor 24from torch.testing._internal.common_cuda import _get_torch_cuda_version 25from torch.testing._internal.common_device_type import ( 26 dtypes, 27 instantiate_device_type_tests, 28) 29 30from torch.testing._internal.common_dtype import all_types_and_complex 31import torch._dynamo.test_case 32from torch.testing._internal.common_utils import ( 33 parametrize, 34 run_tests, 35 subtest, 36 TestCase, 37 TEST_WITH_ROCM, 38 IS_WINDOWS, 39) 40 41import pytest 42 43from torch.utils._triton import has_triton 44 45SEMI_STRUCTURED_SUPPORTED_BACKENDS = dict() 46 47_IS_SM8X = False 48_IS_SM9X = False 49 50if torch.cuda.is_available(): 51 _IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8 52 _IS_SM9X = torch.cuda.get_device_capability(0)[0] == 9 53 54 # CUTLASS kernels only work for Ampere 55 if _IS_SM8X: 56 SEMI_STRUCTURED_SUPPORTED_BACKENDS["cutlass"] = SparseSemiStructuredTensorCUTLASS 57 58 # add cuSPASRELt tests if available 59 if torch.backends.cusparselt.is_available() and (_IS_SM8X or _IS_SM9X): 60 SEMI_STRUCTURED_SUPPORTED_BACKENDS["cusparselt"] = SparseSemiStructuredTensorCUSPARSELT 61 62inference_dtypes = dtypes(torch.float16, torch.bfloat16, torch.int8) 63training_dtypes = dtypes(torch.float16, torch.bfloat16) 64parametrize_backends = parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS) 65 66atol_rtol_kw = { 67 torch.float16: { 68 "rtol": 1e-3, 69 "atol": 1e-3, 70 }, 71 torch.bfloat16: { 72 "rtol": 1e-1, 73 "atol": 1e-1, 74 }, 75} 76 77def sparse24_largest_mask_2d(original): 78 sparse = SparseSemiStructuredTensorCUTLASS.prune_dense_static_sort(original) 79 return sparse.to_dense().bool() 80 81def sparsify24_dense(original): 82 return sparse24_largest_mask_2d(original) * original 83 84def rand_sparse_semi_structured_mask( 85 r, c, dtype=torch.float16, device="cuda", choice=None 86): 87 """ 88 This function returns a 1:2 sparse matrix of size (r, c). 89 Note that this means this matrix will also be 2:4 and 4:8 sparse as well. 90 """ 91 92 choices = [[0, 1], [1, 0]] 93 mask_entries = [choice or random.choice(choices) for i in range(r * c // 2)] 94 95 return ( 96 torch.tensor(mask_entries, dtype=dtype, device=device) 97 .reshape(r, c) 98 .contiguous() 99 ) 100 101def rand_sparse_semi_structured(r, c, dtype, device, choice=None): 102 pattern = '2by4' if dtype != torch.float32 else '1by2' 103 if pattern == '1by2': 104 ksparse = 2 105 choices = [ 106 [0, 1], 107 [1, 0] 108 ] 109 elif pattern == '2by4': 110 ksparse = 4 111 choices = [ 112 [1, 1, 0, 0], 113 [1, 0, 1, 0], 114 [1, 0, 0, 1], 115 [0, 1, 1, 0], 116 [0, 1, 0, 1], 117 [0, 0, 1, 1] 118 ] 119 mask_entries = [choice or random.choice(choices) for i in range(r * c // ksparse)] 120 mask = torch.tensor(mask_entries, dtype=torch.bool).view(r, c).to(device) 121 dense = make_tensor(r, c, dtype=dtype, device=device) 122 dense[dense == 0] = 1 # To prevent zeros except where mask applied. 123 dense = dense.masked_fill(~mask, 0) 124 return dense 125 126 127def rand_sparse_semi_structured_all_patterns(r, c, dtype, device): 128 pattern = '2by4' if dtype != torch.float32 else '1by2' 129 if pattern == '1by2': 130 ksparse = 2 131 choices = [ 132 [[0, 0], [0, 1]], 133 [[0, 1], [0, 1]], 134 [[1, 0], [1, 0]], 135 [[1, 1], [1, 0]] 136 ] 137 elif pattern == '2by4': 138 ksparse = 4 139 choices = [ 140 [[0, 0, 0, 0], [0, 0, 1, 1]], 141 [[0, 0, 0, 1], [0, 0, 1, 1]], 142 [[0, 0, 1, 0], [0, 0, 1, 1]], 143 [[0, 0, 1, 1], [0, 0, 1, 1]], 144 [[0, 1, 0, 0], [0, 1, 1, 0]], 145 [[0, 1, 0, 1], [0, 1, 0, 1]], 146 [[0, 1, 1, 0], [0, 1, 1, 0]], 147 [[0, 1, 1, 1], [0, 1, 0, 1]], 148 [[1, 0, 0, 0], [1, 0, 1, 0]], 149 [[1, 0, 0, 1], [1, 0, 0, 1]], 150 [[1, 0, 1, 0], [1, 0, 1, 0]], 151 [[1, 0, 1, 1], [1, 0, 0, 1]], 152 [[1, 1, 0, 0], [1, 1, 0, 0]], 153 [[1, 1, 0, 1], [1, 1, 0, 0]], 154 [[1, 1, 1, 0], [1, 1, 0, 0]], 155 [[1, 1, 1, 1], [1, 1, 0, 0]], 156 ] 157 mask_rows = [random.randint(0, len(choices) - 1) for i in range(r * c // ksparse)] 158 159 COL_INV, COL_VAL = 0, 1 160 mask_entries_inv = [choices[i][COL_INV] for i in mask_rows] 161 mask_entries_val = [choices[i][COL_VAL] for i in mask_rows] 162 mask_inv = torch.tensor(mask_entries_inv, dtype=torch.bool).view(r, c).to(device) 163 mask_val = torch.tensor(mask_entries_val, dtype=torch.bool).view(r, c).to(device) 164 dense = make_tensor(r, c, dtype=dtype, device=device) 165 dense[dense == 0] = 1 # To prevent zeros except where mask below applied. 166 dense_inv = dense.masked_fill(~mask_inv, 0) 167 dense_val = dense_inv.masked_fill(~mask_val, 0) 168 169 return dense_inv, dense_val 170 171 172class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase): 173 174 def setUp(self): 175 if len(SEMI_STRUCTURED_SUPPORTED_BACKENDS) == 0: 176 self.skipTest('semi-structured sparsity has no available backend!') 177 super().setUp() 178 179 def tearDown(self): 180 super().tearDown() 181 182 @staticmethod 183 def _test_mlp_contiguous_relu_compile(backend, dense_input_shape): 184 """ 185 Test nn.Linear + .contiguous() + nn.ReLU with SparseSemiStructuredTensor + torch.compile 186 We expect: 187 (1) The sparse tensor subclass should turn nn.Linear into `aten._structured_sparse_addmm` + `aten.contiguous()` 188 (2) Inductor should fuse the .contiguous() call into the relu 189 """ 190 191 class Model(nn.Module): 192 def __init__(self) -> None: 193 super().__init__() 194 self.linear = nn.Linear(128, 128) 195 196 def forward(self, x): 197 x = self.linear(x) 198 x = x.contiguous() 199 return torch.nn.functional.relu(x) 200 201 input = torch.rand(dense_input_shape, device="cuda").half() 202 model = Model().eval().cuda().half() 203 mod_linear = model.linear 204 m, n = mod_linear.weight.shape 205 mask = torch.Tensor([1, 0, 0, 1]).tile((m, n // 4)).bool().cuda() 206 # set masked weight 207 mod_linear.weight = nn.Parameter(mod_linear.weight * mask) 208 209 dense_result = model(input) 210 mod_linear.weight = nn.Parameter(SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend].from_dense(mod_linear.weight)) 211 sparse_result = model(input) 212 213 model = torch.compile(model, backend="inductor", fullgraph=True) 214 sparse_compile_result = model(input) 215 216 # test that sparse_compile_result and dense_result are numerically close 217 torch.testing.assert_close(dense_result, sparse_compile_result, rtol=1e-3, atol=1e-3) 218 # assert sparse and sparse_compile have the same strides, 219 # as meta registrations may return contiguous tensors when the output is transposed 220 # https://github.com/pytorch/pytorch/pull/114477 221 assert sparse_result.stride() == sparse_compile_result.stride() 222 223 @unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows") 224 @unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine") 225 def test_mlp_contiguous_relu_compile_cusparselt(self): 226 """ 227 test for cuSPASRELt meta registrations (_cslt_sparse_mm) + torch.compile 228 """ 229 for dense_input_shape in [(1, 128), (64, 128), (128, 128), (64, 128, 128)]: 230 SparseSemiStructuredTensorCompileTest._test_mlp_contiguous_relu_compile("cusparselt", dense_input_shape) 231 232 233 @unittest.skipIf("cutlass" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cutlass not supported on this machine") 234 @unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows") 235 def test_mlp_contiguous_relu_compile_cutlass(self): 236 """ 237 test for CUTLASS meta registrations (_sparse_semi_structured_addmm) + torch.compile 238 """ 239 for dense_input_shape in [(1, 128), (64, 128), (128, 128), (64, 128, 128)]: 240 SparseSemiStructuredTensorCompileTest._test_mlp_contiguous_relu_compile("cutlass", dense_input_shape) 241 242 243 @unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows") 244 @unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine") 245 def test_sp24_compile(self) -> None: 246 x = torch.randn([1024, 512], device="cuda", dtype=torch.float16, requires_grad=True) 247 e = torch.eye(x.shape[0], x.shape[0], device="cuda", dtype=torch.float16) 248 249 def fn(x, e): 250 y = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(x) 251 y = y.t() 252 return x @ y 253 254 # Eager 255 output = fn(x, e) 256 output.backward(output) 257 # Torch compile 258 output = torch.compile(fn)(x, e) 259 output.backward(output) 260 261class TestSparseSemiStructured(TestCase): 262 263 def setUp(self): 264 if len(SEMI_STRUCTURED_SUPPORTED_BACKENDS) == 0: 265 self.skipTest('semi-structured sparsity has no available backend!') 266 if IS_WINDOWS: 267 self.skipTest("torch.compile not supported on windows") 268 269 @inference_dtypes 270 @parametrize_backends 271 def test_to_sparse_semi_structured(self, dtype, backend): 272 SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") 273 A = rand_sparse_semi_structured_mask(128, 256, dtype=dtype) 274 A_sparse = to_sparse_semi_structured(A) 275 276 assert A.shape == A_sparse.shape 277 assert A.device == A_sparse.device 278 assert A.dtype == A_sparse.dtype 279 280 assert isinstance(A, torch.Tensor) 281 assert isinstance(A_sparse, SparseSemiStructuredTensor) 282 283 @inference_dtypes 284 @parametrize_backends 285 @parametrize("dense_input_shape", [(128, 1), (128, 64), (128, 128)]) 286 def test_mm_sparse_first_NN(self, dense_input_shape, dtype, device, backend): 287 """ 288 Ensure torch.mm(A_sparse, B) is correct for float16 and will throw error for int8 289 """ 290 SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") 291 A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype) 292 A_sparse = to_sparse_semi_structured(A) 293 294 B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype) 295 296 # Currently we don't support int matmul on GPU, so evaluate on CPU and copy over 297 if dtype is torch.int8: 298 if backend == "cutlass": 299 with self.assertRaisesRegex(RuntimeError, "spgemm_cutlass_dispatch_layouts"): 300 sparse_result = torch.mm(A_sparse, B) 301 else: 302 with self.assertRaisesRegex(RuntimeError, 303 "CUDA error: operation not supported when calling `cusparseLtMatmulDescriptorInit"): 304 sparse_result = torch.mm(A_sparse, B) 305 else: 306 dense_result = torch.mm(A, B) 307 sparse_result = torch.mm(A_sparse, B) 308 torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3) 309 310 @inference_dtypes 311 @parametrize_backends 312 @parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)]) 313 def test_mm_sparse_first_NT(self, dense_input_shape, dtype, device, backend): 314 """ 315 Ensure torch.mm(A_sparse, B.t()) is correct for float16/bfloat16 316 and will throw an error for int8 + padding 317 """ 318 SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") 319 A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype) 320 A_sparse = to_sparse_semi_structured(A) 321 322 B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype) 323 324 # Currently we don't support int matmul on GPU, so evaluate on CPU and copy over 325 if dtype is torch.int8 and dense_input_shape in {(1, 128)}: 326 # padding with int8 throws an error because transposing B yields a contiguous output 327 # and row-row 2:4 sparse @ dense with NN is not supported by cuSPARSELt or CUTLASS. 328 if backend == "cutlass": 329 with self.assertRaisesRegex(RuntimeError, "spgemm_cutlass_dispatch_layouts"): 330 sparse_result = torch.mm(A_sparse, B.t()) 331 else: 332 with self.assertRaisesRegex(RuntimeError, 333 "CUDA error: operation not supported when calling `cusparseLtMatmulDescriptorInit"): 334 sparse_result = torch.mm(A_sparse, B.t()) 335 elif dtype is torch.int8: 336 # test transpose 337 dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int8) 338 sparse_result = torch.mm(A_sparse, B.t()) 339 torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3) 340 else: 341 # test transpose 342 dense_result = torch.mm(A, B.t()) 343 sparse_result = torch.mm(A_sparse, B.t()) 344 torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3) 345 346 @inference_dtypes 347 @parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)]) 348 @parametrize_backends 349 def test_mm_sparse_first_TN(self, dtype, dense_input_shape, device, backend): 350 """ 351 Ensure torch.mm(A_sparse.t(), B) throws error 352 """ 353 SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") 354 if backend == "cutlass" and IS_WINDOWS: 355 self.skipTest("CUTLASS not supported on Windows") 356 A = rand_sparse_semi_structured_mask(128, 256, dtype=dtype) 357 A_sparse = to_sparse_semi_structured(A) 358 359 B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype) 360 361 with self.assertRaisesRegex( 362 NotImplementedError, 363 r"`SparseSemiStructuredTensor.*` matmul: operation is not supported", 364 ): 365 torch.mm(A_sparse.t(), B) 366 367 @inference_dtypes 368 @parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)]) 369 @parametrize_backends 370 def test_mm_sparse_second_NT(self, dense_input_shape, dtype, device, backend): 371 """ 372 Ensure torch.mm(A, B_sparse.t()) is correct 373 """ 374 SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") 375 if backend == "cutlass" and IS_WINDOWS: 376 self.skipTest("CUTLASS not supported on Windows") 377 B = rand_sparse_semi_structured_mask(256, 128, dtype=dtype) 378 B_sparse = to_sparse_semi_structured(B) 379 380 A = torch.rand(dense_input_shape, device=B_sparse.device).to(dtype) 381 382 # Currently we don't support int matmul on GPU, so evaluate on CPU and copy over 383 if dtype is torch.int8: 384 dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int8) 385 sparse_result = torch.mm(A, B_sparse.t()) 386 else: 387 dense_result = torch.mm(A, B.t()) 388 sparse_result = torch.mm(A, B_sparse.t()) 389 390 torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3) 391 392 @inference_dtypes 393 @parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)]) 394 @parametrize_backends 395 def test_mm_sparse_second_NN(self, dense_input_shape, dtype, device, backend): 396 """ 397 Ensure torch.mm(A, B_sparse) throws error 398 """ 399 SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") 400 if backend == "cutlass" and IS_WINDOWS: 401 self.skipTest("CUTLASS not supported on Windows") 402 B = rand_sparse_semi_structured_mask(256, 128, dtype=dtype) 403 B_sparse = to_sparse_semi_structured(B) 404 405 A = torch.rand(dense_input_shape, device=B_sparse.device).to(dtype) 406 407 with self.assertRaisesRegex( 408 NotImplementedError, 409 r"`SparseSemiStructuredTensor.*` matmul: operation is not supported", 410 ): 411 sparse_result = torch.mm(A, B_sparse) 412 413 @parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128), (64, 128, 128)]) 414 @parametrize("inference_mode", [subtest(True), subtest(False)]) 415 @parametrize_backends 416 def test_linear(self, dense_input_shape, inference_mode, device, backend): 417 """ 418 Test nn.Linear has the same numerics 419 """ 420 SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") 421 if backend == "cutlass" and IS_WINDOWS: 422 self.skipTest("CUTLASS not supported on Windows") 423 input = torch.rand((dense_input_shape), device=device).half() 424 model = nn.Linear(128, 256).to(device).half() 425 m, n = model.weight.shape 426 mask = rand_sparse_semi_structured_mask(m, n, device=device, dtype=torch.bool) 427 # set masked weight 428 model.weight = nn.Parameter(model.weight * mask) 429 430 dense_result = model(input) 431 432 model.weight = nn.Parameter(to_sparse_semi_structured(model.weight)) 433 434 if inference_mode: 435 with torch.inference_mode(): 436 sparse_result = model(input) 437 else: 438 sparse_result = model(input) 439 440 torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3) 441 442 @parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128), (64, 128, 128)]) 443 @parametrize_backends 444 def test_mlp(self, device, dense_input_shape, backend): 445 SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") 446 input = torch.rand(dense_input_shape, device=device).half() 447 model = ( 448 nn.Sequential( 449 nn.Linear(128, 256), 450 nn.Linear(256, 128), 451 ) 452 .half() 453 .to(device) 454 ) 455 456 for i in range(2): 457 m, n = model[i].weight.shape 458 mask = rand_sparse_semi_structured_mask( 459 m, n, device=device, dtype=torch.bool 460 ) 461 # set masked weight 462 model[i].weight = nn.Parameter(model[i].weight * mask) 463 464 dense_result = model(input) 465 466 for i in range(2): 467 model[i].weight = nn.Parameter(to_sparse_semi_structured(model[i].weight)) 468 469 sparse_result = model(input) 470 471 torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3) 472 473 @parametrize_backends 474 def test_values(self, backend): 475 SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") 476 if backend == "cutlass" and IS_WINDOWS: 477 self.skipTest("CUTLASS not supported on Windows") 478 A = rand_sparse_semi_structured_mask(128, 128) 479 A_sparse = to_sparse_semi_structured(A) 480 assert A_sparse.values().shape == (128, 64) 481 assert (A_sparse.values() == 1).all() 482 483 @parametrize_backends 484 def test_indices(self, backend): 485 SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") 486 if backend == "cutlass" and IS_WINDOWS: 487 self.skipTest("CUTLASS not supported on Windows") 488 A = rand_sparse_semi_structured_mask(128, 128) 489 A_sparse = to_sparse_semi_structured(A) 490 assert A_sparse.indices().shape == (128, 8) 491 492 @inference_dtypes 493 @parametrize_backends 494 def test_min_sparse_shape(self, dtype, device, backend): 495 SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") 496 config = SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend]._DTYPE_SHAPE_CONSTRAINTS[dtype] 497 A = rand_sparse_semi_structured_mask(config.sparse_min_rows, config.sparse_min_cols, dtype=dtype, device=device) 498 A_sparse = to_sparse_semi_structured(A) 499 B = torch.rand((config.sparse_min_cols, config.dense_min_cols), device=device).to(dtype) 500 if dtype == torch.int8: 501 dense_res = torch.mm(A.cpu(), B.cpu()).to(device, dtype=torch.int8) 502 # int8 sparse matmul not supported for R/R -> R layout, so we transpose one of the arguments to get R/C -> R 503 B_t = B.t().contiguous() 504 sparse_res = torch.mm(A_sparse, B_t.t()) 505 else: 506 dense_res = torch.mm(A, B) 507 sparse_res = torch.mm(A_sparse, B) 508 torch.testing.assert_close(sparse_res, dense_res, rtol=1e-3, atol=1e-3) 509 510 @inference_dtypes 511 @parametrize_backends 512 def test_unsupported_shape(self, dtype, device, backend): 513 SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") 514 if backend == "cutlass" and IS_WINDOWS: 515 self.skipTest("CUTLASS not supported on Windows") 516 A = rand_sparse_semi_structured_mask(2, 2, dtype=dtype, device=device) 517 with self.assertRaisesRegex(RuntimeError, "Error original_tensor.shape"): 518 A_sparse = to_sparse_semi_structured(A) 519 520 @dtypes(*all_types_and_complex()) 521 @parametrize_backends 522 def test_unsupported_dtype(self, dtype, device, backend): 523 SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") 524 if backend == "cutlass" and IS_WINDOWS: 525 self.skipTest("CUTLASS not supported on Windows") 526 A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype, device=device) 527 528 if dtype not in SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend]._DTYPE_SHAPE_CONSTRAINTS: 529 with self.assertRaisesRegex(RuntimeError, "Error original_tensor.dtype"): 530 A_sparse = to_sparse_semi_structured(A) 531 else: 532 A_sparse = to_sparse_semi_structured(A) 533 534 @parametrize_backends 535 def test_unsupported_dim(self, device, backend): 536 SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") 537 if backend == "cutlass" and IS_WINDOWS: 538 self.skipTest("CUTLASS not supported on Windows") 539 A = torch.rand(128, 128, 128, device=device, dtype=torch.float16) 540 541 with self.assertRaisesRegex(RuntimeError, "Error original_tensor.dim"): 542 A_sparse = to_sparse_semi_structured(A) 543 544 545def create_random_mask(shape) -> torch.Tensor: 546 r = random.Random(0) 547 mask = torch.zeros(shape, dtype=torch.bool) 548 for line in range(mask.shape[0]): 549 for col in range(0, mask.shape[1], 4): 550 sparsity = r.choice( 551 [ 552 [False, False, True, True], 553 [False, True, False, True], 554 [True, False, False, True], 555 [False, True, True, False], 556 [True, False, True, False], 557 [True, True, False, False], 558 ] 559 ) 560 mask[line, col : col + 4] = torch.tensor(sparsity, dtype=torch.bool) 561 return mask 562 563class TestSparseSemiStructuredTraining(TestCase): 564 565 def setUp(self): 566 if not _IS_SM8X: 567 self.skipTest("SparseSemiStructuredTensor training only supported on SM8x (Ampere)") 568 569 if IS_WINDOWS: 570 self.skipTest('CUTLASS not supported on windows') 571 572 573 @training_dtypes 574 def test_prune_dense_static_sort(self, dtype) -> None: 575 # Ideally we would like to clone and compare, but that won't work because the sorting order will be different 576 # instead we pass the pruned matrix to the CUDA implementation and preserve the sparsity pattern. 577 dense = torch.randn(128, 128, device="cuda", dtype=dtype) 578 pruned = _sparse_semi_structured_tile(dense) 579 580 # CUTLASS 581 reference_cutlass = SparseSemiStructuredTensorCUTLASS.prune_dense_static_sort(pruned, algorithm="largest_abs_values_greedy") 582 torch.testing.assert_close(pruned, reference_cutlass.to_dense()) 583 584 packed_cutlass, meta_cutlass = sparse_semi_structured_from_dense_cutlass(pruned) 585 packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass(pruned.t().contiguous()) 586 meta_cutlass = meta_cutlass.as_strided(reference_cutlass.meta.shape, reference_cutlass.meta.stride()) 587 meta_t_cutlass = meta_t_cutlass.as_strided(reference_cutlass.meta_t.shape, reference_cutlass.meta_t.stride()) 588 compressed_swizzled_bitmask = _compute_compressed_swizzled_bitmask(pruned) 589 compressed_swizzled_bitmask = compressed_swizzled_bitmask.as_strided(reference_cutlass.compressed_swizzled_bitmask.shape, 590 reference_cutlass.compressed_swizzled_bitmask.stride()) 591 cutlass = SparseSemiStructuredTensorCUTLASS(dense.shape, 592 packed_cutlass, 593 meta_cutlass, 594 packed_t_cutlass, 595 meta_t_cutlass, 596 compressed_swizzled_bitmask) 597 torch.testing.assert_close(reference_cutlass.to_dense(), cutlass.to_dense()) 598 599 # CUSPARSELT 600 reference_cusparselt = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(pruned, 601 algorithm="largest_abs_values_greedy") 602 torch.testing.assert_close(pruned, reference_cusparselt.to_dense()) 603 604 packed_cusparselt = torch._cslt_compress(pruned) 605 packed_t_cusparselt = torch._cslt_compress(pruned.t().contiguous()) 606 cusparselt = SparseSemiStructuredTensorCUSPARSELT(dense.shape, 607 packed_cusparselt, 608 None, 609 packed_t_cusparselt, 610 None, 611 compressed_swizzled_bitmask) 612 torch.testing.assert_close(reference_cusparselt.to_dense(), cusparselt.to_dense()) 613 614 615 616 @training_dtypes 617 @parametrize_backends 618 def test_pruning_algo_largest_abs_values_greedy(self, dtype, backend) -> None: 619 inp = torch.tensor( 620 [[4, 3, 2, 1], [-1, -3, 0.6, 0.5], [1, 2, 3, 4], [10, 2, -1, 5]], 621 device="cuda", 622 dtype=dtype, 623 ) 624 inp = F.pad(inp, (0, 128 - 4, 0, 128 - 4), "constant", 1) 625 sInp = SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend].prune_dense_static_sort(inp, algorithm="largest_abs_values_greedy") 626 627 mask = sInp.to_dense() / inp 628 assert mask[:4, :4].int().tolist() == [ 629 [1, 1, 0, 0], 630 [0, 1, 1, 0], 631 [0, 0, 1, 1], 632 [1, 0, 0, 1], 633 ] 634 635 @training_dtypes 636 def test_gemm(self, dtype) -> None: 637 M, N, K = 32, 32, 64 638 a = torch.randn([M, K], device="cuda", dtype=dtype) 639 b = torch.randn([K, N], device="cuda", dtype=dtype) 640 mask = rand_sparse_semi_structured_mask(M, K, dtype=torch.bool) 641 642 a.masked_fill_(~mask, 0) 643 644 a_sparse = to_sparse_semi_structured(a) 645 646 masked_a = a * mask 647 ref_out = masked_a @ b 648 sp24_out = a_sparse @ b 649 torch.testing.assert_close(ref_out, sp24_out, **atol_rtol_kw[dtype]) 650 651 652 @training_dtypes 653 @parametrize_backends 654 def test_pack_both_ways_meta_correctness(self, dtype, backend) -> None: 655 M, N = 128, 256 656 # Construct x to make sure we always have exactly 8 elements per 4x4 tile 657 a = (4 * torch.arange(8))[:, None] + torch.arange(8)[None, :] 658 a = a.repeat(M // 8, N // 8) 659 assert a.shape == (M, N) 660 a = a.cuda().to(dtype) 661 b = torch.randn([a.shape[1], 128], device="cuda", dtype=dtype) 662 663 a_sparse = SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend].prune_dense_static_sort(a) 664 665 mask_dense = sparse24_largest_mask_2d(a).to(dtype) 666 667 if backend == "cutlass": 668 assert isinstance(a_sparse, SparseSemiStructuredTensorCUTLASS) 669 (packed, meta, packed_t, meta_t, bitmask) = torch._sparse_semi_structured_tile( 670 mask_dense, use_cutlass=True) 671 672 sparse_mask = SparseSemiStructuredTensorCUTLASS( 673 mask_dense.shape, 674 packed=packed, 675 meta=meta, 676 packed_t=packed_t, 677 meta_t=meta_t, 678 compressed_swizzled_bitmask=bitmask, 679 ) 680 torch.testing.assert_close(a_sparse.meta.view(torch.short), sparse_mask.meta) 681 682 ref_gemm = (mask_dense * a) @ b 683 pack_gemm = a_sparse @ b 684 torch.testing.assert_close(ref_gemm, pack_gemm, **atol_rtol_kw[dtype]) 685 686 @training_dtypes 687 def test_pack_both_ways_id(self, dtype) -> None: 688 N = 512 689 torch.manual_seed(0) 690 a = torch.randn([N, N], dtype=dtype, device="cuda") 691 b = torch.eye(N, dtype=dtype, device="cuda") 692 693 packed, meta, packed_t, meta_t = torch._sparse_semi_structured_tile(a)[ 694 :4 695 ] 696 # Heuristic to ensure we pack the same values 697 torch.testing.assert_close( 698 packed.to(torch.float64).sum(), packed_t.to(torch.float64).sum() 699 ) 700 701 mask_dense = sparse24_largest_mask_2d(a.to(dtype)) 702 703 ref_gemm = mask_dense * a 704 # Test A@B 705 pack_gemm = torch._sparse_semi_structured_linear(b.t(), packed, meta).t() 706 max_diff = (ref_gemm - pack_gemm).abs().argmax() 707 torch.testing.assert_close( 708 ref_gemm, pack_gemm, 709 **atol_rtol_kw[dtype] 710 ), f"packed is wrong at pos: ({max_diff // N}, {max_diff % N})" 711 # Test A.t@B 712 pack_gemm = torch._sparse_semi_structured_linear(b.t(), packed_t, meta_t) 713 max_diff = (ref_gemm - pack_gemm).abs().argmax() 714 715 torch.testing.assert_close( 716 ref_gemm, pack_gemm, 717 **atol_rtol_kw[dtype] 718 ), f"packed_t is wrong at pos: ({max_diff // N}, {max_diff % N})" 719 720 @training_dtypes 721 def test_pack_both_ways_edge_case1(self, dtype) -> None: 722 # In this case, the heuristic will keep 7 values out of 16 723 # instead of 8. let's see how the kernel handles this 724 quad = torch.tensor( 725 [ 726 [2, -1, -2, -3], # Should be packed as `2 <null>` 727 [-1, 8, -1, 6], 728 [-1, -1, 4, 5], 729 [-1, 3, 7, -1], 730 ], 731 dtype=dtype, 732 device="cuda", 733 ) 734 a = torch.randn([32, 64], dtype=dtype, device="cuda") 735 a[:4, :4] = quad 736 packed, meta, packed_t, meta_t = torch._sparse_semi_structured_tile(a)[:4] 737 # Check first line in A 738 assert packed[0, 0].item() == 2 739 assert packed[0, 1].item() == 0 740 # And first column in A.t 741 assert packed_t[0, 0].item() == 2 742 assert packed_t[0, 1].item() == 0 743 744 @training_dtypes 745 def test_sp24_apply(self, dtype) -> None: 746 M, N = 256, 1024 747 x = torch.randn([M, N], dtype=dtype, device="cuda") 748 ( 749 packed, 750 meta, 751 packed_t, 752 meta_t, 753 bitmask, 754 ) = torch._sparse_semi_structured_tile(x) 755 packed2, packed_t2 = torch._sparse_semi_structured_apply(x, bitmask) 756 torch.testing.assert_close(packed, packed2) 757 torch.testing.assert_close(packed_t, packed_t2) 758 759 @training_dtypes 760 def test_sp24_apply_dense(self, dtype) -> None: 761 M, N = 256, 1024 762 x = torch.randn([M, N], dtype=dtype, device="cuda") 763 ( 764 packed, 765 meta, 766 packed_t, 767 meta_t, 768 bitmask, 769 ) = torch._sparse_semi_structured_tile(x) 770 771 expected = SparseSemiStructuredTensorCUTLASS( 772 x.shape, 773 packed=packed, 774 meta=meta, 775 packed_t=packed_t, 776 meta_t=meta_t, 777 compressed_swizzled_bitmask=bitmask, 778 ).to_dense() 779 780 packed2, packed_t2 = torch._sparse_semi_structured_apply(x, bitmask) 781 sparse = SparseSemiStructuredTensorCUTLASS( 782 x.shape, 783 packed=packed2, 784 meta=meta, 785 packed_t=packed_t2, 786 meta_t=meta_t, 787 compressed_swizzled_bitmask=bitmask, 788 ) 789 790 dense = torch._sparse_semi_structured_apply_dense(x, bitmask) 791 792 torch.testing.assert_close(dense, expected) 793 torch.testing.assert_close(sparse.to_dense(), expected) 794 795 796 @training_dtypes 797 def test_sp24_matmuls(self, dtype) -> None: 798 M, N, K = 64, 256, 1024 799 a = torch.randn([M, K], device="cuda", dtype=dtype) 800 b = torch.randn([K, N], device="cuda", dtype=dtype) 801 a_m = sparse24_largest_mask_2d(a) 802 b_m = sparse24_largest_mask_2d(b) 803 (packed, meta, packed_t, meta_t, bitmask) = torch._sparse_semi_structured_tile(a) 804 a_s = SparseSemiStructuredTensorCUTLASS( 805 a.shape, 806 packed=packed, 807 meta=meta, 808 packed_t=packed_t, 809 meta_t=meta_t, 810 compressed_swizzled_bitmask=bitmask, 811 ) 812 (packed, meta, packed_t, meta_t, bitmask) = torch._sparse_semi_structured_tile(b) 813 b_s = SparseSemiStructuredTensorCUTLASS( 814 b.shape, 815 packed=packed, 816 meta=meta, 817 packed_t=packed_t, 818 meta_t=meta_t, 819 compressed_swizzled_bitmask=bitmask, 820 ) 821 822 torch.testing.assert_close(a_s @ b, (a * a_m) @ b, rtol=1e-1, atol=1.5e-1) 823 torch.testing.assert_close(a @ b_s, a @ (b * b_m), rtol=1e-1, atol=1.5e-1) 824 torch.testing.assert_close( 825 a @ a_s.t(), a @ (a * a_m).t(), rtol=1e-1, atol=1.5e-1 826 ) 827 torch.testing.assert_close( 828 a_s.t() @ a, (a * a_m).t() @ a, rtol=1e-1, atol=1e-1 829 ) 830 831 def test_sp24_matmuls_mat_vec(self) -> None: 832 a = torch.randn([64, 128], device="cuda", dtype=torch.float16) 833 b = torch.randn([128], device="cuda", dtype=torch.float16) 834 a_m = sparse24_largest_mask_2d(a) 835 a_s = to_sparse_semi_structured(a) 836 837 with pytest.raises(NotImplementedError): 838 torch.testing.assert_close(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype]) 839 840 841 def test_sp24_matmuls_bmm(self) -> None: 842 a = torch.randn([64, 128], device="cuda", dtype=torch.float16) 843 b = torch.randn([5, 6, 128], device="cuda", dtype=torch.float16) 844 a_m = sparse24_largest_mask_2d(a) 845 a_s = to_sparse_semi_structured(a) 846 847 with pytest.raises(NotImplementedError): 848 torch.testing.assert_close(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype]) 849 850class TestSparseSemiStructuredCUTLASS(TestCase): 851 """ 852 This contains CUTLASS specific tests for 853 - torch._sparse_semi_structured_linear 854 """ 855 def setUp(self): 856 if "cutlass" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS: 857 self.skipTest('CUTLASS not enabled') 858 859 @unittest.skipIf(TEST_WITH_ROCM or IS_WINDOWS, "ROCm and Windows doesn't support CUTLASS") 860 @inference_dtypes 861 def test_linear_cutlass(self, device, dtype): 862 863 def run_test(batch_shape, m, n, k, device, dtype, dtype_out, add_bias, activation, rtol, atol): 864 weight = rand_sparse_semi_structured(m, k, dtype, device) 865 input = make_tensor((*batch_shape, n, k), dtype=dtype, device=device) 866 bias = make_tensor((m,), dtype=dtype_out, device=device) if add_bias else None 867 868 dtype_dense = torch.float32 869 input_dense = input.to(dtype_dense) 870 weight_dense = weight.to(dtype_dense) 871 bias_dense = bias.to(dtype_dense) if add_bias else None 872 output0 = torch.nn.functional.linear(input_dense, weight_dense, bias=bias_dense) 873 if activation == "relu": 874 relu = torch.nn.ReLU() 875 output0 = relu(output0) 876 elif activation == "silu": 877 silu = torch.nn.SiLU() 878 output0 = silu(output0) 879 880 compressed = to_sparse_semi_structured(weight) 881 882 weight_sparse = compressed.values() 883 meta = compressed.indices() 884 885 output1 = torch._sparse_semi_structured_linear(input, weight_sparse, meta, bias=bias, activation=activation, 886 out_dtype=dtype_out if dtype == torch.int8 else None) 887 torch.testing.assert_close(output1.to(dtype_dense), output0, rtol=rtol, atol=atol) 888 889 if dtype == torch.float32: 890 # Inputs are converted to TF32 internally for sparse GEMM, 891 # so make dense GEMM to do the same for matching results. 892 orig = torch.backends.cuda.matmul.allow_tf32 893 torch.backends.cuda.matmul.allow_tf32 = True 894 895 batch_shapes = [[], [3], [3, 1]] 896 dtype_out = {torch.int8: torch.int32, torch.half: torch.half, torch.bfloat16: torch.bfloat16, torch.float32: torch.float32} 897 activations = [None, "relu", "silu"] 898 rtol, atol = 1e-3, 1e-3 899 if dtype == torch.bfloat16: 900 rtol, atol = 5e-3, 5e-3 901 elif dtype == torch.float32: 902 rtol, atol = 1e-3, 75e-2 903 for batch_shape, m, n, k, add_bias, activation in \ 904 itertools.product(batch_shapes, range(3), range(3), range(3), (False, True), activations): 905 if activation == "silu" and dtype == torch.int8: 906 continue # SiLU not supported for integer inputs 907 908 m = 2 ** m * 32 909 n = 2 ** n * 32 910 k = 2 ** k * 128 911 run_test(batch_shape, m, n, k, device, dtype, dtype_out[dtype], add_bias, activation, rtol, atol) 912 913 if dtype == torch.float32: 914 torch.backends.cuda.matmul.allow_tf32 = orig 915 916 917 @unittest.skipIf(TEST_WITH_ROCM or IS_WINDOWS, "ROCm and Windows doesn't support CUTLASS") 918 @parametrize("backend", ["cutlass"]) 919 @inference_dtypes 920 def test_sparse_semi_structured_ops_cutlass(self, device, dtype, backend): 921 SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") 922 if backend == "cutlass" and IS_WINDOWS: 923 self.skipTest("CUTLASS not supported on Windows") 924 925 def run_test(m, n, k, device, dtype, dtype_out, use_input, rtol, atol): 926 mat1 = rand_sparse_semi_structured(m, k, dtype, device) 927 # mat2 transposed as int8 case supports only row-major/column-major combination 928 mat2 = make_tensor((n, k), dtype=dtype, device=device).t() 929 input = make_tensor((m,), dtype=dtype_out, device=device) if use_input else None 930 931 if use_input: 932 if dtype.is_floating_point: 933 alpha = 1.3 934 beta = -0.7 935 else: 936 alpha = 2 937 beta = -3 938 939 dtype_dense = torch.float32 940 mat1_dense = mat1.to(dtype_dense) 941 mat2_dense = mat2.to(dtype_dense) 942 if not use_input: 943 output0 = torch.mm(mat1_dense, mat2_dense) 944 else: 945 input_dense = input.to(dtype_dense)[:, None] 946 output0 = torch.addmm(input_dense, mat1_dense, mat2_dense, alpha=alpha, beta=beta) 947 948 compressed = to_sparse_semi_structured(mat1) 949 950 mat1_sparse = compressed.values() 951 mat1_meta = compressed.indices() 952 953 if not use_input: 954 output1 = torch._sparse_semi_structured_mm(mat1_sparse, mat1_meta, mat2, out_dtype=dtype_out) 955 else: 956 output1 = torch._sparse_semi_structured_addmm( 957 input, mat1_sparse, mat1_meta, mat2, alpha=alpha, beta=beta, out_dtype=dtype_out 958 ) 959 torch.testing.assert_close(output1.to(dtype_dense), output0, rtol=rtol, atol=atol) 960 961 if dtype == torch.float32: 962 # Inputs are converted to TF32 internally for sparse GEMM, 963 # so make dense GEMM to do the same for matching results. 964 orig = torch.backends.cuda.matmul.allow_tf32 965 torch.backends.cuda.matmul.allow_tf32 = True 966 967 dtype_out = {torch.int8: torch.int32, torch.half: torch.half, torch.bfloat16: torch.bfloat16, torch.float32: torch.float32} 968 rtol, atol = 1e-3, 1e-3 969 if dtype == torch.bfloat16: 970 rtol, atol = 5e-3, 5e-3 971 elif dtype == torch.float32: 972 rtol, atol = 1e-3, 75e-2 973 for m, n, k, use_input in \ 974 itertools.product(range(3), range(3), range(3), (False, True)): 975 m = 2 ** m * 32 976 n = 2 ** n * 32 977 k = 2 ** k * 128 978 run_test(m, n, k, device, dtype, dtype_out[dtype], use_input, rtol, atol) 979 980 if dtype == torch.float32: 981 torch.backends.cuda.matmul.allow_tf32 = orig 982 983 984 @unittest.skipIf(not has_triton(), "Test needs triton and recent GPU arch") 985 @inference_dtypes 986 def test_conversions(self, device, dtype): 987 988 def run_test(r, c, device, dtype): 989 dense_ref = rand_sparse_semi_structured(r, c, dtype, device) 990 991 compressed = to_sparse_semi_structured(dense_ref) 992 993 # The torch.ops.aten._to_sparse_semi_structured operator 994 # uses CUTLASS to perform conversion from given dense 995 # matrix to the pair of corresponding sparse and metadata 996 # matrices, with the later used here as a reference to 997 # compare the metadata matrix produced by conversion 998 # performed by SparseSemiStructuredTensor class 999 # constructor against. 1000 _, meta_ref = torch.ops.aten._to_sparse_semi_structured(dense_ref) 1001 1002 meta = compressed.indices() 1003 torch.testing.assert_close(meta, meta_ref, rtol=0, atol=0) 1004 1005 dense = compressed.to_dense() 1006 torch.testing.assert_close(dense, dense_ref, rtol=0, atol=0) 1007 1008 shapes = [[32, 128], [32, 256], [64, 128], [64, 256]] 1009 for r, c in shapes: 1010 run_test(r, c, device, dtype) 1011 1012 @unittest.skipIf(not has_triton(), "Test needs triton and recent GPU arch") 1013 @inference_dtypes 1014 def test_conversions_all_patterns(self, device, dtype): 1015 r, c = 32, 128 1016 1017 dense_inv, dense_val = rand_sparse_semi_structured_all_patterns(r, c, dtype, device) 1018 1019 compressed = to_sparse_semi_structured(dense_inv) 1020 dense = compressed.to_dense() 1021 1022 torch.testing.assert_close(dense, dense_val, rtol=0, atol=0) 1023 1024 1025 1026CUSPARSELT_NUM_ALG_IDS = 4 1027CUSPARSELT_MIXED_DTYPE_SUPPORT = [torch.float16, torch.bfloat16, torch.int32] 1028 1029 1030class TestSparseSemiStructuredCUSPARSELT(TestCase): 1031 """ 1032 This contains cuSPARSELt specific tests for 1033 torch._cslt_compress 1034 torch._cslt_sparse_mm 1035 """ 1036 def setUp(self): 1037 if "cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS: 1038 self.skipTest('cuSPARSELt not enabled') 1039 1040 @parametrize("out_dtype", CUSPARSELT_MIXED_DTYPE_SUPPORT) 1041 @parametrize("dense_input_shape", [(128, 128)]) 1042 def test_cslt_sparse_mm_mixed_dtype(self, dense_input_shape, out_dtype, device): 1043 A = rand_sparse_semi_structured_mask(128, 128, dtype=torch.int8) 1044 A_compressed = torch._cslt_compress(A) 1045 1046 B = torch.rand(dense_input_shape, device=device).to(torch.int8) 1047 1048 dense_result = torch.mm(A.cpu().to(torch.int64), B.t().cpu().to(torch.int64)).to(device, dtype=out_dtype) 1049 sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), out_dtype=out_dtype) 1050 torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3) 1051 1052 @unittest.skip("cuSPARSELt v0.6.x does not support bfloat/float16 alpha scaling") 1053 @training_dtypes 1054 def test_cslt_sparse_mm_alpha(self, dtype, device): 1055 A = torch.Tensor([0, 0, 1, 1]).tile((128, 64)).to(dtype).cuda() 1056 B = torch.ones((256, 128), device=device).to(dtype) 1057 alpha = torch.Tensor([2**(-i) for i in range(128)]).cuda() 1058 bias = torch.ones(128, device=device).to(dtype) 1059 1060 A_compressed = torch._cslt_compress(A) 1061 sparse_result = torch._cslt_sparse_mm(A_compressed, B, alpha=alpha, bias=bias) 1062 1063 alpha_scaled = torch.stack([alpha] * 128).t() 1064 dense_result = alpha_scaled * torch.mm(A.to(torch.float32), B.to(torch.float32)) 1065 dense_result = dense_result.to(dtype) 1066 1067 torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3) 1068 1069 @parametrize("out_dtype", CUSPARSELT_MIXED_DTYPE_SUPPORT) 1070 def test_cslt_sparse_mm_alpha_mixed_dtype(self, out_dtype, device): 1071 A = torch.Tensor([0, 0, 10, 10]).tile((128, 64)).to(torch.int8).cuda() 1072 B = torch.ones((128, 256), device=device).to(torch.int8).t() 1073 alpha = torch.Tensor([2**(-i) if out_dtype is not torch.int32 else 1 1074 for i in range(128)]).cuda() 1075 1076 A_compressed = torch._cslt_compress(A) 1077 sparse_result = torch._cslt_sparse_mm(A_compressed, B, alpha=alpha, out_dtype=out_dtype).cpu() 1078 1079 alpha_scaled = torch.stack([alpha] * 128).t() 1080 dense_result = alpha_scaled.cpu() * torch.mm(A.to(torch.int64).cpu(), B.to(torch.int64).cpu()) 1081 dense_result = dense_result.to(out_dtype) 1082 1083 torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3) 1084 1085 @parametrize("alg_id", range(CUSPARSELT_NUM_ALG_IDS)) 1086 @inference_dtypes 1087 def test_cslt_sparse_mm_alg_id(self, device, dtype, alg_id): 1088 # alg_id=3 not supported for float32 dtype 1089 if dtype == torch.float32 and alg_id == 3: 1090 return 1091 A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype) 1092 A_compressed = torch._cslt_compress(A) 1093 B = torch.ones((128, 128), device=device).to(dtype) 1094 1095 A_compressed = torch._cslt_compress(A) 1096 sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), alg_id=alg_id) 1097 1098 dense_result = torch.mm(A.to(torch.float32), B.to(torch.float32)) 1099 dense_result = dense_result.to(dtype) 1100 1101 torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3) 1102 1103 @inference_dtypes 1104 def test_cslt_sparse_mm_search(self, device, dtype): 1105 A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype) 1106 A_compressed = torch._cslt_compress(A) 1107 B = torch.ones((128, 128), device=device).to(dtype) 1108 1109 A_compressed = torch._cslt_compress(A) 1110 alg_id = torch._cslt_sparse_mm_search(A_compressed, B.t()) 1111 # for cuSPARSELt v0.4.0 there is a bug where although there are 5 alg_ids, we run into an error 1112 # when setting using the last one (4) 1113 # in cuSPARSELt v0.5.0 there are only 4 alg_ids total, so we should remove the +1 here when we update. 1114 # TODO Move this into the cuSPARSELt backendk 1115 assert alg_id in range(CUSPARSELT_NUM_ALG_IDS + 1) 1116 1117 def test_cusparselt_backend(self): 1118 version = _get_torch_cuda_version() 1119 assert torch.backends.cusparselt.is_available() 1120 1121 # CUDA 11.8 has cuSPARSELt v0.4.0 support 1122 if version == (11, 8): 1123 assert torch.backends.cusparselt.version() == 400 1124 # CUDA 12.1 has cuSPARSELt v0.5.2 support 1125 elif version == (12, 1): 1126 assert torch.backends.cusparselt.version() == 502 1127 # CUDA 12.4+ has cuSPARSELt v0.6.2 support 1128 elif version >= (12, 4): 1129 assert torch.backends.cusparselt.version() == 602 1130 else: 1131 assert torch.backends.cusparselt.version() is None 1132 1133if len(SEMI_STRUCTURED_SUPPORTED_BACKENDS) > 0: 1134 instantiate_device_type_tests(TestSparseSemiStructured, globals(), only_for="cuda") 1135if "cutlass" in SEMI_STRUCTURED_SUPPORTED_BACKENDS: 1136 instantiate_device_type_tests(TestSparseSemiStructuredCUTLASS, globals(), only_for="cuda") 1137 instantiate_device_type_tests(TestSparseSemiStructuredTraining, globals(), only_for="cuda") 1138if "cusparselt" in SEMI_STRUCTURED_SUPPORTED_BACKENDS: 1139 instantiate_device_type_tests(TestSparseSemiStructuredCUSPARSELT, globals(), only_for="cuda") 1140 1141if __name__ == "__main__": 1142 run_tests() 1143