xref: /aosp_15_r20/external/pytorch/test/test_sparse_semi_structured.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: sparse"]
2import itertools
3import random
4import unittest
5
6import torch
7from torch import nn
8import torch.nn.functional as F
9
10from torch.sparse import (
11    SparseSemiStructuredTensor,
12    SparseSemiStructuredTensorCUSPARSELT,
13    SparseSemiStructuredTensorCUTLASS,
14    to_sparse_semi_structured,
15)
16
17from torch.sparse._semi_structured_conversions import (
18    sparse_semi_structured_from_dense_cutlass,
19    _sparse_semi_structured_tile,
20    _compute_compressed_swizzled_bitmask,
21)
22
23from torch.testing import make_tensor
24from torch.testing._internal.common_cuda import _get_torch_cuda_version
25from torch.testing._internal.common_device_type import (
26    dtypes,
27    instantiate_device_type_tests,
28)
29
30from torch.testing._internal.common_dtype import all_types_and_complex
31import torch._dynamo.test_case
32from torch.testing._internal.common_utils import (
33    parametrize,
34    run_tests,
35    subtest,
36    TestCase,
37    TEST_WITH_ROCM,
38    IS_WINDOWS,
39)
40
41import pytest
42
43from torch.utils._triton import has_triton
44
45SEMI_STRUCTURED_SUPPORTED_BACKENDS = dict()
46
47_IS_SM8X = False
48_IS_SM9X = False
49
50if torch.cuda.is_available():
51    _IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8
52    _IS_SM9X = torch.cuda.get_device_capability(0)[0] == 9
53
54    # CUTLASS kernels only work for Ampere
55    if _IS_SM8X:
56        SEMI_STRUCTURED_SUPPORTED_BACKENDS["cutlass"] = SparseSemiStructuredTensorCUTLASS
57
58    # add cuSPASRELt tests if available
59    if torch.backends.cusparselt.is_available() and (_IS_SM8X or _IS_SM9X):
60        SEMI_STRUCTURED_SUPPORTED_BACKENDS["cusparselt"] = SparseSemiStructuredTensorCUSPARSELT
61
62inference_dtypes = dtypes(torch.float16, torch.bfloat16, torch.int8)
63training_dtypes = dtypes(torch.float16, torch.bfloat16)
64parametrize_backends = parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
65
66atol_rtol_kw = {
67    torch.float16: {
68        "rtol": 1e-3,
69        "atol": 1e-3,
70    },
71    torch.bfloat16: {
72        "rtol": 1e-1,
73        "atol": 1e-1,
74    },
75}
76
77def sparse24_largest_mask_2d(original):
78    sparse = SparseSemiStructuredTensorCUTLASS.prune_dense_static_sort(original)
79    return sparse.to_dense().bool()
80
81def sparsify24_dense(original):
82    return sparse24_largest_mask_2d(original) * original
83
84def rand_sparse_semi_structured_mask(
85    r, c, dtype=torch.float16, device="cuda", choice=None
86):
87    """
88    This function returns a 1:2 sparse matrix of size (r, c).
89    Note that this means this matrix will also be 2:4 and 4:8 sparse as well.
90    """
91
92    choices = [[0, 1], [1, 0]]
93    mask_entries = [choice or random.choice(choices) for i in range(r * c // 2)]
94
95    return (
96        torch.tensor(mask_entries, dtype=dtype, device=device)
97        .reshape(r, c)
98        .contiguous()
99    )
100
101def rand_sparse_semi_structured(r, c, dtype, device, choice=None):
102    pattern = '2by4' if dtype != torch.float32 else '1by2'
103    if pattern == '1by2':
104        ksparse = 2
105        choices = [
106            [0, 1],
107            [1, 0]
108        ]
109    elif pattern == '2by4':
110        ksparse = 4
111        choices = [
112            [1, 1, 0, 0],
113            [1, 0, 1, 0],
114            [1, 0, 0, 1],
115            [0, 1, 1, 0],
116            [0, 1, 0, 1],
117            [0, 0, 1, 1]
118        ]
119    mask_entries = [choice or random.choice(choices) for i in range(r * c // ksparse)]
120    mask = torch.tensor(mask_entries, dtype=torch.bool).view(r, c).to(device)
121    dense = make_tensor(r, c, dtype=dtype, device=device)
122    dense[dense == 0] = 1  # To prevent zeros except where mask applied.
123    dense = dense.masked_fill(~mask, 0)
124    return dense
125
126
127def rand_sparse_semi_structured_all_patterns(r, c, dtype, device):
128    pattern = '2by4' if dtype != torch.float32 else '1by2'
129    if pattern == '1by2':
130        ksparse = 2
131        choices = [
132            [[0, 0], [0, 1]],
133            [[0, 1], [0, 1]],
134            [[1, 0], [1, 0]],
135            [[1, 1], [1, 0]]
136        ]
137    elif pattern == '2by4':
138        ksparse = 4
139        choices = [
140            [[0, 0, 0, 0], [0, 0, 1, 1]],
141            [[0, 0, 0, 1], [0, 0, 1, 1]],
142            [[0, 0, 1, 0], [0, 0, 1, 1]],
143            [[0, 0, 1, 1], [0, 0, 1, 1]],
144            [[0, 1, 0, 0], [0, 1, 1, 0]],
145            [[0, 1, 0, 1], [0, 1, 0, 1]],
146            [[0, 1, 1, 0], [0, 1, 1, 0]],
147            [[0, 1, 1, 1], [0, 1, 0, 1]],
148            [[1, 0, 0, 0], [1, 0, 1, 0]],
149            [[1, 0, 0, 1], [1, 0, 0, 1]],
150            [[1, 0, 1, 0], [1, 0, 1, 0]],
151            [[1, 0, 1, 1], [1, 0, 0, 1]],
152            [[1, 1, 0, 0], [1, 1, 0, 0]],
153            [[1, 1, 0, 1], [1, 1, 0, 0]],
154            [[1, 1, 1, 0], [1, 1, 0, 0]],
155            [[1, 1, 1, 1], [1, 1, 0, 0]],
156        ]
157    mask_rows = [random.randint(0, len(choices) - 1) for i in range(r * c // ksparse)]
158
159    COL_INV, COL_VAL = 0, 1
160    mask_entries_inv = [choices[i][COL_INV] for i in mask_rows]
161    mask_entries_val = [choices[i][COL_VAL] for i in mask_rows]
162    mask_inv = torch.tensor(mask_entries_inv, dtype=torch.bool).view(r, c).to(device)
163    mask_val = torch.tensor(mask_entries_val, dtype=torch.bool).view(r, c).to(device)
164    dense = make_tensor(r, c, dtype=dtype, device=device)
165    dense[dense == 0] = 1   # To prevent zeros except where mask below applied.
166    dense_inv = dense.masked_fill(~mask_inv, 0)
167    dense_val = dense_inv.masked_fill(~mask_val, 0)
168
169    return dense_inv, dense_val
170
171
172class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
173
174    def setUp(self):
175        if len(SEMI_STRUCTURED_SUPPORTED_BACKENDS) == 0:
176            self.skipTest('semi-structured sparsity has no available backend!')
177        super().setUp()
178
179    def tearDown(self):
180        super().tearDown()
181
182    @staticmethod
183    def _test_mlp_contiguous_relu_compile(backend, dense_input_shape):
184        """
185        Test nn.Linear + .contiguous() + nn.ReLU with SparseSemiStructuredTensor + torch.compile
186        We expect:
187            (1) The sparse tensor subclass should turn nn.Linear into `aten._structured_sparse_addmm` + `aten.contiguous()`
188            (2) Inductor should fuse the .contiguous() call into the relu
189        """
190
191        class Model(nn.Module):
192            def __init__(self) -> None:
193                super().__init__()
194                self.linear = nn.Linear(128, 128)
195
196            def forward(self, x):
197                x = self.linear(x)
198                x = x.contiguous()
199                return torch.nn.functional.relu(x)
200
201        input = torch.rand(dense_input_shape, device="cuda").half()
202        model = Model().eval().cuda().half()
203        mod_linear = model.linear
204        m, n = mod_linear.weight.shape
205        mask = torch.Tensor([1, 0, 0, 1]).tile((m, n // 4)).bool().cuda()
206        # set masked weight
207        mod_linear.weight = nn.Parameter(mod_linear.weight * mask)
208
209        dense_result = model(input)
210        mod_linear.weight = nn.Parameter(SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend].from_dense(mod_linear.weight))
211        sparse_result = model(input)
212
213        model = torch.compile(model, backend="inductor", fullgraph=True)
214        sparse_compile_result = model(input)
215
216        # test that sparse_compile_result and dense_result are numerically close
217        torch.testing.assert_close(dense_result, sparse_compile_result, rtol=1e-3, atol=1e-3)
218        # assert sparse and sparse_compile have the same strides,
219        # as meta registrations may return contiguous tensors when the output is transposed
220        # https://github.com/pytorch/pytorch/pull/114477
221        assert sparse_result.stride() == sparse_compile_result.stride()
222
223    @unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
224    @unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine")
225    def test_mlp_contiguous_relu_compile_cusparselt(self):
226        """
227        test for cuSPASRELt meta registrations (_cslt_sparse_mm) + torch.compile
228        """
229        for dense_input_shape in [(1, 128), (64, 128), (128, 128), (64, 128, 128)]:
230            SparseSemiStructuredTensorCompileTest._test_mlp_contiguous_relu_compile("cusparselt", dense_input_shape)
231
232
233    @unittest.skipIf("cutlass" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cutlass not supported on this machine")
234    @unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
235    def test_mlp_contiguous_relu_compile_cutlass(self):
236        """
237        test for CUTLASS meta registrations (_sparse_semi_structured_addmm) + torch.compile
238        """
239        for dense_input_shape in [(1, 128), (64, 128), (128, 128), (64, 128, 128)]:
240            SparseSemiStructuredTensorCompileTest._test_mlp_contiguous_relu_compile("cutlass", dense_input_shape)
241
242
243    @unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
244    @unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine")
245    def test_sp24_compile(self) -> None:
246        x = torch.randn([1024, 512], device="cuda", dtype=torch.float16, requires_grad=True)
247        e = torch.eye(x.shape[0], x.shape[0], device="cuda", dtype=torch.float16)
248
249        def fn(x, e):
250            y = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(x)
251            y = y.t()
252            return x @ y
253
254        # Eager
255        output = fn(x, e)
256        output.backward(output)
257        # Torch compile
258        output = torch.compile(fn)(x, e)
259        output.backward(output)
260
261class TestSparseSemiStructured(TestCase):
262
263    def setUp(self):
264        if len(SEMI_STRUCTURED_SUPPORTED_BACKENDS) == 0:
265            self.skipTest('semi-structured sparsity has no available backend!')
266        if IS_WINDOWS:
267            self.skipTest("torch.compile not supported on windows")
268
269    @inference_dtypes
270    @parametrize_backends
271    def test_to_sparse_semi_structured(self, dtype, backend):
272        SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
273        A = rand_sparse_semi_structured_mask(128, 256, dtype=dtype)
274        A_sparse = to_sparse_semi_structured(A)
275
276        assert A.shape == A_sparse.shape
277        assert A.device == A_sparse.device
278        assert A.dtype == A_sparse.dtype
279
280        assert isinstance(A, torch.Tensor)
281        assert isinstance(A_sparse, SparseSemiStructuredTensor)
282
283    @inference_dtypes
284    @parametrize_backends
285    @parametrize("dense_input_shape", [(128, 1), (128, 64), (128, 128)])
286    def test_mm_sparse_first_NN(self, dense_input_shape, dtype, device, backend):
287        """
288        Ensure torch.mm(A_sparse, B) is correct for float16 and will throw error for int8
289        """
290        SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
291        A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
292        A_sparse = to_sparse_semi_structured(A)
293
294        B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype)
295
296        # Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
297        if dtype is torch.int8:
298            if backend == "cutlass":
299                with self.assertRaisesRegex(RuntimeError, "spgemm_cutlass_dispatch_layouts"):
300                    sparse_result = torch.mm(A_sparse, B)
301            else:
302                with self.assertRaisesRegex(RuntimeError,
303                                            "CUDA error: operation not supported when calling `cusparseLtMatmulDescriptorInit"):
304                    sparse_result = torch.mm(A_sparse, B)
305        else:
306            dense_result = torch.mm(A, B)
307            sparse_result = torch.mm(A_sparse, B)
308            torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
309
310    @inference_dtypes
311    @parametrize_backends
312    @parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
313    def test_mm_sparse_first_NT(self, dense_input_shape, dtype, device, backend):
314        """
315        Ensure torch.mm(A_sparse, B.t()) is correct for float16/bfloat16
316        and will throw an error for int8 + padding
317        """
318        SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
319        A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
320        A_sparse = to_sparse_semi_structured(A)
321
322        B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype)
323
324        # Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
325        if dtype is torch.int8 and dense_input_shape in {(1, 128)}:
326            # padding with int8 throws an error because transposing B yields a contiguous output
327            # and row-row 2:4 sparse @ dense with NN is not supported by cuSPARSELt or CUTLASS.
328            if backend == "cutlass":
329                with self.assertRaisesRegex(RuntimeError, "spgemm_cutlass_dispatch_layouts"):
330                    sparse_result = torch.mm(A_sparse, B.t())
331            else:
332                with self.assertRaisesRegex(RuntimeError,
333                                            "CUDA error: operation not supported when calling `cusparseLtMatmulDescriptorInit"):
334                    sparse_result = torch.mm(A_sparse, B.t())
335        elif dtype is torch.int8:
336            # test transpose
337            dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int8)
338            sparse_result = torch.mm(A_sparse, B.t())
339            torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
340        else:
341            # test transpose
342            dense_result = torch.mm(A, B.t())
343            sparse_result = torch.mm(A_sparse, B.t())
344            torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
345
346    @inference_dtypes
347    @parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
348    @parametrize_backends
349    def test_mm_sparse_first_TN(self, dtype, dense_input_shape, device, backend):
350        """
351        Ensure torch.mm(A_sparse.t(), B) throws error
352        """
353        SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
354        if backend == "cutlass" and IS_WINDOWS:
355            self.skipTest("CUTLASS not supported on Windows")
356        A = rand_sparse_semi_structured_mask(128, 256, dtype=dtype)
357        A_sparse = to_sparse_semi_structured(A)
358
359        B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype)
360
361        with self.assertRaisesRegex(
362            NotImplementedError,
363            r"`SparseSemiStructuredTensor.*` matmul: operation is not supported",
364        ):
365            torch.mm(A_sparse.t(), B)
366
367    @inference_dtypes
368    @parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
369    @parametrize_backends
370    def test_mm_sparse_second_NT(self, dense_input_shape, dtype, device, backend):
371        """
372        Ensure torch.mm(A, B_sparse.t()) is correct
373        """
374        SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
375        if backend == "cutlass" and IS_WINDOWS:
376            self.skipTest("CUTLASS not supported on Windows")
377        B = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
378        B_sparse = to_sparse_semi_structured(B)
379
380        A = torch.rand(dense_input_shape, device=B_sparse.device).to(dtype)
381
382        # Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
383        if dtype is torch.int8:
384            dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int8)
385            sparse_result = torch.mm(A, B_sparse.t())
386        else:
387            dense_result = torch.mm(A, B.t())
388            sparse_result = torch.mm(A, B_sparse.t())
389
390        torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
391
392    @inference_dtypes
393    @parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
394    @parametrize_backends
395    def test_mm_sparse_second_NN(self, dense_input_shape, dtype, device, backend):
396        """
397        Ensure torch.mm(A, B_sparse) throws error
398        """
399        SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
400        if backend == "cutlass" and IS_WINDOWS:
401            self.skipTest("CUTLASS not supported on Windows")
402        B = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
403        B_sparse = to_sparse_semi_structured(B)
404
405        A = torch.rand(dense_input_shape, device=B_sparse.device).to(dtype)
406
407        with self.assertRaisesRegex(
408            NotImplementedError,
409            r"`SparseSemiStructuredTensor.*` matmul: operation is not supported",
410        ):
411            sparse_result = torch.mm(A, B_sparse)
412
413    @parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128), (64, 128, 128)])
414    @parametrize("inference_mode", [subtest(True), subtest(False)])
415    @parametrize_backends
416    def test_linear(self, dense_input_shape, inference_mode, device, backend):
417        """
418        Test nn.Linear has the same numerics
419        """
420        SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
421        if backend == "cutlass" and IS_WINDOWS:
422            self.skipTest("CUTLASS not supported on Windows")
423        input = torch.rand((dense_input_shape), device=device).half()
424        model = nn.Linear(128, 256).to(device).half()
425        m, n = model.weight.shape
426        mask = rand_sparse_semi_structured_mask(m, n, device=device, dtype=torch.bool)
427        # set masked weight
428        model.weight = nn.Parameter(model.weight * mask)
429
430        dense_result = model(input)
431
432        model.weight = nn.Parameter(to_sparse_semi_structured(model.weight))
433
434        if inference_mode:
435            with torch.inference_mode():
436                sparse_result = model(input)
437        else:
438            sparse_result = model(input)
439
440        torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
441
442    @parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128), (64, 128, 128)])
443    @parametrize_backends
444    def test_mlp(self, device, dense_input_shape, backend):
445        SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
446        input = torch.rand(dense_input_shape, device=device).half()
447        model = (
448            nn.Sequential(
449                nn.Linear(128, 256),
450                nn.Linear(256, 128),
451            )
452            .half()
453            .to(device)
454        )
455
456        for i in range(2):
457            m, n = model[i].weight.shape
458            mask = rand_sparse_semi_structured_mask(
459                m, n, device=device, dtype=torch.bool
460            )
461            # set masked weight
462            model[i].weight = nn.Parameter(model[i].weight * mask)
463
464        dense_result = model(input)
465
466        for i in range(2):
467            model[i].weight = nn.Parameter(to_sparse_semi_structured(model[i].weight))
468
469        sparse_result = model(input)
470
471        torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
472
473    @parametrize_backends
474    def test_values(self, backend):
475        SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
476        if backend == "cutlass" and IS_WINDOWS:
477            self.skipTest("CUTLASS not supported on Windows")
478        A = rand_sparse_semi_structured_mask(128, 128)
479        A_sparse = to_sparse_semi_structured(A)
480        assert A_sparse.values().shape == (128, 64)
481        assert (A_sparse.values() == 1).all()
482
483    @parametrize_backends
484    def test_indices(self, backend):
485        SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
486        if backend == "cutlass" and IS_WINDOWS:
487            self.skipTest("CUTLASS not supported on Windows")
488        A = rand_sparse_semi_structured_mask(128, 128)
489        A_sparse = to_sparse_semi_structured(A)
490        assert A_sparse.indices().shape == (128, 8)
491
492    @inference_dtypes
493    @parametrize_backends
494    def test_min_sparse_shape(self, dtype, device, backend):
495        SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
496        config = SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend]._DTYPE_SHAPE_CONSTRAINTS[dtype]
497        A = rand_sparse_semi_structured_mask(config.sparse_min_rows, config.sparse_min_cols, dtype=dtype, device=device)
498        A_sparse = to_sparse_semi_structured(A)
499        B = torch.rand((config.sparse_min_cols, config.dense_min_cols), device=device).to(dtype)
500        if dtype == torch.int8:
501            dense_res = torch.mm(A.cpu(), B.cpu()).to(device, dtype=torch.int8)
502            # int8 sparse matmul not supported for R/R -> R layout, so we transpose one of the arguments to get R/C -> R
503            B_t = B.t().contiguous()
504            sparse_res = torch.mm(A_sparse, B_t.t())
505        else:
506            dense_res = torch.mm(A, B)
507            sparse_res = torch.mm(A_sparse, B)
508        torch.testing.assert_close(sparse_res, dense_res, rtol=1e-3, atol=1e-3)
509
510    @inference_dtypes
511    @parametrize_backends
512    def test_unsupported_shape(self, dtype, device, backend):
513        SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
514        if backend == "cutlass" and IS_WINDOWS:
515            self.skipTest("CUTLASS not supported on Windows")
516        A = rand_sparse_semi_structured_mask(2, 2, dtype=dtype, device=device)
517        with self.assertRaisesRegex(RuntimeError, "Error original_tensor.shape"):
518            A_sparse = to_sparse_semi_structured(A)
519
520    @dtypes(*all_types_and_complex())
521    @parametrize_backends
522    def test_unsupported_dtype(self, dtype, device, backend):
523        SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
524        if backend == "cutlass" and IS_WINDOWS:
525            self.skipTest("CUTLASS not supported on Windows")
526        A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype, device=device)
527
528        if dtype not in SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend]._DTYPE_SHAPE_CONSTRAINTS:
529            with self.assertRaisesRegex(RuntimeError, "Error original_tensor.dtype"):
530                A_sparse = to_sparse_semi_structured(A)
531        else:
532            A_sparse = to_sparse_semi_structured(A)
533
534    @parametrize_backends
535    def test_unsupported_dim(self, device, backend):
536        SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
537        if backend == "cutlass" and IS_WINDOWS:
538            self.skipTest("CUTLASS not supported on Windows")
539        A = torch.rand(128, 128, 128, device=device, dtype=torch.float16)
540
541        with self.assertRaisesRegex(RuntimeError, "Error original_tensor.dim"):
542            A_sparse = to_sparse_semi_structured(A)
543
544
545def create_random_mask(shape) -> torch.Tensor:
546    r = random.Random(0)
547    mask = torch.zeros(shape, dtype=torch.bool)
548    for line in range(mask.shape[0]):
549        for col in range(0, mask.shape[1], 4):
550            sparsity = r.choice(
551                [
552                    [False, False, True, True],
553                    [False, True, False, True],
554                    [True, False, False, True],
555                    [False, True, True, False],
556                    [True, False, True, False],
557                    [True, True, False, False],
558                ]
559            )
560            mask[line, col : col + 4] = torch.tensor(sparsity, dtype=torch.bool)
561    return mask
562
563class TestSparseSemiStructuredTraining(TestCase):
564
565    def setUp(self):
566        if not _IS_SM8X:
567            self.skipTest("SparseSemiStructuredTensor training only supported on SM8x (Ampere)")
568
569        if IS_WINDOWS:
570            self.skipTest('CUTLASS not supported on windows')
571
572
573    @training_dtypes
574    def test_prune_dense_static_sort(self, dtype) -> None:
575        # Ideally we would like to clone and compare, but that won't work because the sorting order will be different
576        # instead we pass the pruned matrix to the CUDA implementation and preserve the sparsity pattern.
577        dense = torch.randn(128, 128, device="cuda", dtype=dtype)
578        pruned = _sparse_semi_structured_tile(dense)
579
580        # CUTLASS
581        reference_cutlass = SparseSemiStructuredTensorCUTLASS.prune_dense_static_sort(pruned, algorithm="largest_abs_values_greedy")
582        torch.testing.assert_close(pruned, reference_cutlass.to_dense())
583
584        packed_cutlass, meta_cutlass = sparse_semi_structured_from_dense_cutlass(pruned)
585        packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass(pruned.t().contiguous())
586        meta_cutlass = meta_cutlass.as_strided(reference_cutlass.meta.shape, reference_cutlass.meta.stride())
587        meta_t_cutlass = meta_t_cutlass.as_strided(reference_cutlass.meta_t.shape, reference_cutlass.meta_t.stride())
588        compressed_swizzled_bitmask = _compute_compressed_swizzled_bitmask(pruned)
589        compressed_swizzled_bitmask = compressed_swizzled_bitmask.as_strided(reference_cutlass.compressed_swizzled_bitmask.shape,
590                                                                             reference_cutlass.compressed_swizzled_bitmask.stride())
591        cutlass = SparseSemiStructuredTensorCUTLASS(dense.shape,
592                                                    packed_cutlass,
593                                                    meta_cutlass,
594                                                    packed_t_cutlass,
595                                                    meta_t_cutlass,
596                                                    compressed_swizzled_bitmask)
597        torch.testing.assert_close(reference_cutlass.to_dense(), cutlass.to_dense())
598
599        # CUSPARSELT
600        reference_cusparselt = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(pruned,
601                                                                                            algorithm="largest_abs_values_greedy")
602        torch.testing.assert_close(pruned, reference_cusparselt.to_dense())
603
604        packed_cusparselt = torch._cslt_compress(pruned)
605        packed_t_cusparselt = torch._cslt_compress(pruned.t().contiguous())
606        cusparselt = SparseSemiStructuredTensorCUSPARSELT(dense.shape,
607                                                          packed_cusparselt,
608                                                          None,
609                                                          packed_t_cusparselt,
610                                                          None,
611                                                          compressed_swizzled_bitmask)
612        torch.testing.assert_close(reference_cusparselt.to_dense(), cusparselt.to_dense())
613
614
615
616    @training_dtypes
617    @parametrize_backends
618    def test_pruning_algo_largest_abs_values_greedy(self, dtype, backend) -> None:
619        inp = torch.tensor(
620            [[4, 3, 2, 1], [-1, -3, 0.6, 0.5], [1, 2, 3, 4], [10, 2, -1, 5]],
621            device="cuda",
622            dtype=dtype,
623        )
624        inp = F.pad(inp, (0, 128 - 4, 0, 128 - 4), "constant", 1)
625        sInp = SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend].prune_dense_static_sort(inp, algorithm="largest_abs_values_greedy")
626
627        mask = sInp.to_dense() / inp
628        assert mask[:4, :4].int().tolist() == [
629            [1, 1, 0, 0],
630            [0, 1, 1, 0],
631            [0, 0, 1, 1],
632            [1, 0, 0, 1],
633        ]
634
635    @training_dtypes
636    def test_gemm(self, dtype) -> None:
637        M, N, K = 32, 32, 64
638        a = torch.randn([M, K], device="cuda", dtype=dtype)
639        b = torch.randn([K, N], device="cuda", dtype=dtype)
640        mask = rand_sparse_semi_structured_mask(M, K, dtype=torch.bool)
641
642        a.masked_fill_(~mask, 0)
643
644        a_sparse = to_sparse_semi_structured(a)
645
646        masked_a = a * mask
647        ref_out = masked_a @ b
648        sp24_out = a_sparse @ b
649        torch.testing.assert_close(ref_out, sp24_out, **atol_rtol_kw[dtype])
650
651
652    @training_dtypes
653    @parametrize_backends
654    def test_pack_both_ways_meta_correctness(self, dtype, backend) -> None:
655        M, N = 128, 256
656        # Construct x to make sure we always have exactly 8 elements per 4x4 tile
657        a = (4 * torch.arange(8))[:, None] + torch.arange(8)[None, :]
658        a = a.repeat(M // 8, N // 8)
659        assert a.shape == (M, N)
660        a = a.cuda().to(dtype)
661        b = torch.randn([a.shape[1], 128], device="cuda", dtype=dtype)
662
663        a_sparse = SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend].prune_dense_static_sort(a)
664
665        mask_dense = sparse24_largest_mask_2d(a).to(dtype)
666
667        if backend == "cutlass":
668            assert isinstance(a_sparse, SparseSemiStructuredTensorCUTLASS)
669            (packed, meta, packed_t, meta_t, bitmask) = torch._sparse_semi_structured_tile(
670                mask_dense, use_cutlass=True)
671
672            sparse_mask = SparseSemiStructuredTensorCUTLASS(
673                mask_dense.shape,
674                packed=packed,
675                meta=meta,
676                packed_t=packed_t,
677                meta_t=meta_t,
678                compressed_swizzled_bitmask=bitmask,
679            )
680            torch.testing.assert_close(a_sparse.meta.view(torch.short), sparse_mask.meta)
681
682        ref_gemm = (mask_dense * a) @ b
683        pack_gemm = a_sparse @ b
684        torch.testing.assert_close(ref_gemm, pack_gemm, **atol_rtol_kw[dtype])
685
686    @training_dtypes
687    def test_pack_both_ways_id(self, dtype) -> None:
688        N = 512
689        torch.manual_seed(0)
690        a = torch.randn([N, N], dtype=dtype, device="cuda")
691        b = torch.eye(N, dtype=dtype, device="cuda")
692
693        packed, meta, packed_t, meta_t = torch._sparse_semi_structured_tile(a)[
694            :4
695        ]
696        # Heuristic to ensure we pack the same values
697        torch.testing.assert_close(
698            packed.to(torch.float64).sum(), packed_t.to(torch.float64).sum()
699        )
700
701        mask_dense = sparse24_largest_mask_2d(a.to(dtype))
702
703        ref_gemm = mask_dense * a
704        # Test A@B
705        pack_gemm = torch._sparse_semi_structured_linear(b.t(), packed, meta).t()
706        max_diff = (ref_gemm - pack_gemm).abs().argmax()
707        torch.testing.assert_close(
708            ref_gemm, pack_gemm,
709            **atol_rtol_kw[dtype]
710        ), f"packed is wrong at pos: ({max_diff // N}, {max_diff % N})"
711        # Test A.t@B
712        pack_gemm = torch._sparse_semi_structured_linear(b.t(), packed_t, meta_t)
713        max_diff = (ref_gemm - pack_gemm).abs().argmax()
714
715        torch.testing.assert_close(
716            ref_gemm, pack_gemm,
717            **atol_rtol_kw[dtype]
718        ), f"packed_t is wrong at pos: ({max_diff // N}, {max_diff % N})"
719
720    @training_dtypes
721    def test_pack_both_ways_edge_case1(self, dtype) -> None:
722        # In this case, the heuristic will keep 7 values out of 16
723        # instead of 8. let's see how the kernel handles this
724        quad = torch.tensor(
725            [
726                [2, -1, -2, -3],  # Should be packed as `2 <null>`
727                [-1, 8, -1, 6],
728                [-1, -1, 4, 5],
729                [-1, 3, 7, -1],
730            ],
731            dtype=dtype,
732            device="cuda",
733        )
734        a = torch.randn([32, 64], dtype=dtype, device="cuda")
735        a[:4, :4] = quad
736        packed, meta, packed_t, meta_t = torch._sparse_semi_structured_tile(a)[:4]
737        # Check first line in A
738        assert packed[0, 0].item() == 2
739        assert packed[0, 1].item() == 0
740        # And first column in A.t
741        assert packed_t[0, 0].item() == 2
742        assert packed_t[0, 1].item() == 0
743
744    @training_dtypes
745    def test_sp24_apply(self, dtype) -> None:
746        M, N = 256, 1024
747        x = torch.randn([M, N], dtype=dtype, device="cuda")
748        (
749            packed,
750            meta,
751            packed_t,
752            meta_t,
753            bitmask,
754        ) = torch._sparse_semi_structured_tile(x)
755        packed2, packed_t2 = torch._sparse_semi_structured_apply(x, bitmask)
756        torch.testing.assert_close(packed, packed2)
757        torch.testing.assert_close(packed_t, packed_t2)
758
759    @training_dtypes
760    def test_sp24_apply_dense(self, dtype) -> None:
761        M, N = 256, 1024
762        x = torch.randn([M, N], dtype=dtype, device="cuda")
763        (
764            packed,
765            meta,
766            packed_t,
767            meta_t,
768            bitmask,
769        ) = torch._sparse_semi_structured_tile(x)
770
771        expected = SparseSemiStructuredTensorCUTLASS(
772            x.shape,
773            packed=packed,
774            meta=meta,
775            packed_t=packed_t,
776            meta_t=meta_t,
777            compressed_swizzled_bitmask=bitmask,
778        ).to_dense()
779
780        packed2, packed_t2 = torch._sparse_semi_structured_apply(x, bitmask)
781        sparse = SparseSemiStructuredTensorCUTLASS(
782            x.shape,
783            packed=packed2,
784            meta=meta,
785            packed_t=packed_t2,
786            meta_t=meta_t,
787            compressed_swizzled_bitmask=bitmask,
788        )
789
790        dense = torch._sparse_semi_structured_apply_dense(x, bitmask)
791
792        torch.testing.assert_close(dense, expected)
793        torch.testing.assert_close(sparse.to_dense(), expected)
794
795
796    @training_dtypes
797    def test_sp24_matmuls(self, dtype) -> None:
798        M, N, K = 64, 256, 1024
799        a = torch.randn([M, K], device="cuda", dtype=dtype)
800        b = torch.randn([K, N], device="cuda", dtype=dtype)
801        a_m = sparse24_largest_mask_2d(a)
802        b_m = sparse24_largest_mask_2d(b)
803        (packed, meta, packed_t, meta_t, bitmask) = torch._sparse_semi_structured_tile(a)
804        a_s = SparseSemiStructuredTensorCUTLASS(
805            a.shape,
806            packed=packed,
807            meta=meta,
808            packed_t=packed_t,
809            meta_t=meta_t,
810            compressed_swizzled_bitmask=bitmask,
811        )
812        (packed, meta, packed_t, meta_t, bitmask) = torch._sparse_semi_structured_tile(b)
813        b_s = SparseSemiStructuredTensorCUTLASS(
814            b.shape,
815            packed=packed,
816            meta=meta,
817            packed_t=packed_t,
818            meta_t=meta_t,
819            compressed_swizzled_bitmask=bitmask,
820        )
821
822        torch.testing.assert_close(a_s @ b, (a * a_m) @ b, rtol=1e-1, atol=1.5e-1)
823        torch.testing.assert_close(a @ b_s, a @ (b * b_m), rtol=1e-1, atol=1.5e-1)
824        torch.testing.assert_close(
825            a @ a_s.t(), a @ (a * a_m).t(), rtol=1e-1, atol=1.5e-1
826        )
827        torch.testing.assert_close(
828            a_s.t() @ a, (a * a_m).t() @ a, rtol=1e-1, atol=1e-1
829        )
830
831    def test_sp24_matmuls_mat_vec(self) -> None:
832        a = torch.randn([64, 128], device="cuda", dtype=torch.float16)
833        b = torch.randn([128], device="cuda", dtype=torch.float16)
834        a_m = sparse24_largest_mask_2d(a)
835        a_s = to_sparse_semi_structured(a)
836
837        with pytest.raises(NotImplementedError):
838            torch.testing.assert_close(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype])
839
840
841    def test_sp24_matmuls_bmm(self) -> None:
842        a = torch.randn([64, 128], device="cuda", dtype=torch.float16)
843        b = torch.randn([5, 6, 128], device="cuda", dtype=torch.float16)
844        a_m = sparse24_largest_mask_2d(a)
845        a_s = to_sparse_semi_structured(a)
846
847        with pytest.raises(NotImplementedError):
848            torch.testing.assert_close(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype])
849
850class TestSparseSemiStructuredCUTLASS(TestCase):
851    """
852    This contains CUTLASS specific tests for
853         - torch._sparse_semi_structured_linear
854    """
855    def setUp(self):
856        if "cutlass" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS:
857            self.skipTest('CUTLASS not enabled')
858
859    @unittest.skipIf(TEST_WITH_ROCM or IS_WINDOWS, "ROCm and Windows doesn't support CUTLASS")
860    @inference_dtypes
861    def test_linear_cutlass(self, device, dtype):
862
863        def run_test(batch_shape, m, n, k, device, dtype, dtype_out, add_bias, activation, rtol, atol):
864            weight = rand_sparse_semi_structured(m, k, dtype, device)
865            input = make_tensor((*batch_shape, n, k), dtype=dtype, device=device)
866            bias = make_tensor((m,), dtype=dtype_out, device=device) if add_bias else None
867
868            dtype_dense = torch.float32
869            input_dense = input.to(dtype_dense)
870            weight_dense = weight.to(dtype_dense)
871            bias_dense = bias.to(dtype_dense) if add_bias else None
872            output0 = torch.nn.functional.linear(input_dense, weight_dense, bias=bias_dense)
873            if activation == "relu":
874                relu = torch.nn.ReLU()
875                output0 = relu(output0)
876            elif activation == "silu":
877                silu = torch.nn.SiLU()
878                output0 = silu(output0)
879
880            compressed = to_sparse_semi_structured(weight)
881
882            weight_sparse = compressed.values()
883            meta = compressed.indices()
884
885            output1 = torch._sparse_semi_structured_linear(input, weight_sparse, meta, bias=bias, activation=activation,
886                                                           out_dtype=dtype_out if dtype == torch.int8 else None)
887            torch.testing.assert_close(output1.to(dtype_dense), output0, rtol=rtol, atol=atol)
888
889        if dtype == torch.float32:
890            # Inputs are converted to TF32 internally for sparse GEMM,
891            # so make dense GEMM to do the same for matching results.
892            orig = torch.backends.cuda.matmul.allow_tf32
893            torch.backends.cuda.matmul.allow_tf32 = True
894
895        batch_shapes = [[], [3], [3, 1]]
896        dtype_out = {torch.int8: torch.int32, torch.half: torch.half, torch.bfloat16: torch.bfloat16, torch.float32: torch.float32}
897        activations = [None, "relu", "silu"]
898        rtol, atol = 1e-3, 1e-3
899        if dtype == torch.bfloat16:
900            rtol, atol = 5e-3, 5e-3
901        elif dtype == torch.float32:
902            rtol, atol = 1e-3, 75e-2
903        for batch_shape, m, n, k, add_bias, activation in \
904                itertools.product(batch_shapes, range(3), range(3), range(3), (False, True), activations):
905            if activation == "silu" and dtype == torch.int8:
906                continue  # SiLU not supported for integer inputs
907
908            m = 2 ** m * 32
909            n = 2 ** n * 32
910            k = 2 ** k * 128
911            run_test(batch_shape, m, n, k, device, dtype, dtype_out[dtype], add_bias, activation, rtol, atol)
912
913        if dtype == torch.float32:
914            torch.backends.cuda.matmul.allow_tf32 = orig
915
916
917    @unittest.skipIf(TEST_WITH_ROCM or IS_WINDOWS, "ROCm and Windows doesn't support CUTLASS")
918    @parametrize("backend", ["cutlass"])
919    @inference_dtypes
920    def test_sparse_semi_structured_ops_cutlass(self, device, dtype, backend):
921        SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
922        if backend == "cutlass" and IS_WINDOWS:
923            self.skipTest("CUTLASS not supported on Windows")
924
925        def run_test(m, n, k, device, dtype, dtype_out, use_input, rtol, atol):
926            mat1 = rand_sparse_semi_structured(m, k, dtype, device)
927            # mat2 transposed as int8 case supports only row-major/column-major combination
928            mat2 = make_tensor((n, k), dtype=dtype, device=device).t()
929            input = make_tensor((m,), dtype=dtype_out, device=device) if use_input else None
930
931            if use_input:
932                if dtype.is_floating_point:
933                    alpha = 1.3
934                    beta = -0.7
935                else:
936                    alpha = 2
937                    beta = -3
938
939            dtype_dense = torch.float32
940            mat1_dense = mat1.to(dtype_dense)
941            mat2_dense = mat2.to(dtype_dense)
942            if not use_input:
943                output0 = torch.mm(mat1_dense, mat2_dense)
944            else:
945                input_dense = input.to(dtype_dense)[:, None]
946                output0 = torch.addmm(input_dense, mat1_dense, mat2_dense, alpha=alpha, beta=beta)
947
948            compressed = to_sparse_semi_structured(mat1)
949
950            mat1_sparse = compressed.values()
951            mat1_meta = compressed.indices()
952
953            if not use_input:
954                output1 = torch._sparse_semi_structured_mm(mat1_sparse, mat1_meta, mat2, out_dtype=dtype_out)
955            else:
956                output1 = torch._sparse_semi_structured_addmm(
957                    input, mat1_sparse, mat1_meta, mat2, alpha=alpha, beta=beta, out_dtype=dtype_out
958                )
959            torch.testing.assert_close(output1.to(dtype_dense), output0, rtol=rtol, atol=atol)
960
961        if dtype == torch.float32:
962            # Inputs are converted to TF32 internally for sparse GEMM,
963            # so make dense GEMM to do the same for matching results.
964            orig = torch.backends.cuda.matmul.allow_tf32
965            torch.backends.cuda.matmul.allow_tf32 = True
966
967        dtype_out = {torch.int8: torch.int32, torch.half: torch.half, torch.bfloat16: torch.bfloat16, torch.float32: torch.float32}
968        rtol, atol = 1e-3, 1e-3
969        if dtype == torch.bfloat16:
970            rtol, atol = 5e-3, 5e-3
971        elif dtype == torch.float32:
972            rtol, atol = 1e-3, 75e-2
973        for m, n, k, use_input in \
974                itertools.product(range(3), range(3), range(3), (False, True)):
975            m = 2 ** m * 32
976            n = 2 ** n * 32
977            k = 2 ** k * 128
978            run_test(m, n, k, device, dtype, dtype_out[dtype], use_input, rtol, atol)
979
980        if dtype == torch.float32:
981            torch.backends.cuda.matmul.allow_tf32 = orig
982
983
984    @unittest.skipIf(not has_triton(), "Test needs triton and recent GPU arch")
985    @inference_dtypes
986    def test_conversions(self, device, dtype):
987
988        def run_test(r, c, device, dtype):
989            dense_ref = rand_sparse_semi_structured(r, c, dtype, device)
990
991            compressed = to_sparse_semi_structured(dense_ref)
992
993            # The torch.ops.aten._to_sparse_semi_structured operator
994            # uses CUTLASS to perform conversion from given dense
995            # matrix to the pair of corresponding sparse and metadata
996            # matrices, with the later used here as a reference to
997            # compare the metadata matrix produced by conversion
998            # performed by SparseSemiStructuredTensor class
999            # constructor against.
1000            _, meta_ref = torch.ops.aten._to_sparse_semi_structured(dense_ref)
1001
1002            meta = compressed.indices()
1003            torch.testing.assert_close(meta, meta_ref, rtol=0, atol=0)
1004
1005            dense = compressed.to_dense()
1006            torch.testing.assert_close(dense, dense_ref, rtol=0, atol=0)
1007
1008        shapes = [[32, 128], [32, 256], [64, 128], [64, 256]]
1009        for r, c in shapes:
1010            run_test(r, c, device, dtype)
1011
1012    @unittest.skipIf(not has_triton(), "Test needs triton and recent GPU arch")
1013    @inference_dtypes
1014    def test_conversions_all_patterns(self, device, dtype):
1015        r, c = 32, 128
1016
1017        dense_inv, dense_val = rand_sparse_semi_structured_all_patterns(r, c, dtype, device)
1018
1019        compressed = to_sparse_semi_structured(dense_inv)
1020        dense = compressed.to_dense()
1021
1022        torch.testing.assert_close(dense, dense_val, rtol=0, atol=0)
1023
1024
1025
1026CUSPARSELT_NUM_ALG_IDS = 4
1027CUSPARSELT_MIXED_DTYPE_SUPPORT = [torch.float16, torch.bfloat16, torch.int32]
1028
1029
1030class TestSparseSemiStructuredCUSPARSELT(TestCase):
1031    """
1032    This contains cuSPARSELt specific tests for
1033        torch._cslt_compress
1034        torch._cslt_sparse_mm
1035    """
1036    def setUp(self):
1037        if "cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS:
1038            self.skipTest('cuSPARSELt not enabled')
1039
1040    @parametrize("out_dtype", CUSPARSELT_MIXED_DTYPE_SUPPORT)
1041    @parametrize("dense_input_shape", [(128, 128)])
1042    def test_cslt_sparse_mm_mixed_dtype(self, dense_input_shape, out_dtype, device):
1043        A = rand_sparse_semi_structured_mask(128, 128, dtype=torch.int8)
1044        A_compressed = torch._cslt_compress(A)
1045
1046        B = torch.rand(dense_input_shape, device=device).to(torch.int8)
1047
1048        dense_result = torch.mm(A.cpu().to(torch.int64), B.t().cpu().to(torch.int64)).to(device, dtype=out_dtype)
1049        sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), out_dtype=out_dtype)
1050        torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
1051
1052    @unittest.skip("cuSPARSELt v0.6.x does not support bfloat/float16 alpha scaling")
1053    @training_dtypes
1054    def test_cslt_sparse_mm_alpha(self, dtype, device):
1055        A = torch.Tensor([0, 0, 1, 1]).tile((128, 64)).to(dtype).cuda()
1056        B = torch.ones((256, 128), device=device).to(dtype)
1057        alpha = torch.Tensor([2**(-i) for i in range(128)]).cuda()
1058        bias = torch.ones(128, device=device).to(dtype)
1059
1060        A_compressed = torch._cslt_compress(A)
1061        sparse_result = torch._cslt_sparse_mm(A_compressed, B, alpha=alpha, bias=bias)
1062
1063        alpha_scaled = torch.stack([alpha] * 128).t()
1064        dense_result = alpha_scaled * torch.mm(A.to(torch.float32), B.to(torch.float32))
1065        dense_result = dense_result.to(dtype)
1066
1067        torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
1068
1069    @parametrize("out_dtype", CUSPARSELT_MIXED_DTYPE_SUPPORT)
1070    def test_cslt_sparse_mm_alpha_mixed_dtype(self, out_dtype, device):
1071        A = torch.Tensor([0, 0, 10, 10]).tile((128, 64)).to(torch.int8).cuda()
1072        B = torch.ones((128, 256), device=device).to(torch.int8).t()
1073        alpha = torch.Tensor([2**(-i) if out_dtype is not torch.int32 else 1
1074                              for i in range(128)]).cuda()
1075
1076        A_compressed = torch._cslt_compress(A)
1077        sparse_result = torch._cslt_sparse_mm(A_compressed, B, alpha=alpha, out_dtype=out_dtype).cpu()
1078
1079        alpha_scaled = torch.stack([alpha] * 128).t()
1080        dense_result = alpha_scaled.cpu() * torch.mm(A.to(torch.int64).cpu(), B.to(torch.int64).cpu())
1081        dense_result = dense_result.to(out_dtype)
1082
1083        torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
1084
1085    @parametrize("alg_id", range(CUSPARSELT_NUM_ALG_IDS))
1086    @inference_dtypes
1087    def test_cslt_sparse_mm_alg_id(self, device, dtype, alg_id):
1088        # alg_id=3 not supported for float32 dtype
1089        if dtype == torch.float32 and alg_id == 3:
1090            return
1091        A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
1092        A_compressed = torch._cslt_compress(A)
1093        B = torch.ones((128, 128), device=device).to(dtype)
1094
1095        A_compressed = torch._cslt_compress(A)
1096        sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), alg_id=alg_id)
1097
1098        dense_result = torch.mm(A.to(torch.float32), B.to(torch.float32))
1099        dense_result = dense_result.to(dtype)
1100
1101        torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
1102
1103    @inference_dtypes
1104    def test_cslt_sparse_mm_search(self, device, dtype):
1105        A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
1106        A_compressed = torch._cslt_compress(A)
1107        B = torch.ones((128, 128), device=device).to(dtype)
1108
1109        A_compressed = torch._cslt_compress(A)
1110        alg_id = torch._cslt_sparse_mm_search(A_compressed, B.t())
1111        # for cuSPARSELt v0.4.0 there is a bug where although there are 5 alg_ids, we run into an error
1112        # when setting using the last one (4)
1113        # in cuSPARSELt v0.5.0 there are only 4 alg_ids total, so we should remove the +1 here when we update.
1114        # TODO Move this into the cuSPARSELt backendk
1115        assert alg_id in range(CUSPARSELT_NUM_ALG_IDS + 1)
1116
1117    def test_cusparselt_backend(self):
1118        version = _get_torch_cuda_version()
1119        assert torch.backends.cusparselt.is_available()
1120
1121        # CUDA 11.8 has cuSPARSELt v0.4.0 support
1122        if version == (11, 8):
1123            assert torch.backends.cusparselt.version() == 400
1124        # CUDA 12.1 has cuSPARSELt v0.5.2 support
1125        elif version == (12, 1):
1126            assert torch.backends.cusparselt.version() == 502
1127        # CUDA 12.4+ has cuSPARSELt v0.6.2 support
1128        elif version >= (12, 4):
1129            assert torch.backends.cusparselt.version() == 602
1130        else:
1131            assert torch.backends.cusparselt.version() is None
1132
1133if len(SEMI_STRUCTURED_SUPPORTED_BACKENDS) > 0:
1134    instantiate_device_type_tests(TestSparseSemiStructured, globals(), only_for="cuda")
1135if "cutlass" in SEMI_STRUCTURED_SUPPORTED_BACKENDS:
1136    instantiate_device_type_tests(TestSparseSemiStructuredCUTLASS, globals(), only_for="cuda")
1137    instantiate_device_type_tests(TestSparseSemiStructuredTraining, globals(), only_for="cuda")
1138if "cusparselt" in SEMI_STRUCTURED_SUPPORTED_BACKENDS:
1139    instantiate_device_type_tests(TestSparseSemiStructuredCUSPARSELT, globals(), only_for="cuda")
1140
1141if __name__ == "__main__":
1142    run_tests()
1143